forked from respec/HSPsquared
-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Because there is currently no inheritance in numba classes, we need to be a little creative with using classes, but the benefit may be worthwhile in terms of readability and eventual migration to future versions of numba classes, provided that performance is enhanced or at least maintained with respect to the old methods using integer tokens.
- Classes can reference other classes in their spec by using the
[class].class_type.instance_typereference- Ex:
RQUAL_Classhas properties that are references to otherjitclassentities such asOXRX->OXRX_ClassOXRX_Class_ = OXRX_Class.class_type.instance_typespec = [
("OXRX", OXRX_Class_),
- Ex:
- Initially, we can have classes support:
- both
object.valueandstate.state_ix[object_id]for grabbing values - An
op_tokenslocal array to compare execution speeds
- both
- Performance:
numba 0.60(python3.9):iterate_op_arr_test(op1,100000)takes0.00033164024353027344 secondsiterate_op_obj_test(op1,100000)takes0.0016977787017822266 seconds- So Array is 80% faster
numba 0.64(python 3.12):iterate_op_arr_test(op1,100000)takes0.017830848693847656 secondsiterate_op_obj_test(op1,100000)takes0.007010936737060547 seconds- So object is 50% faster BECAUSE ARRAY GOT SLOWER
from hsp2.state.state import state_class_numba, state_class, make_state_numba, state_copy
from hsp2.state.state_definitions import state_empty
from hsp2.hsp2.om import (
om_state_hsp2_run_setup,
state_om_model_run_finish,
)
from hsp2.state.state_timer import timer_class
from numpy import int64, zeros
from numba import deferred_type, optional, types, njit, typeof
from numba.experimental import jitclass
import numpy as np
import numba as nb
arr_class_spec = [
("state_ix", nb.float64[:]),
]
@jitclass(arr_class_spec)
class arr_class_numba:
def __init__(self, num_ops):
state_ix = zeros(num_ops)
self.state_ix = state_ix.astype(np.float64)
# state_class_numba = arr_class_numba
om_node_type = deferred_type()
spec = [
("state", state_class_numba.class_type.instance_type),
("state_ix", typeof(state_empty["state_ix"])),
("arr_ix", arr_class_numba.class_type.instance_type),
("value", nb.float64),
("ix", nb.int64),
("operator", nb.int64),
("input1", optional(om_node_type)),
("input2", optional(om_node_type)),
("op_tokens", typeof(int64(zeros( 64))) ),
#("matrix", types.float64[:, :]),
]
@jitclass(spec)
class om_op_class:
def __init__(self, state, init_value=0.0, operator=1):
self.state = state
state_ix = zeros(64)
self.state_ix = state_ix.astype(np.float64)
self.arr_ix = arr_class_numba(64)
self.value = init_value
self.input1 = None
self.input2 = None
self.operator = operator
op_tokens = zeros(64)
self.op_tokens = op_tokens.astype(int64)
self.op_tokens[0] = 0
#self.matrix = None
def evaluate(self):
if self.operator == 1:
self.value = self.input1.value - self.input2.value
self.state.state_ix[self.ix] = self.value
return
def evaluate_state(self):
if self.operator == 1:
self.value = self.state.state_ix[self.op_tokens[1]] - self.state.state_ix[self.op_tokens[2]]
self.state.state_ix[self.ix] = self.value
return
def evaluate_arr_ix(self):
# a stripped down variant of state_numba to see if that matters?
if self.operator == 1:
self.value = self.arr_ix.state_ix[self.op_tokens[1]] - self.arr_ix.state_ix[self.op_tokens[2]]
self.arr_ix.state_ix[self.ix] = self.value
return
def evaluate_state_ix(self):
# a stripped down variant of state_numba to see if that matters?
if self.operator == 1:
self.value = self.state_ix[self.op_tokens[1]] - self.state_ix[self.op_tokens[2]]
self.state_ix[self.ix] = self.value
return
# 2. Define the type inside the class
om_node_type.define(om_op_class.class_type.instance_type)
# set up functions to test iterations
@njit
def iterate_op_obj_test(
op, steps
):
for step in range(steps):
op.evaluate()
print("Steps completed", step)
@njit
def iterate_op_arr_test(
op, steps
):
for step in range(steps):
op.evaluate_state()
print("Steps completed", step)
@njit
def iterate_state_ix(
op, steps
):
for step in range(steps):
op.evaluate_state_ix()
print("Steps completed", step)
@njit
def iterate_arrix(
op, steps
):
for step in range(steps):
op.evaluate_arr_ix()
print("Steps completed", step)
# Usage
state = state_class(
state_empty["state_ix"], state_empty["op_tokens"], state_empty["state_paths"],
state_empty["op_exec_lists"], state_empty["model_exec_list"], state_empty["dict_ix"],
state_empty["ts_ix"], state_empty["hsp_segments"]
)
state_numba = make_state_numba(64)
op1 = om_op_class(state_numba, 0.0, 1) # init value = 0.0., operator = 1 (-)
op2 = om_op_class(state_numba, 100.0)
op3 = om_op_class(state_numba, 15.0)
op1.input1 = op2 # Reference to another instance
op1.input2 = op3 # Reference to another instance
op1.ix = state.set_state('op1', op1.value)
op2.ix = state.set_state('op2', op2.value)
op3.ix = state.set_state('op3', op3.value)
op1.op_tokens[1] = op2.ix # shortcut to add input
op1.op_tokens[2] = op3.ix # shortcut to add input
state_copy(state, state_numba )
op1.arr_ix.state_ix = np.copy(state_numba.state_ix)
op1.state_ix = np.copy(state_numba.state_ix)
op3.arr_ix = op2.arr_ix = op1.arr_ix
op3.state_ix = op2.state_ix = op1.state_ix
op1.evaluate()
op1.evaluate_state()
# compile
iterate_op_arr_test(op1,1)
iterate_op_obj_test(op1,1)
iterate_arrix(op1,1)
timer=timer_class()
n=100000
t=timer.split();iterate_op_arr_test(op1,n);t=timer.split()
print(n,"ARR iterations took", t,"seconds","with numba", nb.__version__)
t=timer.split();iterate_op_obj_test(op1,n);t=timer.split()
print(n,"OBJ iterations took", t,"seconds","with numba", nb.__version__)
t=timer.split();iterate_arrix(op1,n);t=timer.split()
print(n,"Local prop iterate_arrix iterations took", t,"seconds","with numba", nb.__version__)
t=timer.split();iterate_state_ix(op1,n);t=timer.split()
print(n,"Local prop iterate_state_ixiterations took", t,"seconds","with numba", nb.__version__)
Test python3.12/numba 0.64 vs python3.9/numba 0.60
- save as
bug.py 100000000 iterations took 0.3137941360473633 seconds With numba 0.60.0
import numpy as np
from numpy import zeros
from numba import njit
import time
import numba as nb
from numpy import int64
from numba.experimental import jitclass
class timer_class():
def __init__(self):
self.tstart = time.time()
def split(self):
self.tend = time.time()
self.tsplit = self.tend - self.tstart
self.tstart = time.time()
split = 0
if (self.tsplit > 0):
split = self.tsplit
return split
arr_class_spec = [
("state_ix", nb.float64[:]),
]
@jitclass(arr_class_spec)
class arr_class_numba:
def __init__(self, num_ops):
state_ix = zeros(num_ops)
self.state_ix = state_ix.astype(np.float64)
def step(self):
self.state_ix[0] = self.state_ix[add_keys[0]] + self.state_ix[add_keys[1]]
@njit
def iterate_method_test(
arr_class, add_keys, steps
):
for step in range(steps):
arr_class.step()
@njit
def iterate_arr_test(
arr, add_keys, steps
):
arr[0] = 0.0
for step in range(steps):
arr[0] = arr[add_keys[0]] + arr[add_keys[1]]
@njit
def iterate_arr_class_test(
arr_class, add_keys, steps
):
for step in range(steps):
arr_class.state_ix[0] = arr_class.state_ix[add_keys[0]] + arr_class.state_ix[add_keys[1]]
timer=timer_class()
arr_class=arr_class_numba(8) # cr
arr_class.state_ix[1]=100.0
arr_class.state_ix[2]=15.0
add_keys = np.asarray([1,2])
add_keys = add_keys.astype(int64)
# Now run the 2 variations:
# iterate_arr_test(): sends the state_ix array "naked"
# iterate_arr_state_test(): sends the state_ix array inside the object class wrapper state
iterate_arr_test(arr_class.state_ix, add_keys, 1) # compile it
iterate_arr_class_test(arr_class, add_keys, 1) # compile it
iterate_method_test(arr_class, add_keys, 1) # compile it
n=100000000
t=timer.split();iterate_arr_test(arr_class.state_ix, add_keys, n);t=timer.split()
print(n,"iterations took", t,"seconds","with array numba", nb.__version__)
t=timer.split();iterate_arr_class_test(arr_class, add_keys, n);t=timer.split()
print(n,"iterations took", t,"seconds","with class.array numba", nb.__version__)
t=timer.split();iterate_method_test(arr_class, add_keys, n);t=timer.split()
print(n,"iterations took", t,"seconds","with method math numba", nb.__version__)
- Test with state if that's the source
@njit
def iterate_arr_state_test(
state, add_keys, steps
):
for step in range(steps):
state.state_ix[0] = state.state_ix[add_keys[0]] + state.state_ix[add_keys[1]]
iterate_arr_state_test(state_numba , add_keys, 1) # compile it
n=100000000
t=timer.split();iterate_arr_state_test(state_numba , add_keys, n);t=timer.split()
print(n,"iterations took", t,"seconds","with numba", nb.__version__)
t=timer.split();iterate_arr_test(state_numba.state_ix, add_keys, n);t=timer.split()
print(n,"iterations took", t,"seconds","with numba", nb.__version__)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels