Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions graph_net/constraint_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,15 @@ def _is_dyn_dim_cstr_feasible(
def update_tensor_metas_by_dyn_dim_cstr(
tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints
):
input_shapes = dyn_dim_cstr.get_reified_input_shapes()
# Only update input tensors (first len(input_shapes) tensors), skip weight tensors
for i in range(min(len(input_shapes), len(tensor_metas))):
tensor_meta = tensor_metas[i]
tensor_meta.shape = input_shapes[i]
input_shapes_with_names = dyn_dim_cstr.input_shapes
name2shape = {
name: [dyn_dim_cstr._try_reify(dim) for dim in shape]
for shape, name in input_shapes_with_names
}
for tensor_meta in tensor_metas:
if tensor_meta.name not in name2shape:
continue
tensor_meta.shape = name2shape[tensor_meta.name]
if tensor_meta.data is not None:
assert isinstance(tensor_meta.data, (list, tuple))
size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1)
Expand Down
13 changes: 9 additions & 4 deletions graph_net/dimension_generalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,15 @@ def _get_tensor_metas(self, model_path):
def update_tensor_metas_by_dyn_dim_cstr(
tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints
):
input_shapes = dyn_dim_cstr.get_reified_input_shapes()
assert len(tensor_metas) == len(input_shapes)
for i, tensor_meta in enumerate(tensor_metas):
tensor_meta.shape = input_shapes[i]
input_shapes_with_names = dyn_dim_cstr.input_shapes
name2shape = {
name: [dyn_dim_cstr._try_reify(dim) for dim in shape]
for shape, name in input_shapes_with_names
}
for tensor_meta in tensor_metas:
if tensor_meta.name not in name2shape:
continue
tensor_meta.shape = name2shape[tensor_meta.name]
if tensor_meta.data is not None:
assert isinstance(tensor_meta.data, (list, tuple))
size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1)
Expand Down