@@ -210,6 +210,10 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
210210 This does not copy the data of a :class:`pytato.array.DataWrapper`.
211211 """
212212
213+ # type-ignore reason: incompatible type with Mapper.rec
214+ def __call__ (self , expr : T ) -> T : # type: ignore[override]
215+ return self .rec (expr ) # type: ignore[no-any-return]
216+
213217 def rec_idx_or_size_tuple (self , situp : Tuple [IndexOrShapeExpr , ...]
214218 ) -> Tuple [IndexOrShapeExpr , ...]:
215219 return tuple (self .rec (s ) if isinstance (s , Array ) else s for s in situp )
@@ -1706,6 +1710,33 @@ def map_distributed_recv(self, expr: DistributedRecv, *args: Any) \
17061710# }}}
17071711
17081712
1713+ # {{{ BranchMorpher
1714+
1715+ class BranchMorpher (CopyMapper ):
1716+ """
1717+ A mapper that replaces equal segments of graphs with identical objects.
1718+ """
1719+ def __init__ (self ) -> None :
1720+ super ().__init__ ()
1721+ self .result_cache : Dict [ArrayOrNames , ArrayOrNames ] = {}
1722+
1723+ def cache_key (self , expr : CachedMapperT ) -> Any :
1724+ return (id (expr ), expr )
1725+
1726+ # type-ignore reason: incompatible with Mapper.rec
1727+ def rec (self , expr : T ) -> T : # type: ignore[override]
1728+ rec_expr = super ().rec (expr )
1729+ try :
1730+ # type-ignored because 'result_cache' maps to ArrayOrNames
1731+ return self .result_cache [rec_expr ] # type: ignore[return-value]
1732+ except KeyError :
1733+ self .result_cache [rec_expr ] = rec_expr
1734+ # type-ignored because of super-class' relaxed types
1735+ return rec_expr # type: ignore[no-any-return]
1736+
1737+ # }}}
1738+
1739+
17091740# {{{ deduplicate_data_wrappers
17101741
17111742def _get_data_dedup_cache_key (ary : DataInterface ) -> Hashable :
@@ -1782,8 +1813,10 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames:
17821813 len (data_wrapper_cache ),
17831814 data_wrappers_encountered - len (data_wrapper_cache ))
17841815
1785- return array_or_names
1816+ # many paths in the DAG might be morphed after DWs are deduplicated => morph them
1817+ return BranchMorpher ()(array_or_names )
17861818
17871819# }}}
17881820
1821+
17891822# vim: foldmethod=marker
0 commit comments