Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 61 additions & 17 deletions graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
"bmm",
}

FP32_SENSITIVE_OPS = {
torch.nn.functional.layer_norm: ({2, 3, 4}, {"weight", "bias", "eps"}),
torch.nn.functional.group_norm: ({2, 3, 4}, {"weight", "bias", "eps"}),
torch.nn.functional.batch_norm: ({1, 2, 3, 4}, {"running_mean", "running_var", "weight", "bias"}),
torch.nn.functional.embedding: ({0}, {"weight"}),
}

class ConcretePass(DtypeGeneralizationPass):
"""
Expand Down Expand Up @@ -78,6 +84,27 @@ def _node_need_rewrite(self, node: fx.Node) -> bool:

return False

def _analyze_preserved_nodes(self, graph: fx.Graph) -> set[fx.Node]:
"""预扫描图:找到所有被 FP32 敏感算子使用的参数节点。"""
preserved_nodes = set()

for node in graph.nodes:
if node.op != "call_function":
continue

if node.target in FP32_SENSITIVE_OPS:
target_indices, target_kwargs = FP32_SENSITIVE_OPS[node.target]

for i, arg in enumerate(node.args):
if i in target_indices and isinstance(arg, fx.Node):
preserved_nodes.add(arg)

for k, v in node.kwargs.items():
if k in target_kwargs and isinstance(v, fx.Node):
preserved_nodes.add(v)

return preserved_nodes

def rewrite(self, gm: fx.GraphModule) -> fx.GraphModule:
"""
Rewrite the graph to convert dtypes.
Expand All @@ -89,27 +116,41 @@ def rewrite(self, gm: fx.GraphModule) -> fx.GraphModule:
"""
new_graph = fx.Graph()
val_map = {}

preserved_nodes = self._analyze_preserved_nodes(gm.graph)

def create_placeholder(node: fx.Node) -> fx.Node:
"""Create a placeholder node with dtype conversion if needed."""
new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x))
if self._is_float32_tensor(node):
attr_name = str(node.target)
if self.should_preserve_weight(attr_name):
return new_node

return new_graph.call_method("to", args=(new_node, self.torch_dtype))
return new_node
# if self._is_float32_tensor(node):
# attr_name = str(node.target)
# if self.should_preserve_weight(attr_name):
# return new_node
# return new_graph.call_method("to", args=(new_node, self.torch_dtype))
# return new_node
if not self._is_float32_tensor(node):
return new_node

if node in preserved_nodes:
return new_node

return new_graph.call_method("to", args=(new_node, self.torch_dtype))

def create_get_attr(node: fx.Node) -> fx.Node:
"""Create a get_attr node with dtype conversion if needed."""
new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x))
attr_name = str(node.target)
if self._is_float32_tensor(node) and not self.should_preserve_weight(
attr_name
):
return new_graph.call_method("to", args=(new_node, self.torch_dtype))
return new_node
# attr_name = str(node.target)
# if self._is_float32_tensor(node) and not self.should_preserve_weight(
# attr_name
# ):
# return new_graph.call_method("to", args=(new_node, self.torch_dtype))
# return new_node
if not self._is_float32_tensor(node):
return new_node

if node in preserved_nodes:
return new_node

return new_graph.call_method("to", args=(new_node, self.torch_dtype))

def create_new_args(node: fx.Node) -> list:
"""new_args of node with dtype conversion if needed."""
Expand All @@ -128,16 +169,16 @@ def create_new_args(node: fx.Node) -> list:
def create_new_kwargs(node: fx.Node) -> dict:
"""new_kwargs of node with dtype conversion if needed."""
new_kwargs = {}

for k, v in node.kwargs.items():
if isinstance(v, fx.Node):
mapped = val_map[v]
if self._is_float32_tensor(v):
mapped = new_graph.call_method("to", (mapped, self.torch_dtype))
else:
else:
new_kwargs[k] = mapped
else:
new_kwargs[k] = v
new_kwargs[k] = v
return new_kwargs

def create_call_function(node: fx.Node) -> fx.Node:
Expand Down Expand Up @@ -186,6 +227,9 @@ def create_call_method(node: fx.Node) -> fx.Node:
gm.graph = new_graph
gm.recompile()

with open("output.txt", "w", encoding="utf-8") as f:
print(gm.graph, file=f)

return gm

def _is_float32_tensor(self, node: fx.Node) -> bool:
Expand Down
Loading