Skip to content

om class prototype #123

@rburghol

Description

@rburghol

Because there is currently no inheritance in numba classes, we need to be a little creative with using classes, but the benefit may be worthwhile in terms of readability and eventual migration to future versions of numba classes, provided that performance is enhanced or at least maintained with respect to the old methods using integer tokens.

  • Classes can reference other classes in their spec by using the [class].class_type.instance_type reference
    • Ex: RQUAL_Class has properties that are references to other jitclass entities such as OXRX -> OXRX_Class
      • OXRX_Class_ = OXRX_Class.class_type.instance_type
      • spec = [
        ("OXRX", OXRX_Class_),
  • Initially, we can have classes support:
    • both object.value and state.state_ix[object_id] for grabbing values
    • An op_tokens local array to compare execution speeds
  • Performance:
    • numba 0.60 (python3.9):
      • iterate_op_arr_test(op1,100000) takes 0.00033164024353027344 seconds
      • iterate_op_obj_test(op1,100000) takes 0.0016977787017822266 seconds
      • So Array is 80% faster
    • numba 0.64 (python 3.12):
      • iterate_op_arr_test(op1,100000) takes 0.017830848693847656 seconds
      • iterate_op_obj_test(op1,100000) takes 0.007010936737060547 seconds
      • So object is 50% faster BECAUSE ARRAY GOT SLOWER
from hsp2.state.state import state_class_numba, state_class, make_state_numba, state_copy
from hsp2.state.state_definitions import state_empty
from hsp2.hsp2.om import (
    om_state_hsp2_run_setup,
    state_om_model_run_finish,
)
from hsp2.state.state_timer import timer_class
from numpy import int64, zeros
from numba import deferred_type, optional, types, njit, typeof
from numba.experimental import jitclass
import numpy as np
import numba as nb

arr_class_spec = [
    ("state_ix", nb.float64[:]),
]
@jitclass(arr_class_spec)
class arr_class_numba:
    def __init__(self, num_ops):
        state_ix = zeros(num_ops)
        self.state_ix = state_ix.astype(np.float64)

# state_class_numba = arr_class_numba
om_node_type = deferred_type()

spec = [
    ("state", state_class_numba.class_type.instance_type), 
    ("state_ix", typeof(state_empty["state_ix"])), 
    ("arr_ix", arr_class_numba.class_type.instance_type), 
    ("value", nb.float64),
    ("ix", nb.int64),
    ("operator", nb.int64),
    ("input1", optional(om_node_type)),
    ("input2", optional(om_node_type)),
    ("op_tokens", typeof(int64(zeros( 64))) ),
    #("matrix", types.float64[:, :]),
]

@jitclass(spec)
class om_op_class:
    def __init__(self, state, init_value=0.0, operator=1):
        self.state = state
        state_ix = zeros(64)
        self.state_ix = state_ix.astype(np.float64)
        self.arr_ix = arr_class_numba(64)
        self.value = init_value
        self.input1 = None
        self.input2 = None
        self.operator = operator
        op_tokens = zeros(64)
        self.op_tokens = op_tokens.astype(int64)
        self.op_tokens[0] = 0
        #self.matrix = None
    
    def evaluate(self):
        if self.operator == 1:
            self.value = self.input1.value - self.input2.value
            self.state.state_ix[self.ix] = self.value
            return
    
    def evaluate_state(self):
        if self.operator == 1:
            self.value = self.state.state_ix[self.op_tokens[1]] - self.state.state_ix[self.op_tokens[2]]
            self.state.state_ix[self.ix] = self.value
            return
        
    def evaluate_arr_ix(self):
        # a stripped down variant of state_numba to see if that matters?
        if self.operator == 1:
            self.value = self.arr_ix.state_ix[self.op_tokens[1]] - self.arr_ix.state_ix[self.op_tokens[2]]
            self.arr_ix.state_ix[self.ix] = self.value
            return
        
    def evaluate_state_ix(self):
        # a stripped down variant of state_numba to see if that matters?
        if self.operator == 1:
            self.value = self.state_ix[self.op_tokens[1]] - self.state_ix[self.op_tokens[2]]
            self.state_ix[self.ix] = self.value
            return

# 2. Define the type inside the class
om_node_type.define(om_op_class.class_type.instance_type)

# set up functions to test iterations
@njit
def iterate_op_obj_test(
    op, steps
):
    for step in range(steps):
        op.evaluate()
    print("Steps completed", step)

@njit
def iterate_op_arr_test(
    op, steps
):
    for step in range(steps):
        op.evaluate_state()
    print("Steps completed", step)

@njit
def iterate_state_ix(
    op, steps
):
    for step in range(steps):
        op.evaluate_state_ix()
    print("Steps completed", step)

@njit
def iterate_arrix(
    op, steps
):
    for step in range(steps):
        op.evaluate_arr_ix()
    print("Steps completed", step)

# Usage
state = state_class(
    state_empty["state_ix"], state_empty["op_tokens"], state_empty["state_paths"], 
    state_empty["op_exec_lists"], state_empty["model_exec_list"], state_empty["dict_ix"], 
    state_empty["ts_ix"], state_empty["hsp_segments"]
)
state_numba = make_state_numba(64)
op1 = om_op_class(state_numba, 0.0, 1) # init value = 0.0., operator = 1 (-)
op2 = om_op_class(state_numba, 100.0)
op3 = om_op_class(state_numba, 15.0)
op1.input1 = op2 # Reference to another instance 
op1.input2 = op3 # Reference to another instance 
op1.ix = state.set_state('op1', op1.value)
op2.ix = state.set_state('op2', op2.value)
op3.ix = state.set_state('op3', op3.value)
op1.op_tokens[1] = op2.ix # shortcut to add input
op1.op_tokens[2] = op3.ix # shortcut to add input
state_copy(state, state_numba )
op1.arr_ix.state_ix = np.copy(state_numba.state_ix)
op1.state_ix = np.copy(state_numba.state_ix)
op3.arr_ix = op2.arr_ix = op1.arr_ix
op3.state_ix = op2.state_ix = op1.state_ix 
op1.evaluate()
op1.evaluate_state()

# compile
iterate_op_arr_test(op1,1)
iterate_op_obj_test(op1,1)
iterate_arrix(op1,1)

timer=timer_class()
n=100000
t=timer.split();iterate_op_arr_test(op1,n);t=timer.split()
print(n,"ARR iterations took", t,"seconds","with numba", nb.__version__)
t=timer.split();iterate_op_obj_test(op1,n);t=timer.split()
print(n,"OBJ iterations took", t,"seconds","with numba", nb.__version__)
t=timer.split();iterate_arrix(op1,n);t=timer.split()
print(n,"Local prop iterate_arrix iterations took", t,"seconds","with numba", nb.__version__)
t=timer.split();iterate_state_ix(op1,n);t=timer.split()
print(n,"Local prop iterate_state_ixiterations took", t,"seconds","with numba", nb.__version__)

Test python3.12/numba 0.64 vs python3.9/numba 0.60

  • save as bug.py
  • 100000000 iterations took 0.3137941360473633 seconds With numba 0.60.0
import numpy as np
from numpy import zeros
from numba import njit
import time
import numba as nb
from numpy import int64
from numba.experimental import jitclass

class timer_class():
    def __init__(self):
        self.tstart = time.time()
    
    def split(self):
        self.tend = time.time()
        self.tsplit = self.tend - self.tstart
        self.tstart = time.time()
        split = 0
        if (self.tsplit > 0):
            split = self.tsplit
        return split


arr_class_spec = [
    ("state_ix", nb.float64[:]),
]
@jitclass(arr_class_spec)
class arr_class_numba:
    def __init__(self, num_ops):
        state_ix = zeros(num_ops)
        self.state_ix = state_ix.astype(np.float64)
    
    def step(self):
        self.state_ix[0] = self.state_ix[add_keys[0]] + self.state_ix[add_keys[1]]

@njit
def iterate_method_test(
    arr_class, add_keys, steps
):
    for step in range(steps):
        arr_class.step()

@njit
def iterate_arr_test(
    arr, add_keys, steps
):
    arr[0] = 0.0
    for step in range(steps):
        arr[0] = arr[add_keys[0]] + arr[add_keys[1]]

@njit
def iterate_arr_class_test(
    arr_class, add_keys, steps
):
    for step in range(steps):
        arr_class.state_ix[0] = arr_class.state_ix[add_keys[0]] + arr_class.state_ix[add_keys[1]]

timer=timer_class()
arr_class=arr_class_numba(8) # cr
arr_class.state_ix[1]=100.0
arr_class.state_ix[2]=15.0
add_keys = np.asarray([1,2])
add_keys = add_keys.astype(int64)

# Now run the 2 variations:
# iterate_arr_test(): sends the state_ix array "naked"
# iterate_arr_state_test(): sends the state_ix array inside the object class wrapper state
iterate_arr_test(arr_class.state_ix, add_keys, 1) # compile it
iterate_arr_class_test(arr_class, add_keys, 1) # compile it
iterate_method_test(arr_class, add_keys, 1) # compile it
n=100000000
t=timer.split();iterate_arr_test(arr_class.state_ix, add_keys, n);t=timer.split()
print(n,"iterations took", t,"seconds","with array numba", nb.__version__)
t=timer.split();iterate_arr_class_test(arr_class, add_keys, n);t=timer.split()
print(n,"iterations took", t,"seconds","with class.array numba", nb.__version__)
t=timer.split();iterate_method_test(arr_class, add_keys, n);t=timer.split()
print(n,"iterations took", t,"seconds","with method math numba", nb.__version__)

  • Test with state if that's the source
@njit
def iterate_arr_state_test(
    state, add_keys, steps
):
    for step in range(steps):
        state.state_ix[0] = state.state_ix[add_keys[0]] + state.state_ix[add_keys[1]]

iterate_arr_state_test(state_numba , add_keys, 1) # compile it
n=100000000
t=timer.split();iterate_arr_state_test(state_numba , add_keys, n);t=timer.split()
print(n,"iterations took", t,"seconds","with numba", nb.__version__)

t=timer.split();iterate_arr_test(state_numba.state_ix, add_keys, n);t=timer.split()
print(n,"iterations took", t,"seconds","with numba", nb.__version__)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions