Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion axengine/_axclrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._base_session import Session, SessionOptions
from ._logging import get_logger
from ._node import NodeArg
from ._utils import _transform_dtype_axclrt as _transform_dtype
from ._utils_axclrt import _transform_dtype_axclrt as _transform_dtype

logger = get_logger(__name__)

Expand Down
22 changes: 0 additions & 22 deletions axengine/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import ml_dtypes as mldt
import numpy as np

from ._axclrt_capi import axclrt_cffi, axclrt_lib
from ._axe_capi import engine_cffi, engine_lib


Expand All @@ -24,24 +23,3 @@ def _transform_dtype(dtype):
return np.dtype(mldt.bfloat16)
else:
raise ValueError(f"Unsupported data type '{dtype}'.")


def _transform_dtype_axclrt(dtype):
if dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT8):
return np.dtype(np.uint8)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT8):
return np.dtype(np.int8)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT16):
return np.dtype(np.uint16)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT16):
return np.dtype(np.int16)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT32):
return np.dtype(np.uint32)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT32):
return np.dtype(np.int32)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_FP32):
return np.dtype(np.float32)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_BF16):
return np.dtype(mldt.bfloat16)
else:
raise ValueError(f"Unsupported data type '{dtype}'.")
30 changes: 30 additions & 0 deletions axengine/_utils_axclrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
#
# AXCL dtype helpers — kept separate from _utils so AxEngine (board) path does not
# import axcl_rt at module load time.

import ml_dtypes as mldt
import numpy as np

from ._axclrt_capi import axclrt_cffi, axclrt_lib


def _transform_dtype_axclrt(dtype):
if dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT8):
return np.dtype(np.uint8)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT8):
return np.dtype(np.int8)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT16):
return np.dtype(np.uint16)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT16):
return np.dtype(np.int16)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT32):
return np.dtype(np.uint32)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT32):
return np.dtype(np.int32)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_FP32):
return np.dtype(np.float32)
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_BF16):
return np.dtype(mldt.bfloat16)
else:
raise ValueError(f"Unsupported data type '{dtype}'.")