diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index 91400d1e0..46b77b5a4 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -32,6 +32,12 @@ "bmm", } +FP32_SENSITIVE_OPS = { + torch.nn.functional.layer_norm: ({2, 3, 4}, {"weight", "bias", "eps"}), + torch.nn.functional.group_norm: ({2, 3, 4}, {"weight", "bias", "eps"}), + torch.nn.functional.batch_norm: ({1, 2, 3, 4}, {"running_mean", "running_var", "weight", "bias"}), + torch.nn.functional.embedding: ({0}, {"weight"}), +} class ConcretePass(DtypeGeneralizationPass): """ @@ -78,6 +84,27 @@ def _node_need_rewrite(self, node: fx.Node) -> bool: return False + def _analyze_preserved_nodes(self, graph: fx.Graph) -> set[fx.Node]: + """预扫描图:找到所有被 FP32 敏感算子使用的参数节点。""" + preserved_nodes = set() + + for node in graph.nodes: + if node.op != "call_function": + continue + + if node.target in FP32_SENSITIVE_OPS: + target_indices, target_kwargs = FP32_SENSITIVE_OPS[node.target] + + for i, arg in enumerate(node.args): + if i in target_indices and isinstance(arg, fx.Node): + preserved_nodes.add(arg) + + for k, v in node.kwargs.items(): + if k in target_kwargs and isinstance(v, fx.Node): + preserved_nodes.add(v) + + return preserved_nodes + def rewrite(self, gm: fx.GraphModule) -> fx.GraphModule: """ Rewrite the graph to convert dtypes. @@ -89,27 +116,41 @@ def rewrite(self, gm: fx.GraphModule) -> fx.GraphModule: """ new_graph = fx.Graph() val_map = {} - + preserved_nodes = self._analyze_preserved_nodes(gm.graph) + def create_placeholder(node: fx.Node) -> fx.Node: """Create a placeholder node with dtype conversion if needed.""" new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x)) - if self._is_float32_tensor(node): - attr_name = str(node.target) - if self.should_preserve_weight(attr_name): - return new_node - - return new_graph.call_method("to", args=(new_node, self.torch_dtype)) - return new_node + # if self._is_float32_tensor(node): + # attr_name = str(node.target) + # if self.should_preserve_weight(attr_name): + # return new_node + # return new_graph.call_method("to", args=(new_node, self.torch_dtype)) + # return new_node + if not self._is_float32_tensor(node): + return new_node + + if node in preserved_nodes: + return new_node + + return new_graph.call_method("to", args=(new_node, self.torch_dtype)) def create_get_attr(node: fx.Node) -> fx.Node: """Create a get_attr node with dtype conversion if needed.""" new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x)) - attr_name = str(node.target) - if self._is_float32_tensor(node) and not self.should_preserve_weight( - attr_name - ): - return new_graph.call_method("to", args=(new_node, self.torch_dtype)) - return new_node + # attr_name = str(node.target) + # if self._is_float32_tensor(node) and not self.should_preserve_weight( + # attr_name + # ): + # return new_graph.call_method("to", args=(new_node, self.torch_dtype)) + # return new_node + if not self._is_float32_tensor(node): + return new_node + + if node in preserved_nodes: + return new_node + + return new_graph.call_method("to", args=(new_node, self.torch_dtype)) def create_new_args(node: fx.Node) -> list: """new_args of node with dtype conversion if needed.""" @@ -128,16 +169,16 @@ def create_new_args(node: fx.Node) -> list: def create_new_kwargs(node: fx.Node) -> dict: """new_kwargs of node with dtype conversion if needed.""" new_kwargs = {} - + for k, v in node.kwargs.items(): if isinstance(v, fx.Node): mapped = val_map[v] if self._is_float32_tensor(v): mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) - else: + else: new_kwargs[k] = mapped else: - new_kwargs[k] = v + new_kwargs[k] = v return new_kwargs def create_call_function(node: fx.Node) -> fx.Node: @@ -186,6 +227,9 @@ def create_call_method(node: fx.Node) -> fx.Node: gm.graph = new_graph gm.recompile() + with open("output.txt", "w", encoding="utf-8") as f: + print(gm.graph, file=f) + return gm def _is_float32_tensor(self, node: fx.Node) -> bool: