Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions tests/cpu/st/testcase/trelu/gen_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def gen_golden_data_trelu(case_name, param):


class TReluParams:
def __init__(self, dtype, src_tile_row, src_tile_col, dst_tile_row, dst_tile_col, valid_row, valid_col):
def __init__(self, output_case_name, dtype, src_tile_row, src_tile_col, dst_tile_row, dst_tile_col, valid_row, valid_col):
self.output_case_name = output_case_name
self.dtype = dtype
self.src_tile_row = src_tile_row
self.src_tile_col = src_tile_col
Expand Down Expand Up @@ -64,24 +65,21 @@ def substring(a, b) -> str:
script_dir = os.path.dirname(os.path.abspath(__file__))

case_params_list = [
TReluParams(np.float32, 64, 64, 64, 64, 64, 64),
TReluParams(np.int32, 64, 64, 64, 64, 64, 64),
TReluParams(np.float16, 16, 256, 16, 256, 16, 256),
TReluParams(np.int16, 64, 64, 64, 64, 64, 64),
TReluParams(np.float32, 64, 64, 64, 64, 60, 55),
TReluParams(np.int32, 64, 64, 64, 64, 60, 55),
TReluParams(np.float16, 64, 64, 96, 96, 64, 60),
TReluParams(np.int16, 64, 64, 96, 96, 64, 60),
TReluParams("case_0", np.float32, 64, 64, 64, 64, 64, 64),
TReluParams("case_1", np.int32, 64, 64, 64, 64, 64, 64),
TReluParams("case_2", np.float16, 16, 256, 16, 256, 16, 256),
TReluParams("case_3", np.int16, 64, 64, 64, 64, 64, 64),
TReluParams("case_4", np.float32, 64, 64, 64, 64, 60, 55),
TReluParams("case_5", np.int32, 64, 64, 64, 64, 60, 55),
TReluParams("case_6", np.float16, 64, 64, 96, 96, 64, 60),
TReluParams("case_7", np.int16, 64, 64, 96, 96, 64, 60),
]
if os.getenv("PTO_CPU_SIM_ENABLE_BF16") == "1":
case_params_list.append(TReluParams(BF16_DTYPE, 16, 256, 16, 256, 16, 256))
case_params_list.append(TReluParams("case_bf16_16x256_16x256_16x256", BF16_DTYPE, 16, 256, 16, 256, 16, 256))

for i, param in enumerate(case_params_list):
for param in case_params_list:
case_name = generate_case_name(param)
if i < 8:
output_dir = os.path.join(script_dir, f"TRELUTest.case_{i}")
else:
output_dir = os.path.join(script_dir, case_name)
output_dir = os.path.join(script_dir, f"TRELUTest.{param.output_case_name}")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The case_name variable (defined on line 81) is no longer used for directory mapping, as it has been replaced by the explicit param.output_case_name. Additionally, case_name is passed to gen_golden_data_trelu (line 86), which does not use the parameter. This makes the generate_case_name function and the case_name variable redundant. Consider removing them to simplify the script and improve maintainability.

os.makedirs(output_dir, exist_ok=True)
original_dir = os.getcwd()
os.chdir(output_dir)
Expand Down
Loading