From 463337b79774e25f2802b4d199c8a1abde2231ca Mon Sep 17 00:00:00 2001 From: TelGome <2657726985@qq.com> Date: Mon, 26 Jan 2026 19:54:01 +0800 Subject: [PATCH 1/2] [Bug Fix]: Match tensor shapes by name instead of index in dim generalization. --- graph_net/dimension_generalizer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/graph_net/dimension_generalizer.py b/graph_net/dimension_generalizer.py index e8ac50731..be33d64ce 100644 --- a/graph_net/dimension_generalizer.py +++ b/graph_net/dimension_generalizer.py @@ -262,10 +262,15 @@ def _get_tensor_metas(self, model_path): def update_tensor_metas_by_dyn_dim_cstr( tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints ): - input_shapes = dyn_dim_cstr.get_reified_input_shapes() - assert len(tensor_metas) == len(input_shapes) - for i, tensor_meta in enumerate(tensor_metas): - tensor_meta.shape = input_shapes[i] + input_shapes_with_names = dyn_dim_cstr.input_shapes + name2shape = { + name: [dyn_dim_cstr._try_reify(dim) for dim in shape] + for shape, name in input_shapes_with_names + } + for tensor_meta in tensor_metas: + if tensor_meta.name not in name2shape: + continue + tensor_meta.shape = name2shape[tensor_meta.name] if tensor_meta.data is not None: assert isinstance(tensor_meta.data, (list, tuple)) size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1) From 0b36e1ba2060d2f020fa2fb18b8e1d9266875bb6 Mon Sep 17 00:00:00 2001 From: TelGome <2657726985@qq.com> Date: Tue, 27 Jan 2026 14:56:42 +0800 Subject: [PATCH 2/2] Fix. --- graph_net/constraint_util.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/graph_net/constraint_util.py b/graph_net/constraint_util.py index c1fa443eb..b59f21d3e 100644 --- a/graph_net/constraint_util.py +++ b/graph_net/constraint_util.py @@ -230,11 +230,15 @@ def _is_dyn_dim_cstr_feasible( def update_tensor_metas_by_dyn_dim_cstr( tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints ): - input_shapes = dyn_dim_cstr.get_reified_input_shapes() - # Only update input tensors (first len(input_shapes) tensors), skip weight tensors - for i in range(min(len(input_shapes), len(tensor_metas))): - tensor_meta = tensor_metas[i] - tensor_meta.shape = input_shapes[i] + input_shapes_with_names = dyn_dim_cstr.input_shapes + name2shape = { + name: [dyn_dim_cstr._try_reify(dim) for dim in shape] + for shape, name in input_shapes_with_names + } + for tensor_meta in tensor_metas: + if tensor_meta.name not in name2shape: + continue + tensor_meta.shape = name2shape[tensor_meta.name] if tensor_meta.data is not None: assert isinstance(tensor_meta.data, (list, tuple)) size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1)