Add rewrite for argmax/argmin of monotonic functions#1869
Add rewrite for argmax/argmin of monotonic functions#1869Jasjeet-Singh-S wants to merge 2 commits intopymc-devs:mainfrom
Conversation
Implements graph rewrite that eliminates redundant monotonic function applications in argmax/argmin operations. For monotonically increasing functions, rewrites argmax(f(x)) → argmax(x) and argmin(f(x)) → argmin(x). For decreasing functions, flips operations: argmax(f(x)) → argmin(x) and argmin(f(x)) → argmax(x). Includes comprehensive tests.
There was a problem hiding this comment.
Pull request overview
This PR adds a graph rewrite optimization that eliminates unnecessary function evaluations when computing argmax or argmin of monotonic functions. The optimization leverages the property that monotonic functions preserve ordering, so argmax(exp(x)) can be simplified to argmax(x).
Changes:
- Adds
MONOTONIC_INCREASINGandMONOTONIC_DECREASINGtuples to classify scalar operations by monotonicity - Implements
local_argmax_argmin_monotonicrewriter that optimizes argmax/argmin of monotonic functions - Adds comprehensive test suite with parametrized tests for different axis values
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
pytensor/tensor/rewriting/math.py |
Adds monotonic function classifications and implements the core rewrite logic for argmax/argmin optimization |
tests/tensor/rewriting/test_math.py |
Adds test class with parametrized tests for increasing and decreasing monotonic functions |
pytensor/tensor/rewriting/math.py
Outdated
|
|
||
| MONOTONIC_INCREASING = ( | ||
| ps.Exp, ps.Exp2, ps.Expm1, ps.Log, ps.Log2, ps.Log10, ps.Log1p, | ||
| ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.Tan, ps.ArcTan, |
There was a problem hiding this comment.
ps.Tan should not be included in MONOTONIC_INCREASING. The tangent function is periodic and only monotonic within each period (e.g., on (-π/2, π/2)). For arrays spanning multiple periods, argmax(tan(x)) ≠ argmax(x). For example, with x = [0, π], tan(0) = 0 and tan(π) ≈ 0, so this optimization would be incorrect.
| ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.Tan, ps.ArcTan, | |
| ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.ArcTan, |
pytensor/tensor/rewriting/math.py
Outdated
| ps.ArcCosh, ps.Sinh, ps.ArcSinh, ps.Tanh, ps.ArcTanh | ||
| ) | ||
|
|
||
| MONOTONIC_DECREASING = (ps.Neg, ps.Reciprocal, ps.ArcCos) |
There was a problem hiding this comment.
ps.Reciprocal should not be included in MONOTONIC_DECREASING. The reciprocal function (1/x) is not globally monotonic - it's discontinuous at 0 and changes sign. For example, if x = [-2, -1, 1, 2], then 1/x = [-0.5, -1, 1, 0.5], giving argmax(1/x) = 2 but argmin(x) = 0, which are different. The function is only monotonically decreasing on each of (0, ∞) and (-∞, 0) separately, not globally.
| MONOTONIC_DECREASING = (ps.Neg, ps.Reciprocal, ps.ArcCos) | |
| MONOTONIC_DECREASING = (ps.Neg, ps.ArcCos) |
Add rewrite for argmax/argmin of monotonic functions
Summary
This PR implements a graph rewrite that optimizes
argmax/argminoperations applied to monotonic functions by eliminating unnecessary function evaluations. for issue #1851Motivation
Computing
argmax(exp(x))is wasteful because the exponential computation doesn't affect which index has the maximum value - we only care about relative ordering. Since monotonic functions preserve ordering, we can skip the expensive function application entirely.Implementation
New rewrite:
local_argmax_argmin_monotonicThe rewrite handles four transformation paths based on function monotonicity:
Monotonically Increasing Functions
argmax(f(x)) → argmax(x)argmin(f(x)) → argmin(x)Supported increasing functions:
Exp,Exp2,Expm1,Log,Log2,Log10,Log1p,Sqrt,Deg2Rad,Rad2Deg,ArcSin,Tan,ArcTan,ArcCosh,Sinh,ArcSinh,Tanh,ArcTanhMonotonically Decreasing Functions
argmax(f(x)) → argmin(x)argmin(f(x)) → argmax(x)Supported decreasing functions:
Neg,Reciprocal,ArcCosKey Features
argminwhich is internally represented asArgmax(Neg(...))in PyTensorNone,0,-1, etc.)Elemwisewrapper detection to identify scalar operationscopy_stack_traceChanges
pytensor/tensor/rewriting/math.pyMONOTONIC_INCREASINGtuple containing 18 monotonically increasing scalar operationsMONOTONIC_DECREASINGtuple containing 3 monotonically decreasing scalar operations_is_argmin()helper function to detect argmin patterns (handlesArgmax(Neg(...))representation)local_argmax_argmin_monotonic()rewriter with@register_canonicalizedecoratortests/tensor/rewriting/test_math.pyTestArgmaxArgminMonotonictest class with comprehensive coverage:test_argmax_increasing_functions- Tests rewrite for increasing functions with argmaxtest_argmin_increasing_functions- Tests rewrite for increasing functions with argmintest_argmax_decreasing_functions- Tests rewrite for decreasing functions with argmax (flips to argmin)test_argmin_decreasing_functions- Tests rewrite for decreasing functions with argmin (flips to argmax)None,0,-1)Example
Performance Impact
This rewrite provides significant speedups when:
Testing
All tests pass with various configurations:
None,0,-1)The rewrite correctly handles edge cases including:
Argmax(Neg(...))representation forargmin