diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index 718b39197..91400d1e0 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -94,6 +94,10 @@ def create_placeholder(node: fx.Node) -> fx.Node: """Create a placeholder node with dtype conversion if needed.""" new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x)) if self._is_float32_tensor(node): + attr_name = str(node.target) + if self.should_preserve_weight(attr_name): + return new_node + return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node @@ -121,6 +125,21 @@ def create_new_args(node: fx.Node) -> list: new_args.append(arg) return new_args + def create_new_kwargs(node: fx.Node) -> dict: + """new_kwargs of node with dtype conversion if needed.""" + new_kwargs = {} + + for k, v in node.kwargs.items(): + if isinstance(v, fx.Node): + mapped = val_map[v] + if self._is_float32_tensor(v): + mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) + else: + new_kwargs[k] = mapped + else: + new_kwargs[k] = v + return new_kwargs + def create_call_function(node: fx.Node) -> fx.Node: """Create a call_function node with dtype conversion if needed.""" if node.target not in AMP_CALL_FUNCTION: @@ -128,10 +147,7 @@ def create_call_function(node: fx.Node) -> fx.Node: new_args = create_new_args(node) - new_kwargs = { - k: val_map[v] if isinstance(v, fx.Node) else v - for k, v in node.kwargs.items() - } + new_kwargs = create_new_kwargs(node) return new_graph.call_function( node.target, @@ -140,15 +156,13 @@ def create_call_function(node: fx.Node) -> fx.Node: ) def create_call_method(node: fx.Node) -> fx.Node: + """Create a call_method node with dtype conversion if needed.""" if node.target not in AMP_CALL_METHOD: return new_graph.node_copy(node, lambda x: val_map[x]) new_args = create_new_args(node) - new_kwargs = { - k: (val_map[v] if isinstance(v, fx.Node) else v) - for k, v in node.kwargs.items() - } + new_kwargs = create_new_kwargs(node) return new_graph.call_method( node.target,