Add rewrite for argmax/argmin of monotonic functions #1869
+385
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add rewrite for argmax/argmin/max/min of monotonic functions
Closes #1851
Summary
This PR implements graph rewrites that optimize
argmax/argmin/max/minoperations applied to monotonic functions by eliminating unnecessary function evaluations.Motivation
Computing
argmax(exp(x))ormax(exp(x))is wasteful because the exponential computation doesn't affect which index has the maximum value or the relative ordering - we only care about the ordering relationship. Since monotonic functions preserve ordering, we can skip expensive function applications entirely for argmax/argmin, or move them outside the reduction for max/min.Implementation
New rewrites:
local_argmax_argmin_monotonic- for argmax/argmin operationslocal_max_min_monotonic- for max/min operationsArgmax/Argmin Rewrites
The rewrite handles four transformation paths based on function monotonicity:
Monotonically Increasing Functions
argmax(f(x)) → argmax(x)argmin(f(x)) → argmin(x)Supported increasing functions:
Exp,Exp2,Expm1,Log,Log2,Log10,Log1p,Sqrt,Deg2Rad,Rad2Deg,ArcSin,Tan,ArcTan,ArcCosh,Sinh,ArcSinh,Tanh,ArcTanhMonotonically Decreasing Functions
argmax(f(x)) → argmin(x)argmin(f(x)) → argmax(x)Supported decreasing functions:
Neg,Reciprocal,ArcCosMax/Min Rewrites
The rewrite handles four transformation paths by moving the monotonic function outside the reduction:
Monotonically Increasing Functions
max(f(x)) → f(max(x))min(f(x)) → f(min(x))Monotonically Decreasing Functions
max(f(x)) → f(min(x))min(f(x)) → f(max(x))Same supported increasing and decreasing functions as argmax/argmin.
Key Features
argminwhich is internally represented asArgmax(Neg(...))in PyTensorNone,0,-1, etc.)Elemwisewrapper detection to identify scalar operationscopy_stack_traceChanges
pytensor/tensor/rewriting/math.pyMONOTONIC_INCREASINGtuple containing 18 monotonically increasing scalar operationsMONOTONIC_DECREASINGtuple containing 3 monotonically decreasing scalar operations_is_argmin()helper function to detect argmin patterns (handlesArgmax(Neg(...))representation)local_argmax_argmin_monotonic()rewriter for argmax/argmin with@register_canonicalizedecoratorlocal_max_min_monotonic()rewriter for max/min with@register_canonicalizedecoratortests/tensor/rewriting/test_math.pyTestArgmaxArgminMonotonictest class with comprehensive coverage:test_argmax_increasing_functions- Tests argmax rewrite for increasing functionstest_argmin_increasing_functions- Tests argmin rewrite for increasing functionstest_argmax_decreasing_functions- Tests argmax rewrite for decreasing functions (flips to argmin)test_argmin_decreasing_functions- Tests argmin rewrite for decreasing functions (flips to argmax)test_max_increasing_functions- Tests max rewrite for increasing functionstest_min_increasing_functions- Tests min rewrite for increasing functionstest_max_decreasing_functions- Tests max rewrite for decreasing functionstest_min_decreasing_functions- Tests min rewrite for decreasing functionsNone,0,-1)arccosinstead ofnegfor decreasing function tests to avoid confusion with argmin's internal representationExample
Argmax/Argmin Example
Max/Min Example
Performance Impact
These rewrites provide significant speedups when:
Testing
All 24 tests pass with various configurations:
None,0,-1)The rewrites correctly handle edge cases including:
Argmax(Neg(...))representation forargmin