From 7b43173241c249e22553d4a00c0ed35be7dac2cb Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Tue, 27 Jan 2026 19:48:23 +0800 Subject: [PATCH 1/2] Add FP32_ONLY_FUNCS op to fix dtype generalization pass --- .../dtype_generalization_pass.py | 45 ++++++++++++++++--- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index 718b39197..fbf186557 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -32,6 +32,19 @@ "bmm", } +FP32_ONLY_FUNCS = { + torch.nn.functional.softmax, + torch.nn.functional.layer_norm, + torch.nn.functional.group_norm, + torch.nn.functional.batch_norm, + torch.nn.functional.embedding, + torch.exp, + torch.log, + torch.pow, + torch.sigmoid, + torch.tanh, + torch.conv_transpose2d, +} class ConcretePass(DtypeGeneralizationPass): """ @@ -107,7 +120,7 @@ def create_get_attr(node: fx.Node) -> fx.Node: return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node - def create_new_args(node: fx.Node) -> list: + def create_new_args(node: fx.Node, target_dtype: torch.dtype) -> list: """new_args of node with dtype conversion if needed.""" new_args = [] @@ -115,7 +128,10 @@ def create_new_args(node: fx.Node) -> list: if isinstance(arg, fx.Node): mapped = val_map[arg] if self._is_float32_tensor(arg): - mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) + if target_dtype == torch.float32: + mapped = new_graph.call_method("to", (mapped, torch.float32)) + elif target_dtype == self.torch_dtype: + mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) new_args.append(mapped) else: new_args.append(arg) @@ -123,10 +139,13 @@ def create_new_args(node: fx.Node) -> list: def create_call_function(node: fx.Node) -> fx.Node: """Create a call_function node with dtype conversion if needed.""" - if node.target not in AMP_CALL_FUNCTION: - return new_graph.node_copy(node, lambda x: val_map[x]) + require_fp32 = is_fp32_node(node) + target_dtype = torch.float32 if require_fp32 else self.torch_dtype + + if node.target not in AMP_CALL_FUNCTION and not require_fp32: + return new_graph.node_copy(node, lambda x: val_map[x]) - new_args = create_new_args(node) + new_args = create_new_args(node, target_dtype) new_kwargs = { k: val_map[v] if isinstance(v, fx.Node) else v @@ -140,10 +159,14 @@ def create_call_function(node: fx.Node) -> fx.Node: ) def create_call_method(node: fx.Node) -> fx.Node: - if node.target not in AMP_CALL_METHOD: + """Create a call_method node with dtype conversion if needed.""" + require_fp32 = is_fp32_node(node) + target_dtype = torch.float32 if require_fp32 else self.torch_dtype + + if node.target not in AMP_CALL_METHOD and not require_fp32: return new_graph.node_copy(node, lambda x: val_map[x]) - new_args = create_new_args(node) + new_args = create_new_args(node, target_dtype) new_kwargs = { k: (val_map[v] if isinstance(v, fx.Node) else v) @@ -156,6 +179,14 @@ def create_call_method(node: fx.Node) -> fx.Node: new_kwargs, ) + def is_fp32_node(node: fx.Node) -> bool: + """Check if a node of float32 only op.""" + if node.op == 'call_function': + return node.target in FP32_ONLY_FUNCS + elif node.op == 'call_method': + return node.target in AMP_CALL_METHOD + return False + for node in gm.graph.nodes: if node.op == "placeholder": val_map[node] = create_placeholder(node) From e269e2e4eb6108807f31209c43aa9ad67da09f42 Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Wed, 28 Jan 2026 15:25:39 +0800 Subject: [PATCH 2/2] fix bug of invalid dtype for bias --- .../dtype_generalization_pass.py | 73 +++++++------------ 1 file changed, 28 insertions(+), 45 deletions(-) diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index fbf186557..91400d1e0 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -32,19 +32,6 @@ "bmm", } -FP32_ONLY_FUNCS = { - torch.nn.functional.softmax, - torch.nn.functional.layer_norm, - torch.nn.functional.group_norm, - torch.nn.functional.batch_norm, - torch.nn.functional.embedding, - torch.exp, - torch.log, - torch.pow, - torch.sigmoid, - torch.tanh, - torch.conv_transpose2d, -} class ConcretePass(DtypeGeneralizationPass): """ @@ -107,6 +94,10 @@ def create_placeholder(node: fx.Node) -> fx.Node: """Create a placeholder node with dtype conversion if needed.""" new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x)) if self._is_float32_tensor(node): + attr_name = str(node.target) + if self.should_preserve_weight(attr_name): + return new_node + return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node @@ -120,7 +111,7 @@ def create_get_attr(node: fx.Node) -> fx.Node: return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node - def create_new_args(node: fx.Node, target_dtype: torch.dtype) -> list: + def create_new_args(node: fx.Node) -> list: """new_args of node with dtype conversion if needed.""" new_args = [] @@ -128,29 +119,35 @@ def create_new_args(node: fx.Node, target_dtype: torch.dtype) -> list: if isinstance(arg, fx.Node): mapped = val_map[arg] if self._is_float32_tensor(arg): - if target_dtype == torch.float32: - mapped = new_graph.call_method("to", (mapped, torch.float32)) - elif target_dtype == self.torch_dtype: - mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) + mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) new_args.append(mapped) else: new_args.append(arg) return new_args + def create_new_kwargs(node: fx.Node) -> dict: + """new_kwargs of node with dtype conversion if needed.""" + new_kwargs = {} + + for k, v in node.kwargs.items(): + if isinstance(v, fx.Node): + mapped = val_map[v] + if self._is_float32_tensor(v): + mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) + else: + new_kwargs[k] = mapped + else: + new_kwargs[k] = v + return new_kwargs + def create_call_function(node: fx.Node) -> fx.Node: """Create a call_function node with dtype conversion if needed.""" - require_fp32 = is_fp32_node(node) - target_dtype = torch.float32 if require_fp32 else self.torch_dtype - - if node.target not in AMP_CALL_FUNCTION and not require_fp32: - return new_graph.node_copy(node, lambda x: val_map[x]) + if node.target not in AMP_CALL_FUNCTION: + return new_graph.node_copy(node, lambda x: val_map[x]) - new_args = create_new_args(node, target_dtype) + new_args = create_new_args(node) - new_kwargs = { - k: val_map[v] if isinstance(v, fx.Node) else v - for k, v in node.kwargs.items() - } + new_kwargs = create_new_kwargs(node) return new_graph.call_function( node.target, @@ -160,18 +157,12 @@ def create_call_function(node: fx.Node) -> fx.Node: def create_call_method(node: fx.Node) -> fx.Node: """Create a call_method node with dtype conversion if needed.""" - require_fp32 = is_fp32_node(node) - target_dtype = torch.float32 if require_fp32 else self.torch_dtype - - if node.target not in AMP_CALL_METHOD and not require_fp32: + if node.target not in AMP_CALL_METHOD: return new_graph.node_copy(node, lambda x: val_map[x]) - new_args = create_new_args(node, target_dtype) + new_args = create_new_args(node) - new_kwargs = { - k: (val_map[v] if isinstance(v, fx.Node) else v) - for k, v in node.kwargs.items() - } + new_kwargs = create_new_kwargs(node) return new_graph.call_method( node.target, @@ -179,14 +170,6 @@ def create_call_method(node: fx.Node) -> fx.Node: new_kwargs, ) - def is_fp32_node(node: fx.Node) -> bool: - """Check if a node of float32 only op.""" - if node.op == 'call_function': - return node.target in FP32_ONLY_FUNCS - elif node.op == 'call_method': - return node.target in AMP_CALL_METHOD - return False - for node in gm.graph.nodes: if node.op == "placeholder": val_map[node] = create_placeholder(node)