diff --git a/graph_net/constraint_util.py b/graph_net/constraint_util.py index c1fa443eb..b59f21d3e 100644 --- a/graph_net/constraint_util.py +++ b/graph_net/constraint_util.py @@ -230,11 +230,15 @@ def _is_dyn_dim_cstr_feasible( def update_tensor_metas_by_dyn_dim_cstr( tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints ): - input_shapes = dyn_dim_cstr.get_reified_input_shapes() - # Only update input tensors (first len(input_shapes) tensors), skip weight tensors - for i in range(min(len(input_shapes), len(tensor_metas))): - tensor_meta = tensor_metas[i] - tensor_meta.shape = input_shapes[i] + input_shapes_with_names = dyn_dim_cstr.input_shapes + name2shape = { + name: [dyn_dim_cstr._try_reify(dim) for dim in shape] + for shape, name in input_shapes_with_names + } + for tensor_meta in tensor_metas: + if tensor_meta.name not in name2shape: + continue + tensor_meta.shape = name2shape[tensor_meta.name] if tensor_meta.data is not None: assert isinstance(tensor_meta.data, (list, tuple)) size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1) diff --git a/graph_net/dimension_generalizer.py b/graph_net/dimension_generalizer.py index e8ac50731..be33d64ce 100644 --- a/graph_net/dimension_generalizer.py +++ b/graph_net/dimension_generalizer.py @@ -262,10 +262,15 @@ def _get_tensor_metas(self, model_path): def update_tensor_metas_by_dyn_dim_cstr( tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints ): - input_shapes = dyn_dim_cstr.get_reified_input_shapes() - assert len(tensor_metas) == len(input_shapes) - for i, tensor_meta in enumerate(tensor_metas): - tensor_meta.shape = input_shapes[i] + input_shapes_with_names = dyn_dim_cstr.input_shapes + name2shape = { + name: [dyn_dim_cstr._try_reify(dim) for dim in shape] + for shape, name in input_shapes_with_names + } + for tensor_meta in tensor_metas: + if tensor_meta.name not in name2shape: + continue + tensor_meta.shape = name2shape[tensor_meta.name] if tensor_meta.data is not None: assert isinstance(tensor_meta.data, (list, tuple)) size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1)