Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 199 additions & 25 deletions problems/nvidia/eval_better_bench_grouped_gemm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import base64
import dataclasses
import multiprocessing
import random
import re
import time
import os
import sys
import math
import random

# Disable CuTe DSL file caching for more stable benchmarking
os.environ["CUTE_DSL_DISABLE_FILE_CACHING"] = "1"
Expand Down Expand Up @@ -173,6 +175,88 @@ def _clone_data(data):
return data


def _collect_output_tensors(output):
"""Collect tensors from nested output structure in deterministic order."""
tensors = []

def _walk(x):
if isinstance(x, torch.Tensor):
tensors.append(x)
elif isinstance(x, (list, tuple)):
for y in x:
_walk(y)
elif isinstance(x, dict):
for k in sorted(x.keys()):
_walk(x[k])

_walk(output)
return tensors


def _make_fingerprint_plan(output, gen, samples_per_tensor: int = 256):
"""
Build a secret sampled hash plan for this output structure.
"""
tensors = _collect_output_tensors(output)
if not tensors:
return []

plan = []
for t in tensors:
n = int(t.numel())
s = min(samples_per_tensor, n)
if s <= 0:
plan.append((0, None, None, None))
continue
idx = torch.randint(0, n, (s,), generator=gen, device=t.device, dtype=torch.int64)
w1 = torch.randint(
-(1 << 20), (1 << 20), (s,), generator=gen, device=t.device, dtype=torch.int32
).to(torch.float64)
w2 = torch.randint(
-(1 << 20), (1 << 20), (s,), generator=gen, device=t.device, dtype=torch.int32
).to(torch.float64)
plan.append((n, idx, w1, w2))
return plan


def _fingerprint_output(output, plan):
"""
Compute a lightweight sampled fingerprint of output tensor contents.

Returns two device scalars (h1, h2). If output changed post-return, the
fingerprint almost certainly changes too.
"""
tensors = _collect_output_tensors(output)
if len(tensors) != len(plan):
raise ValueError(
f"output structure changed: expected {len(plan)} tensors, got {len(tensors)}"
)

if not tensors:
z = torch.zeros((), dtype=torch.float64)
return z, z

device = tensors[0].device
h1 = torch.zeros((), device=device, dtype=torch.float64)
h2 = torch.zeros((), device=device, dtype=torch.float64)

for t, (expected_n, idx, w1, w2) in zip(tensors, plan):
n = int(t.numel())
if n != expected_n:
raise ValueError(f"output tensor size changed: expected {expected_n}, got {n}")
if expected_n == 0:
continue
vals = t.reshape(-1).index_select(0, idx).to(torch.float64)
vals = torch.nan_to_num(vals, nan=0.0, posinf=1e6, neginf=-1e6)
h1 = h1 + (vals * w1).sum(dtype=torch.float64)
h2 = h2 + (vals * w2).sum(dtype=torch.float64)
return h1, h2


def _fingerprint_equal(a, b) -> bool:
return torch.equal(a[0], b[0]) and torch.equal(a[1], b[1])


def _run_single_test(test: TestCase):
"""
Runs a single test case. Do not call directly
Expand Down Expand Up @@ -242,15 +326,28 @@ def _run_single_benchmark(
data_list = []
# generate input data once

local_seed = test.args.get("seed", None)
for i in range(NUM_ITERATIONS_PER_BENCHMARK):
if "seed" in test.args:
test.args["seed"] += 42
data = generate_input(**test.args)
if local_seed is not None:
local_seed += 42
args = {**test.args, "seed": local_seed}
else:
args = test.args
data = generate_input(**args)
data_list.append(data)

check_copy = _clone_data(data_list)

# first, one obligatory correctness check
# Deterministic but hidden probe stream.
# In benchmark mode we use randomized call windows and sparse probes.
# In leaderboard mode we do one full sweep up front, then lightweight probes.
probe_seed = _combine(int(test.args.get("seed", 0) or 0), 0x4D455452)
probe_rng = random.Random(probe_seed)
full_calls = len(data_list)
fp_gen = torch.Generator(device="cuda")
fp_seed = _combine(probe_seed, 0xF1A9E5) & ((1 << 63) - 1)
fp_gen.manual_seed(fp_seed)

# First, one obligatory correctness check on fresh clones.
outputs = []
try:
for data in data_list:
Expand All @@ -262,41 +359,114 @@ def _run_single_benchmark(
good, message = check_implementation(reference_output, custom_output)
if not good:
return message
try:
fingerprint_plans = [_make_fingerprint_plan(out, fp_gen) for out in outputs]
except Exception as E:
return f"fingerprint plan build failed: {E}"

# now, do multiple timing runs without further correctness testing
# there is an upper bound of 200 runs, and a lower bound of 3 runs;
# otherwise, we repeat until we either measure at least 10 full seconds,
# or the relative error of the mean is below 1%.
# Timing: per-call intervals captured with CUDA events and one sync.
# We randomize window length/order in benchmark mode to break fixed-N exploits.
# Data is cloned each iteration to prevent object-identity caching.

bm_start_time = time.perf_counter_ns()
for i in range(max_repeats):
# Clone and shuffle data before timing to prevent both
# object-identity caching and call-order caching exploits
iteration_data = _clone_data(data_list)
shuffle_order = list(range(len(iteration_data)))
random.shuffle(shuffle_order)
iteration_data = [iteration_data[j] for j in shuffle_order]

torch.cuda.synchronize()

if recheck:
integrity_repeat = (i == 0) or (i % 20 == 0)
else:
integrity_repeat = (i < 3) or (i % 25 == 0)

if recheck:
call_indices = list(range(full_calls))
else:
call_indices = list(range(full_calls))
probe_rng.shuffle(call_indices)
if integrity_repeat:
# Integrity repeats must exercise the full call window so
# flush-at-N exploits cannot hide behind short random windows.
n_calls = full_calls
else:
min_calls = max(4, full_calls - 6)
n_calls = probe_rng.randint(min_calls, full_calls)
call_indices = call_indices[:n_calls]

outputs = []
clear_l2_cache()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for data in data_list:
output = custom_kernel(data)
outputs.append(output)
end_event.record()
events = [torch.cuda.Event(enable_timing=True) for _ in range(len(call_indices) + 1)]

if integrity_repeat and len(call_indices) <= 1:
in_loop_probe_pos = 0 if call_indices else None
elif integrity_repeat:
# Probe before last call to expose deferred-until-last behavior.
in_loop_probe_pos = probe_rng.randrange(0, len(call_indices) - 1)
else:
in_loop_probe_pos = None

probe_snapshot = None
events[0].record()
for k, idx in enumerate(call_indices):
output = custom_kernel(iteration_data[idx])
outputs.append((idx, output))
events[k + 1].record()

# Snapshot output state immediately after return; compare again after
# the full window to detect post-return deferred writes.
if in_loop_probe_pos is not None and k == in_loop_probe_pos:
try:
fp_before = _fingerprint_output(output, fingerprint_plans[idx])
except Exception as E:
return f"fingerprint snapshot failed: {E}"
probe_snapshot = (idx, output, fp_before)
torch.cuda.synchronize()
duration = (
start_event.elapsed_time(end_event) / NUM_ITERATIONS_PER_BENCHMARK
) * 1e6 # Convert ms to ns

if probe_snapshot is not None:
idx, probe_output, fp_before = probe_snapshot
try:
fp_after = _fingerprint_output(probe_output, fingerprint_plans[idx])
except Exception as E:
return f"fingerprint verify failed: {E}"
torch.cuda.synchronize()
if not _fingerprint_equal(fp_before, fp_after):
return (
"detected deferred/cross-call output mutation "
f"(call_index={idx}, window_calls={len(call_indices)})"
)

per_call_durations = [
events[k].elapsed_time(events[k + 1]) * 1e6 for k in range(len(call_indices))
]

# Correctness policy:
# - benchmark: sparse hidden integrity repeats + randomized windows/order.
# - leaderboard: sparse integrity repeats; first repeat gets full sweep.
if recheck:
for reference_output, custom_output in zip(check_copy, outputs):
good, message = check_implementation(reference_output, custom_output)
if i == 0:
check_positions = list(range(len(outputs)))
else:
check_positions = []
else:
check_positions = []

for pos in check_positions:
idx, output = outputs[pos]
good, message = check_implementation(check_copy[idx], output)
if not good:
return message

durations.append(duration)
duration = sum(per_call_durations) / len(call_indices)
if not integrity_repeat:
durations.append(duration)

total_bm_duration = time.perf_counter_ns() - bm_start_time
if (
i > 1 and total_bm_duration > 1e8
len(durations) > 1 and total_bm_duration > 1e8
): # at least 2 runs, and at least 100 ms total time
stats = calculate_stats(durations)
# stop if either
Expand All @@ -310,6 +480,9 @@ def _run_single_benchmark(
):
break

if not durations:
return "benchmark produced no timing samples"

return calculate_stats(durations)


Expand Down Expand Up @@ -518,8 +691,9 @@ def main():
break

logger.log("check", "pass" if passed else "fail")
return 0 if passed else 112
elif mode == "profile":
run_profiling(logger, pool, tests)
return run_profiling(logger, pool, tests)
else:
# TODO: Implement script mode
return 2
Expand Down