Skip to content

Conversation

@Jasjeet-Singh-S
Copy link

@Jasjeet-Singh-S Jasjeet-Singh-S commented Feb 1, 2026

Add rewrite for argmax/argmin/max/min of monotonic functions

Closes #1851

Summary

This PR implements graph rewrites that optimize argmax/argmin/max/min operations applied to monotonic functions by eliminating unnecessary function evaluations.

Motivation

Computing argmax(exp(x)) or max(exp(x)) is wasteful because the exponential computation doesn't affect which index has the maximum value or the relative ordering - we only care about the ordering relationship. Since monotonic functions preserve ordering, we can skip expensive function applications entirely for argmax/argmin, or move them outside the reduction for max/min.

Implementation

New rewrites:

  1. local_argmax_argmin_monotonic - for argmax/argmin operations
  2. local_max_min_monotonic - for max/min operations

Argmax/Argmin Rewrites

The rewrite handles four transformation paths based on function monotonicity:

Monotonically Increasing Functions

Supported increasing functions: Exp, Exp2, Expm1, Log, Log2, Log10, Log1p, Sqrt, Deg2Rad, Rad2Deg, ArcSin, Tan, ArcTan, ArcCosh, Sinh, ArcSinh, Tanh, ArcTanh

Monotonically Decreasing Functions

Supported decreasing functions: Neg, Reciprocal, ArcCos

Max/Min Rewrites

The rewrite handles four transformation paths by moving the monotonic function outside the reduction:

Monotonically Increasing Functions

Monotonically Decreasing Functions

Same supported increasing and decreasing functions as argmax/argmin.

Key Features

  • Handles PyTensor's internal representation: Correctly processes argmin which is internally represented as Argmax(Neg(...)) in PyTensor
  • Preserves axis parameter: Works correctly with different axis specifications (None, 0, -1, etc.)
  • Robust pattern matching: Uses Elemwise wrapper detection to identify scalar operations
  • Stack trace preservation: Maintains debugging information via copy_stack_trace
  • Comprehensive function coverage: Handles both value-returning (max/min) and index-returning (argmax/argmin) reductions

Changes

pytensor/tensor/rewriting/math.py

  • Added MONOTONIC_INCREASING tuple containing 18 monotonically increasing scalar operations
  • Added MONOTONIC_DECREASING tuple containing 3 monotonically decreasing scalar operations
  • Implemented _is_argmin() helper function to detect argmin patterns (handles Argmax(Neg(...)) representation)
  • Implemented local_argmax_argmin_monotonic() rewriter for argmax/argmin with @register_canonicalize decorator
  • Implemented local_max_min_monotonic() rewriter for max/min with @register_canonicalize decorator

tests/tensor/rewriting/test_math.py

Example

Argmax/Argmin Example

import pytensor.tensor as pt
import numpy as np

x = pt.vector('x')
y_argmax = pt.argmax(pt.exp(x))  # Before: computes exp then argmax
                                  # After: computes argmax directly
y_argmin = pt.argmin(pt.exp(x))  # Before: computes exp then argmin
                                  # After: computes argmin directly

# The rewrite eliminates the expensive exp() computation
# since argmax(exp(x)) = argmax(x) for monotonic functions

Max/Min Example

import pytensor.tensor as pt
import numpy as np

x = pt.vector('x')
y_max = pt.max(pt.exp(x))  # Before: computes exp then max
                            # After: computes max then exp (exp(max(x)))
y_min = pt.min(pt.exp(x))  # Before: computes exp then min
                            # After: computes min then exp (exp(min(x)))

# The rewrite moves the expensive exp() outside the reduction
# reducing the number of exponential computations from len(x) to 1

Performance Impact

These rewrites provide significant speedups when:

  • Computing argmax/argmin/max/min of exponentials, logarithms, or other monotonic transformations
  • Working with large arrays where:
    • For argmax/argmin: the eliminated operations would be expensive
    • For max/min: reducing N expensive operations to 1 (where N is the size of the reduction dimension)
  • The monotonic function application is the dominant computational cost

Testing

All 24 tests pass with various configurations:

  • 8 argmax/argmin tests (4 test methods × 3 axis configurations × multiple functions)
  • 12 max/min tests (4 test methods × 3 axis configurations × multiple functions)
  • Multiple monotonic functions tested (18 increasing, 3 decreasing)
  • Different axis specifications (None, 0, -1)
  • Numerical correctness verification against expected results
  • Graph structure validation to ensure rewrites are applied correctly

The rewrites correctly handle edge cases including:

  • PyTensor's internal Argmax(Neg(...)) representation for argmin
  • Broadcasting and dimension handling
  • Proper flipping between argmax/argmin for decreasing functions
  • Moving functions outside reductions for max/min operations

Implements graph rewrite that eliminates redundant monotonic function applications in argmax/argmin operations. For monotonically increasing functions, rewrites argmax(f(x)) → argmax(x) and argmin(f(x)) → argmin(x). For decreasing functions, flips operations: argmax(f(x)) → argmin(x) and argmin(f(x)) → argmax(x). Includes comprehensive tests.
Copilot AI review requested due to automatic review settings February 1, 2026 17:22
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a graph rewrite optimization that eliminates unnecessary function evaluations when computing argmax or argmin of monotonic functions. The optimization leverages the property that monotonic functions preserve ordering, so argmax(exp(x)) can be simplified to argmax(x).

Changes:

  • Adds MONOTONIC_INCREASING and MONOTONIC_DECREASING tuples to classify scalar operations by monotonicity
  • Implements local_argmax_argmin_monotonic rewriter that optimizes argmax/argmin of monotonic functions
  • Adds comprehensive test suite with parametrized tests for different axis values

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
pytensor/tensor/rewriting/math.py Adds monotonic function classifications and implements the core rewrite logic for argmax/argmin optimization
tests/tensor/rewriting/test_math.py Adds test class with parametrized tests for increasing and decreasing monotonic functions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Rewrite argmax(f(x) -> argmax(x)` for monotonic f

1 participant