Skip to content

Commit 3021d89

Browse files
committed
Call BranchMorpher after dw deduplication
1 parent 644eacb commit 3021d89

2 files changed

Lines changed: 35 additions & 2 deletions

File tree

pytato/transform/__init__.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
210210
This does not copy the data of a :class:`pytato.array.DataWrapper`.
211211
"""
212212

213+
# type-ignore reason: incompatible type with Mapper.rec
214+
def __call__(self, expr: T) -> T: # type: ignore[override]
215+
return self.rec(expr) # type: ignore[no-any-return]
216+
213217
def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...]
214218
) -> Tuple[IndexOrShapeExpr, ...]:
215219
return tuple(self.rec(s) if isinstance(s, Array) else s for s in situp)
@@ -1706,6 +1710,33 @@ def map_distributed_recv(self, expr: DistributedRecv, *args: Any) \
17061710
# }}}
17071711

17081712

1713+
# {{{ BranchMorpher
1714+
1715+
class BranchMorpher(CopyMapper):
1716+
"""
1717+
A mapper that replaces equal segments of graphs with identical objects.
1718+
"""
1719+
def __init__(self) -> None:
1720+
super().__init__()
1721+
self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {}
1722+
1723+
def cache_key(self, expr: CachedMapperT) -> Any:
1724+
return (id(expr), expr)
1725+
1726+
# type-ignore reason: incompatible with Mapper.rec
1727+
def rec(self, expr: T) -> T: # type: ignore[override]
1728+
rec_expr = super().rec(expr)
1729+
try:
1730+
# type-ignored because 'result_cache' maps to ArrayOrNames
1731+
return self.result_cache[rec_expr] # type: ignore[return-value]
1732+
except KeyError:
1733+
self.result_cache[rec_expr] = rec_expr
1734+
# type-ignored because of super-class' relaxed types
1735+
return rec_expr # type: ignore[no-any-return]
1736+
1737+
# }}}
1738+
1739+
17091740
# {{{ deduplicate_data_wrappers
17101741

17111742
def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable:
@@ -1782,8 +1813,10 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames:
17821813
len(data_wrapper_cache),
17831814
data_wrappers_encountered - len(data_wrapper_cache))
17841815

1785-
return array_or_names
1816+
# many paths in the DAG might be morphed after DWs are deduplicated => morph them
1817+
return BranchMorpher()(array_or_names)
17861818

17871819
# }}}
17881820

1821+
17891822
# vim: foldmethod=marker

test/test_codegen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,7 @@ def test_zero_size_cl_array_dedup(ctx_factory):
16071607
x4 = pt.make_data_wrapper(x_cl2)
16081608

16091609
out = pt.make_dict_of_named_arrays({"out1": 2*x1,
1610-
"out2": 2*x2,
1610+
"out2": 3*x2,
16111611
"out3": x3 + x4
16121612
})
16131613

0 commit comments

Comments
 (0)