Skip to content

Comments

Fix batch-and-skip benchmark exploit via per-call timing#104

Open
nataliakokoromyti wants to merge 3 commits intogpu-mode:mainfrom
nataliakokoromyti:fix/per-call-timing-anti-exploit
Open

Fix batch-and-skip benchmark exploit via per-call timing#104
nataliakokoromyti wants to merge 3 commits intogpu-mode:mainfrom
nataliakokoromyti:fix/per-call-timing-anti-exploit

Conversation

@nataliakokoromyti
Copy link

Summary

Fixes a benchmark exploit in eval_better_bench_grouped_gemm.py where a submission can batch all 15 custom_kernel() calls into a single GPU kernel launch and make 14/15 timed calls into no-ops (pure dict lookups returning cached results). This reports ~1/15th of the real per-call cost.

Why #102's fix is insufficient: The clone+shuffle approach in #102 breaks trivial id()-based caching, but a more sophisticated exploit uses a shape-matching fallback path that collects cloned data objects by problem shape and still batches them — the pointer-update path doesn't depend on stable id() values at all.

Changes

  1. Clone data each timing iteration — prevents object-identity caching
  2. Per-call CUDA events with GPU sync — each custom_kernel() call is individually timed with torch.cuda.synchronize() between calls, preventing work deferral across calls
  3. Per-call correctness check in recheck mode — if a submission skips the kernel and returns uncomputed tensors, the correctness check fails immediately (fixes the indentation bug where only the last call was checked)
  4. Local seed variable — avoids mutating test.args["seed"] across iterations

How the exploit works

The exploit:

  1. Learning phase (first 15 calls): Records each data object's id(), tensors, and results
  2. _build_superbatch(): Merges all 15 × 8 groups = 120 groups into a single kernel launch
  3. Fast path: On subsequent iterations, only the first id() triggers the batched kernel; the other 14 return pre-cached results (zero GPU work)
  4. Pointer-update fallback: When id() values change (e.g., after cloning), collects all 15 new objects by shape match, updates pointer tables, and still launches only once — defeating clone-based mitigations

Why this fix works

  • GPU sync between calls forces each call to either launch a kernel (measurable cost) or not (returns uncomputed results)
  • Per-call correctness check catches deferred computation — if a call returns without launching a kernel, its output tensors contain garbage and fail verification
  • The only viable strategy for a submission is to actually compute the result for each call independently — which is exactly what a legitimate kernel does

Test plan

  • Verify legitimate submissions produce same scores (per-call mean = batch mean for honest kernels)
  • Verify the known exploit kernel fails correctness in leaderboard mode
  • Check benchmark runtime overhead is acceptable (extra sync per call adds ~5μs × 15 = ~75μs per repeat)

…ess checks

The current eval times all 15 custom_kernel() calls as a single batch and
divides by 15. A malicious submission can exploit this by deferring all work
to one call (batching 15 problems into a single kernel launch) and making the
other 14 calls no-ops, reporting ~1/15th of the real per-call cost.

Cloning data alone (as proposed in gpu-mode#102) does not fully prevent this -- a
shape-matching fallback path can still collect new data objects and batch them.

This fix:
- Clones data each timing iteration (prevents object-identity caching)
- Times each call individually with its own CUDA events and GPU sync
  (prevents amortization across calls)
- Checks correctness after each individual call in recheck/leaderboard mode
  (catches deferred-computation exploits that return uncomputed tensors)
- Uses a local seed variable instead of mutating test.args
- Fixes the recheck indentation bug where only the last call was checked
@nataliakokoromyti nataliakokoromyti marked this pull request as draft February 22, 2026 10:20
@G-structure
Copy link

Hey @nataliakokoromyti — this is awesome, thanks for writing it up so clearly. The explanation of why #102’s clone+shuffle isn’t enough (shape-match + pointer-update path) is exactly right.

One thing I noticed when I ran 208fd03 against the same known-good kernel we’ve been comparing with: the per-call sync approach ends up being really expensive in leaderboard mode. I’m seeing a geomean around 266,298 ns (~266 µs) — about +1040% vs the lighter-hardened baseline (~23 µs). At that point we’re mostly benchmarking torch.cuda.synchronize() + event overhead rather than the kernel itself, and it risks turning the leaderboard into “who is least harmed by eval overhead” instead of “who’s closest to speed-of-light math throughput.”

One nuance on semantics: the per-call correctness checks guarantee “correct when checked,” but they don’t fully enforce call independence as a contract. A clever submission can still coordinate across calls (batching/deferral) as long as it lands the writes before the check. So it’s a strong hammer, but it’s costly and still not quite the clean independence guarantee we want.

A direction that seems both cheaper and more targeted is output fingerprint auditing (we’ve been experimenting with this and it’s been working well):

  • After a probed call returns, compute a lightweight fingerprint of the output buffers (sampled indices + random weights; computed on GPU so it stays in stream order).
  • After the full call window (after the usual sync), recompute the fingerprint on the same buffers.
  • If it changed → the output mutated after return → direct signal of deferred / cross-call writes (“return handle now, fill later”).

This hits the exploit mechanism directly (temporal integrity) instead of inferring cheating from timing skew. In our tests a fingerprint-based audit (f52ff4b style) caught the deferred-mutation exploit while keeping overhead much closer to baseline (~26.6 µs, ~+13.9%).

A couple notes / caveats on fingerprinting (worth us digging into together):

  • It assumes the call’s work lands on the current stream; if custom streams are allowed, we probably need to disallow them (common in these comps) or treat them as undefined for timing/integrity.
  • It’s probabilistic because we sample, but using 2 hashes + enough samples (e.g. 256–1024) makes collisions super unlikely.
  • We don’t need to fingerprint every call — probing 1–2 random positions per integrity repeat tends to be enough to catch “flush at end” patterns.

I also tried a couple follow-ups building on your approach:

  • 20fb8c3: randomized window length/order + sparse probes (more heuristic)
  • f52ff4b: fingerprint audit that directly catches post-return output mutation

Diff is here (easy to cherry-pick pieces): https://gist.github.com/G-structure/f9de3df9b051f43c06422ffd7a21a8dd

Down to pair on integrating this in a way that keeps leaderboard scoring “real,” with the stronger checks happening only on integrity repeats.

@ngc92
Copy link
Collaborator

ngc92 commented Feb 22, 2026

Hi,
I've been working on a new benchmark+test implementation, unfortunately not in time for the competition:
https://github.com/ngc92/pygpubench

I think it does avoid most of the problems mentioned above, and it tries to minimize the overhead of the benchmarking framework by implementing the main loop in C++, calling the user function through nanobind.
This also avoids malicious users messing with benchmarking data using the inspect module in python.

Note that the checking kernel is started using PDL to minimize the attack window, and checks the entries in the result in randomized order.

@G-structure G-structure force-pushed the fix/per-call-timing-anti-exploit branch from 208fd03 to 11fe446 Compare February 23, 2026 03:09
@G-structure G-structure force-pushed the fix/per-call-timing-anti-exploit branch from 11fe446 to 340e48e Compare February 23, 2026 03:11
@nataliakokoromyti nataliakokoromyti marked this pull request as ready for review February 23, 2026 03:21
@nataliakokoromyti
Copy link
Author

thanks @G-structure and @ngc92 for your help and thoughtful responses. idk what the timeline for migrating to cpp is (great idea) but till then sth like this pr ^ could be beneficial.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants