Skip to content

Add rewrite for argmax/argmin of monotonic functions#1869

Open
Jasjeet-Singh-S wants to merge 2 commits intopymc-devs:mainfrom
Jasjeet-Singh-S:argmax-argmin-monotonic-rewrite
Open

Add rewrite for argmax/argmin of monotonic functions#1869
Jasjeet-Singh-S wants to merge 2 commits intopymc-devs:mainfrom
Jasjeet-Singh-S:argmax-argmin-monotonic-rewrite

Conversation

@Jasjeet-Singh-S
Copy link

@Jasjeet-Singh-S Jasjeet-Singh-S commented Feb 1, 2026

Add rewrite for argmax/argmin of monotonic functions

Summary

This PR implements a graph rewrite that optimizes argmax/argmin operations applied to monotonic functions by eliminating unnecessary function evaluations. for issue #1851

Motivation

Computing argmax(exp(x)) is wasteful because the exponential computation doesn't affect which index has the maximum value - we only care about relative ordering. Since monotonic functions preserve ordering, we can skip the expensive function application entirely.

Implementation

New rewrite: local_argmax_argmin_monotonic

The rewrite handles four transformation paths based on function monotonicity:

Monotonically Increasing Functions

  • argmax(f(x)) → argmax(x)
  • argmin(f(x)) → argmin(x)

Supported increasing functions: Exp, Exp2, Expm1, Log, Log2, Log10, Log1p, Sqrt, Deg2Rad, Rad2Deg, ArcSin, Tan, ArcTan, ArcCosh, Sinh, ArcSinh, Tanh, ArcTanh

Monotonically Decreasing Functions

  • argmax(f(x)) → argmin(x)
  • argmin(f(x)) → argmax(x)

Supported decreasing functions: Neg, Reciprocal, ArcCos

Key Features

  • Handles PyTensor's internal representation: Correctly processes argmin which is internally represented as Argmax(Neg(...)) in PyTensor
  • Preserves axis parameter: Works correctly with different axis specifications (None, 0, -1, etc.)
  • Robust pattern matching: Uses Elemwise wrapper detection to identify scalar operations
  • Stack trace preservation: Maintains debugging information via copy_stack_trace

Changes

pytensor/tensor/rewriting/math.py

  • Added MONOTONIC_INCREASING tuple containing 18 monotonically increasing scalar operations
  • Added MONOTONIC_DECREASING tuple containing 3 monotonically decreasing scalar operations
  • Implemented _is_argmin() helper function to detect argmin patterns (handles Argmax(Neg(...)) representation)
  • Implemented local_argmax_argmin_monotonic() rewriter with @register_canonicalize decorator

tests/tensor/rewriting/test_math.py

  • Added TestArgmaxArgminMonotonic test class with comprehensive coverage:
    • test_argmax_increasing_functions - Tests rewrite for increasing functions with argmax
    • test_argmin_increasing_functions - Tests rewrite for increasing functions with argmin
    • test_argmax_decreasing_functions - Tests rewrite for decreasing functions with argmax (flips to argmin)
    • test_argmin_decreasing_functions - Tests rewrite for decreasing functions with argmin (flips to argmax)
  • All tests parametrized over multiple axis values (None, 0, -1)
  • Tests verify both numerical correctness and graph structure optimization

Example

import pytensor.tensor as pt
import numpy as np

x = pt.vector('x')
y = pt.argmax(pt.exp(x))  # Before: computes exp then argmax
                           # After: computes argmax directly

# The rewrite eliminates the expensive exp() computation
# since argmax(exp(x)) = argmax(x) for monotonic functions

Performance Impact

This rewrite provides significant speedups when:

  • Computing argmax/argmin of exponentials, logarithms, or other monotonic transformations
  • Working with large arrays where the eliminated operations would be expensive
  • The monotonic function application is the dominant computational cost

Testing

All tests pass with various configurations:

  • Multiple monotonic functions (18 increasing, 3 decreasing)
  • Different axis specifications (None, 0, -1)
  • Numerical correctness verification against expected results
  • Graph structure validation to ensure rewrites are applied correctly

The rewrite correctly handles edge cases including:

  • PyTensor's internal Argmax(Neg(...)) representation for argmin
  • Broadcasting and dimension handling
  • Proper flipping between argmax/argmin for decreasing functions

Implements graph rewrite that eliminates redundant monotonic function applications in argmax/argmin operations. For monotonically increasing functions, rewrites argmax(f(x)) → argmax(x) and argmin(f(x)) → argmin(x). For decreasing functions, flips operations: argmax(f(x)) → argmin(x) and argmin(f(x)) → argmax(x). Includes comprehensive tests.
Copilot AI review requested due to automatic review settings February 1, 2026 17:22
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a graph rewrite optimization that eliminates unnecessary function evaluations when computing argmax or argmin of monotonic functions. The optimization leverages the property that monotonic functions preserve ordering, so argmax(exp(x)) can be simplified to argmax(x).

Changes:

  • Adds MONOTONIC_INCREASING and MONOTONIC_DECREASING tuples to classify scalar operations by monotonicity
  • Implements local_argmax_argmin_monotonic rewriter that optimizes argmax/argmin of monotonic functions
  • Adds comprehensive test suite with parametrized tests for different axis values

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
pytensor/tensor/rewriting/math.py Adds monotonic function classifications and implements the core rewrite logic for argmax/argmin optimization
tests/tensor/rewriting/test_math.py Adds test class with parametrized tests for increasing and decreasing monotonic functions


MONOTONIC_INCREASING = (
ps.Exp, ps.Exp2, ps.Expm1, ps.Log, ps.Log2, ps.Log10, ps.Log1p,
ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.Tan, ps.ArcTan,
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ps.Tan should not be included in MONOTONIC_INCREASING. The tangent function is periodic and only monotonic within each period (e.g., on (-π/2, π/2)). For arrays spanning multiple periods, argmax(tan(x)) ≠ argmax(x). For example, with x = [0, π], tan(0) = 0 and tan(π) ≈ 0, so this optimization would be incorrect.

Suggested change
ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.Tan, ps.ArcTan,
ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.ArcTan,

Copilot uses AI. Check for mistakes.
ps.ArcCosh, ps.Sinh, ps.ArcSinh, ps.Tanh, ps.ArcTanh
)

MONOTONIC_DECREASING = (ps.Neg, ps.Reciprocal, ps.ArcCos)
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ps.Reciprocal should not be included in MONOTONIC_DECREASING. The reciprocal function (1/x) is not globally monotonic - it's discontinuous at 0 and changes sign. For example, if x = [-2, -1, 1, 2], then 1/x = [-0.5, -1, 1, 0.5], giving argmax(1/x) = 2 but argmin(x) = 0, which are different. The function is only monotonically decreasing on each of (0, ∞) and (-∞, 0) separately, not globally.

Suggested change
MONOTONIC_DECREASING = (ps.Neg, ps.Reciprocal, ps.ArcCos)
MONOTONIC_DECREASING = (ps.Neg, ps.ArcCos)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant