@@ -1317,8 +1317,6 @@ def perform(self, node, inputs, outputs):
13171317 z [0 ] = y
13181318
13191319 def grad (self , inputs , gout ):
1320- from pytensor .sparse .math import sp_sum
1321-
13221320 (x , s ) = inputs
13231321 (gz ,) = gout
13241322 return [col_scale (gz , s ), sp_sum (x * gz , axis = 0 )]
@@ -1368,8 +1366,6 @@ def perform(self, node, inputs, outputs):
13681366 z [0 ] = scipy .sparse .csc_matrix ((y_data , indices , indptr ), (M , N ))
13691367
13701368 def grad (self , inputs , gout ):
1371- from pytensor .sparse .math import sp_sum
1372-
13731369 (x , s ) = inputs
13741370 (gz ,) = gout
13751371 return [row_scale (gz , s ), sp_sum (x * gz , axis = 1 )]
@@ -1435,6 +1431,126 @@ def row_scale(x, s):
14351431 return col_scale (x .T , s ).T
14361432
14371433
1434+ class SpSum (Op ):
1435+ """
1436+
1437+ WARNING: judgement call...
1438+ We are not using the structured in the comparison or hashing
1439+ because it doesn't change the perform method therefore, we
1440+ *do* want Sums with different structured values to be merged
1441+ by the merge optimization and this requires them to compare equal.
1442+ """
1443+
1444+ __props__ = ("axis" ,)
1445+
1446+ def __init__ (self , axis = None , sparse_grad = True ):
1447+ super ().__init__ ()
1448+ self .axis = axis
1449+ self .structured = sparse_grad
1450+ if self .axis not in (None , 0 , 1 ):
1451+ raise ValueError ("Illegal value for self.axis." )
1452+
1453+ def make_node (self , x ):
1454+ x = as_sparse_variable (x )
1455+ assert x .format in ("csr" , "csc" )
1456+
1457+ if self .axis is not None :
1458+ out_shape = (None ,)
1459+ else :
1460+ out_shape = ()
1461+
1462+ z = TensorType (dtype = x .dtype , shape = out_shape )()
1463+ return Apply (self , [x ], [z ])
1464+
1465+ def perform (self , node , inputs , outputs ):
1466+ (x ,) = inputs
1467+ (z ,) = outputs
1468+ if self .axis is None :
1469+ z [0 ] = np .asarray (x .sum ())
1470+ else :
1471+ z [0 ] = np .asarray (x .sum (self .axis )).ravel ()
1472+
1473+ def grad (self , inputs , gout ):
1474+ (x ,) = inputs
1475+ (gz ,) = gout
1476+ if x .dtype not in continuous_dtypes :
1477+ return [x .zeros_like (dtype = config .floatX )]
1478+ if self .structured :
1479+ if self .axis is None :
1480+ r = gz * sp_ones_like (x )
1481+ elif self .axis == 0 :
1482+ r = col_scale (sp_ones_like (x ), gz )
1483+ elif self .axis == 1 :
1484+ r = row_scale (sp_ones_like (x ), gz )
1485+ else :
1486+ raise ValueError ("Illegal value for self.axis." )
1487+ else :
1488+ o_format = x .format
1489+ x = dense_from_sparse (x )
1490+ if _is_sparse_variable (gz ):
1491+ gz = dense_from_sparse (gz )
1492+ if self .axis is None :
1493+ r = ptb .second (x , gz )
1494+ else :
1495+ ones = ptb .ones_like (x )
1496+ if self .axis == 0 :
1497+ r = specify_broadcastable (gz .dimshuffle ("x" , 0 ), 0 ) * ones
1498+ elif self .axis == 1 :
1499+ r = specify_broadcastable (gz .dimshuffle (0 , "x" ), 1 ) * ones
1500+ else :
1501+ raise ValueError ("Illegal value for self.axis." )
1502+ r = SparseFromDense (o_format )(r )
1503+ return [r ]
1504+
1505+ def infer_shape (self , fgraph , node , shapes ):
1506+ r = None
1507+ if self .axis is None :
1508+ r = [()]
1509+ elif self .axis == 0 :
1510+ r = [(shapes [0 ][1 ],)]
1511+ else :
1512+ r = [(shapes [0 ][0 ],)]
1513+ return r
1514+
1515+ def __str__ (self ):
1516+ return f"{ self .__class__ .__name__ } {{axis={ self .axis } }}"
1517+
1518+
1519+ def sp_sum (x , axis = None , sparse_grad = False ):
1520+ """
1521+ Calculate the sum of a sparse matrix along the specified axis.
1522+
1523+ It operates a reduction along the specified axis. When `axis` is `None`,
1524+ it is applied along all axes.
1525+
1526+ Parameters
1527+ ----------
1528+ x
1529+ Sparse matrix.
1530+ axis
1531+ Axis along which the sum is applied. Integer or `None`.
1532+ sparse_grad : bool
1533+ `True` to have a structured grad.
1534+
1535+ Returns
1536+ -------
1537+ object
1538+ The sum of `x` in a dense format.
1539+
1540+ Notes
1541+ -----
1542+ The grad implementation is controlled with the `sparse_grad` parameter.
1543+ `True` will provide a structured grad and `False` will provide a regular
1544+ grad. For both choices, the grad returns a sparse matrix having the same
1545+ format as `x`.
1546+
1547+ This op does not return a sparse matrix, but a dense tensor matrix.
1548+
1549+ """
1550+
1551+ return SpSum (axis , sparse_grad )(x )
1552+
1553+
14381554class Diag (Op ):
14391555 """Extract the diagonal of a square sparse matrix as a dense vector.
14401556
@@ -1944,3 +2060,6 @@ def grad(self, inputs, grads):
19442060
19452061
19462062construct_sparse_from_list = ConstructSparseFromList ()
2063+
2064+ # Import sp_sum from math to maintain backward compatibility
2065+ # This must be at the end to avoid circular imports
0 commit comments