Skip to content

Commit 9016175

Browse files
jaanerikricardoV94
andcommitted
Refactor AdvancedSubtensor
- newaxis is handled as explicit DimShuffel on the inputs - slices are encoded internally, so the Ops only take numerical inputs Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
1 parent 8747006 commit 9016175

26 files changed

Lines changed: 1358 additions & 1528 deletions

File tree

pytensor/graph/destroyhandler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -771,9 +771,9 @@ def orderings(self, fgraph, ordered=True):
771771
}
772772
tolerated.add(destroyed_idx)
773773
tolerate_aliased = getattr(
774-
app.op, "destroyhandler_tolerate_aliased", []
774+
app.op, "destroyhandler_tolerate_aliased", ()
775775
)
776-
assert isinstance(tolerate_aliased, list)
776+
assert isinstance(tolerate_aliased, tuple | list)
777777
ignored = {
778778
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
779779
}

pytensor/link/jax/dispatch/subtensor.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Subtensor,
99
indices_from_subtensor,
1010
)
11-
from pytensor.tensor.type_other import MakeSlice
1211

1312

1413
BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean
@@ -35,10 +34,8 @@
3534
@jax_funcify.register(AdvancedSubtensor)
3635
@jax_funcify.register(AdvancedSubtensor1)
3736
def jax_funcify_Subtensor(op, node, **kwargs):
38-
idx_list = getattr(op, "idx_list", None)
39-
4037
def subtensor(x, *ilists):
41-
indices = indices_from_subtensor(ilists, idx_list)
38+
indices = indices_from_subtensor(ilists, op.idx_list)
4239
if len(indices) == 1:
4340
indices = indices[0]
4441

@@ -48,10 +45,9 @@ def subtensor(x, *ilists):
4845

4946

5047
@jax_funcify.register(IncSubtensor)
48+
@jax_funcify.register(AdvancedIncSubtensor)
5149
@jax_funcify.register(AdvancedIncSubtensor1)
5250
def jax_funcify_IncSubtensor(op, node, **kwargs):
53-
idx_list = getattr(op, "idx_list", None)
54-
5551
if getattr(op, "set_instead_of_inc", False):
5652

5753
def jax_fn(x, indices, y):
@@ -62,7 +58,7 @@ def jax_fn(x, indices, y):
6258
def jax_fn(x, indices, y):
6359
return x.at[indices].add(y)
6460

65-
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
61+
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=op.idx_list):
6662
indices = indices_from_subtensor(ilist, idx_list)
6763
if len(indices) == 1:
6864
indices = indices[0]
@@ -73,29 +69,3 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
7369
return jax_fn(x, indices, y)
7470

7571
return incsubtensor
76-
77-
78-
@jax_funcify.register(AdvancedIncSubtensor)
79-
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
80-
if getattr(op, "set_instead_of_inc", False):
81-
82-
def jax_fn(x, indices, y):
83-
return x.at[indices].set(y)
84-
85-
else:
86-
87-
def jax_fn(x, indices, y):
88-
return x.at[indices].add(y)
89-
90-
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
91-
return jax_fn(x, ilist, y)
92-
93-
return advancedincsubtensor
94-
95-
96-
@jax_funcify.register(MakeSlice)
97-
def jax_funcify_MakeSlice(op, **kwargs):
98-
def makeslice(*x):
99-
return slice(*x)
100-
101-
return makeslice

pytensor/link/mlx/dispatch/subtensor.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@
1010
Subtensor,
1111
indices_from_subtensor,
1212
)
13-
from pytensor.tensor.type_other import MakeSlice
1413

1514

1615
@mlx_funcify.register(Subtensor)
1716
def mlx_funcify_Subtensor(op, node, **kwargs):
18-
idx_list = getattr(op, "idx_list", None)
19-
2017
def subtensor(x, *ilists):
21-
indices = indices_from_subtensor([int(element) for element in ilists], idx_list)
18+
indices = indices_from_subtensor(
19+
[int(element) for element in ilists], op.idx_list
20+
)
2221
if len(indices) == 1:
2322
indices = indices[0]
2423

@@ -30,10 +29,8 @@ def subtensor(x, *ilists):
3029
@mlx_funcify.register(AdvancedSubtensor)
3130
@mlx_funcify.register(AdvancedSubtensor1)
3231
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
33-
idx_list = getattr(op, "idx_list", None)
34-
3532
def advanced_subtensor(x, *ilists):
36-
indices = indices_from_subtensor(ilists, idx_list)
33+
indices = indices_from_subtensor(ilists, op.idx_list)
3734
if len(indices) == 1:
3835
indices = indices[0]
3936

@@ -45,8 +42,6 @@ def advanced_subtensor(x, *ilists):
4542
@mlx_funcify.register(IncSubtensor)
4643
@mlx_funcify.register(AdvancedIncSubtensor1)
4744
def mlx_funcify_IncSubtensor(op, node, **kwargs):
48-
idx_list = getattr(op, "idx_list", None)
49-
5045
if getattr(op, "set_instead_of_inc", False):
5146

5247
def mlx_fn(x, indices, y):
@@ -63,7 +58,7 @@ def mlx_fn(x, indices, y):
6358
x[indices] += y
6459
return x
6560

66-
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
61+
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list):
6762
indices = indices_from_subtensor(ilist, idx_list)
6863
if len(indices) == 1:
6964
indices = indices[0]
@@ -95,11 +90,3 @@ def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn):
9590
return mlx_fn(x, ilist, y)
9691

9792
return advancedincsubtensor
98-
99-
100-
@mlx_funcify.register(MakeSlice)
101-
def mlx_funcify_MakeSlice(op, **kwargs):
102-
def makeslice(*x):
103-
return slice(*x)
104-
105-
return makeslice

0 commit comments

Comments
 (0)