diff --git a/problems/nvidia/eval_better_bench_grouped_gemm.py b/problems/nvidia/eval_better_bench_grouped_gemm.py index 09b5279..92424a0 100644 --- a/problems/nvidia/eval_better_bench_grouped_gemm.py +++ b/problems/nvidia/eval_better_bench_grouped_gemm.py @@ -1,11 +1,13 @@ import base64 import dataclasses import multiprocessing +import random import re import time import os import sys import math +import random # Disable CuTe DSL file caching for more stable benchmarking os.environ["CUTE_DSL_DISABLE_FILE_CACHING"] = "1" @@ -173,6 +175,88 @@ def _clone_data(data): return data +def _collect_output_tensors(output): + """Collect tensors from nested output structure in deterministic order.""" + tensors = [] + + def _walk(x): + if isinstance(x, torch.Tensor): + tensors.append(x) + elif isinstance(x, (list, tuple)): + for y in x: + _walk(y) + elif isinstance(x, dict): + for k in sorted(x.keys()): + _walk(x[k]) + + _walk(output) + return tensors + + +def _make_fingerprint_plan(output, gen, samples_per_tensor: int = 256): + """ + Build a secret sampled hash plan for this output structure. + """ + tensors = _collect_output_tensors(output) + if not tensors: + return [] + + plan = [] + for t in tensors: + n = int(t.numel()) + s = min(samples_per_tensor, n) + if s <= 0: + plan.append((0, None, None, None)) + continue + idx = torch.randint(0, n, (s,), generator=gen, device=t.device, dtype=torch.int64) + w1 = torch.randint( + -(1 << 20), (1 << 20), (s,), generator=gen, device=t.device, dtype=torch.int32 + ).to(torch.float64) + w2 = torch.randint( + -(1 << 20), (1 << 20), (s,), generator=gen, device=t.device, dtype=torch.int32 + ).to(torch.float64) + plan.append((n, idx, w1, w2)) + return plan + + +def _fingerprint_output(output, plan): + """ + Compute a lightweight sampled fingerprint of output tensor contents. + + Returns two device scalars (h1, h2). If output changed post-return, the + fingerprint almost certainly changes too. + """ + tensors = _collect_output_tensors(output) + if len(tensors) != len(plan): + raise ValueError( + f"output structure changed: expected {len(plan)} tensors, got {len(tensors)}" + ) + + if not tensors: + z = torch.zeros((), dtype=torch.float64) + return z, z + + device = tensors[0].device + h1 = torch.zeros((), device=device, dtype=torch.float64) + h2 = torch.zeros((), device=device, dtype=torch.float64) + + for t, (expected_n, idx, w1, w2) in zip(tensors, plan): + n = int(t.numel()) + if n != expected_n: + raise ValueError(f"output tensor size changed: expected {expected_n}, got {n}") + if expected_n == 0: + continue + vals = t.reshape(-1).index_select(0, idx).to(torch.float64) + vals = torch.nan_to_num(vals, nan=0.0, posinf=1e6, neginf=-1e6) + h1 = h1 + (vals * w1).sum(dtype=torch.float64) + h2 = h2 + (vals * w2).sum(dtype=torch.float64) + return h1, h2 + + +def _fingerprint_equal(a, b) -> bool: + return torch.equal(a[0], b[0]) and torch.equal(a[1], b[1]) + + def _run_single_test(test: TestCase): """ Runs a single test case. Do not call directly @@ -242,15 +326,28 @@ def _run_single_benchmark( data_list = [] # generate input data once + local_seed = test.args.get("seed", None) for i in range(NUM_ITERATIONS_PER_BENCHMARK): - if "seed" in test.args: - test.args["seed"] += 42 - data = generate_input(**test.args) + if local_seed is not None: + local_seed += 42 + args = {**test.args, "seed": local_seed} + else: + args = test.args + data = generate_input(**args) data_list.append(data) check_copy = _clone_data(data_list) - - # first, one obligatory correctness check + # Deterministic but hidden probe stream. + # In benchmark mode we use randomized call windows and sparse probes. + # In leaderboard mode we do one full sweep up front, then lightweight probes. + probe_seed = _combine(int(test.args.get("seed", 0) or 0), 0x4D455452) + probe_rng = random.Random(probe_seed) + full_calls = len(data_list) + fp_gen = torch.Generator(device="cuda") + fp_seed = _combine(probe_seed, 0xF1A9E5) & ((1 << 63) - 1) + fp_gen.manual_seed(fp_seed) + + # First, one obligatory correctness check on fresh clones. outputs = [] try: for data in data_list: @@ -262,41 +359,114 @@ def _run_single_benchmark( good, message = check_implementation(reference_output, custom_output) if not good: return message + try: + fingerprint_plans = [_make_fingerprint_plan(out, fp_gen) for out in outputs] + except Exception as E: + return f"fingerprint plan build failed: {E}" - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 200 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. + # Timing: per-call intervals captured with CUDA events and one sync. + # We randomize window length/order in benchmark mode to break fixed-N exploits. + # Data is cloned each iteration to prevent object-identity caching. bm_start_time = time.perf_counter_ns() for i in range(max_repeats): + # Clone and shuffle data before timing to prevent both + # object-identity caching and call-order caching exploits + iteration_data = _clone_data(data_list) + shuffle_order = list(range(len(iteration_data))) + random.shuffle(shuffle_order) + iteration_data = [iteration_data[j] for j in shuffle_order] + torch.cuda.synchronize() + if recheck: + integrity_repeat = (i == 0) or (i % 20 == 0) + else: + integrity_repeat = (i < 3) or (i % 25 == 0) + + if recheck: + call_indices = list(range(full_calls)) + else: + call_indices = list(range(full_calls)) + probe_rng.shuffle(call_indices) + if integrity_repeat: + # Integrity repeats must exercise the full call window so + # flush-at-N exploits cannot hide behind short random windows. + n_calls = full_calls + else: + min_calls = max(4, full_calls - 6) + n_calls = probe_rng.randint(min_calls, full_calls) + call_indices = call_indices[:n_calls] + outputs = [] - clear_l2_cache() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - for data in data_list: - output = custom_kernel(data) - outputs.append(output) - end_event.record() + events = [torch.cuda.Event(enable_timing=True) for _ in range(len(call_indices) + 1)] + + if integrity_repeat and len(call_indices) <= 1: + in_loop_probe_pos = 0 if call_indices else None + elif integrity_repeat: + # Probe before last call to expose deferred-until-last behavior. + in_loop_probe_pos = probe_rng.randrange(0, len(call_indices) - 1) + else: + in_loop_probe_pos = None + + probe_snapshot = None + events[0].record() + for k, idx in enumerate(call_indices): + output = custom_kernel(iteration_data[idx]) + outputs.append((idx, output)) + events[k + 1].record() + + # Snapshot output state immediately after return; compare again after + # the full window to detect post-return deferred writes. + if in_loop_probe_pos is not None and k == in_loop_probe_pos: + try: + fp_before = _fingerprint_output(output, fingerprint_plans[idx]) + except Exception as E: + return f"fingerprint snapshot failed: {E}" + probe_snapshot = (idx, output, fp_before) torch.cuda.synchronize() - duration = ( - start_event.elapsed_time(end_event) / NUM_ITERATIONS_PER_BENCHMARK - ) * 1e6 # Convert ms to ns + if probe_snapshot is not None: + idx, probe_output, fp_before = probe_snapshot + try: + fp_after = _fingerprint_output(probe_output, fingerprint_plans[idx]) + except Exception as E: + return f"fingerprint verify failed: {E}" + torch.cuda.synchronize() + if not _fingerprint_equal(fp_before, fp_after): + return ( + "detected deferred/cross-call output mutation " + f"(call_index={idx}, window_calls={len(call_indices)})" + ) + + per_call_durations = [ + events[k].elapsed_time(events[k + 1]) * 1e6 for k in range(len(call_indices)) + ] + + # Correctness policy: + # - benchmark: sparse hidden integrity repeats + randomized windows/order. + # - leaderboard: sparse integrity repeats; first repeat gets full sweep. if recheck: - for reference_output, custom_output in zip(check_copy, outputs): - good, message = check_implementation(reference_output, custom_output) + if i == 0: + check_positions = list(range(len(outputs))) + else: + check_positions = [] + else: + check_positions = [] + + for pos in check_positions: + idx, output = outputs[pos] + good, message = check_implementation(check_copy[idx], output) if not good: return message - durations.append(duration) + duration = sum(per_call_durations) / len(call_indices) + if not integrity_repeat: + durations.append(duration) total_bm_duration = time.perf_counter_ns() - bm_start_time if ( - i > 1 and total_bm_duration > 1e8 + len(durations) > 1 and total_bm_duration > 1e8 ): # at least 2 runs, and at least 100 ms total time stats = calculate_stats(durations) # stop if either @@ -310,6 +480,9 @@ def _run_single_benchmark( ): break + if not durations: + return "benchmark produced no timing samples" + return calculate_stats(durations) @@ -518,8 +691,9 @@ def main(): break logger.log("check", "pass" if passed else "fail") + return 0 if passed else 112 elif mode == "profile": - run_profiling(logger, pool, tests) + return run_profiling(logger, pool, tests) else: # TODO: Implement script mode return 2