diff --git a/tests/unit/test_weight_processing.py b/tests/unit/test_weight_processing.py index cd04c70ab..169be8ce0 100644 --- a/tests/unit/test_weight_processing.py +++ b/tests/unit/test_weight_processing.py @@ -746,6 +746,11 @@ def test_fold_layer_no_adapter_transformer_lens_format(self, basic_config): """ cfg = basic_config cfg.n_layers = 1 # Test with single layer for simplicity + # Match config to the tensor dimensions used below + cfg.d_model = 4 + cfg.n_heads = 2 + cfg.d_head = 2 + cfg.d_mlp = 8 # Create a state dict with known values for deterministic testing state_dict = {} diff --git a/transformer_lens/benchmarks/__init__.py b/transformer_lens/benchmarks/__init__.py index 6996211c0..391639195 100644 --- a/transformer_lens/benchmarks/__init__.py +++ b/transformer_lens/benchmarks/__init__.py @@ -36,7 +36,12 @@ validate_hook_shape_compatibility, ) from transformer_lens.benchmarks.main_benchmark import run_benchmark_suite -from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity, PhaseReferenceData +from transformer_lens.benchmarks.text_quality import benchmark_text_quality +from transformer_lens.benchmarks.utils import ( + BenchmarkResult, + BenchmarkSeverity, + PhaseReferenceData, +) from transformer_lens.benchmarks.weight_processing import ( benchmark_weight_modification, benchmark_weight_processing, @@ -72,6 +77,8 @@ "benchmark_generation", "benchmark_generation_with_kv_cache", "benchmark_multiple_generation_calls", + # Text quality benchmarks + "benchmark_text_quality", # Weight processing benchmarks "benchmark_weight_processing", "benchmark_weight_sharing", diff --git a/transformer_lens/benchmarks/activation_cache.py b/transformer_lens/benchmarks/activation_cache.py index ebef781af..b37100785 100644 --- a/transformer_lens/benchmarks/activation_cache.py +++ b/transformer_lens/benchmarks/activation_cache.py @@ -6,7 +6,11 @@ from transformer_lens import HookedTransformer from transformer_lens.ActivationCache import ActivationCache -from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity +from transformer_lens.benchmarks.utils import ( + BenchmarkResult, + BenchmarkSeverity, + safe_allclose, +) from transformer_lens.model_bridge import TransformerBridge @@ -65,9 +69,10 @@ def benchmark_run_with_cache( if missing_patterns: return BenchmarkResult( name="run_with_cache", - severity=BenchmarkSeverity.WARNING, + severity=BenchmarkSeverity.DANGER, message=f"Cache missing expected patterns: {missing_patterns}", details={"missing": missing_patterns, "cache_keys_count": len(cache_keys)}, + passed=False, ) # Verify cached tensors are actually tensors @@ -79,9 +84,10 @@ def benchmark_run_with_cache( if non_tensor_keys: return BenchmarkResult( name="run_with_cache", - severity=BenchmarkSeverity.WARNING, + severity=BenchmarkSeverity.DANGER, message=f"Cache contains {len(non_tensor_keys)} non-tensor values", details={"non_tensor_keys": non_tensor_keys[:5]}, + passed=False, ) if reference_model is not None: @@ -175,9 +181,11 @@ def benchmark_activation_cache( continue # Check values - if not torch.allclose(bridge_tensor, reference_tensor, atol=tolerance, rtol=0): - max_diff = torch.max(torch.abs(bridge_tensor - reference_tensor)).item() - mean_diff = torch.mean(torch.abs(bridge_tensor - reference_tensor)).item() + if not safe_allclose(bridge_tensor, reference_tensor, atol=tolerance, rtol=0.0): + b = bridge_tensor.float() + r = reference_tensor.float() + max_diff = torch.max(torch.abs(b - r)).item() + mean_diff = torch.mean(torch.abs(b - r)).item() mismatches.append( f"{key}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}" ) diff --git a/transformer_lens/benchmarks/backward_gradients.py b/transformer_lens/benchmarks/backward_gradients.py index 60e9e21b8..f16a8d353 100644 --- a/transformer_lens/benchmarks/backward_gradients.py +++ b/transformer_lens/benchmarks/backward_gradients.py @@ -6,7 +6,11 @@ from transformer_lens import HookedTransformer from transformer_lens.benchmarks.hook_structure import validate_hook_shape_compatibility -from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity +from transformer_lens.benchmarks.utils import ( + BenchmarkResult, + BenchmarkSeverity, + safe_allclose, +) from transformer_lens.model_bridge import TransformerBridge @@ -167,14 +171,14 @@ def hook_fn(tensor, hook): if bridge_finite.numel() > 0 and reference_finite.numel() > 0: # Compare finite values - if not torch.allclose( + if not safe_allclose( bridge_finite, reference_finite, atol=abs_tolerance, rtol=rel_tolerance ): - max_diff = torch.max(torch.abs(bridge_finite - reference_finite)).item() - mean_diff = torch.mean(torch.abs(bridge_finite - reference_finite)).item() - rel_diff = torch.abs(bridge_finite - reference_finite) / ( - torch.abs(bridge_finite) + 1e-8 - ) + bf = bridge_finite.float() + rf = reference_finite.float() + max_diff = torch.max(torch.abs(bf - rf)).item() + mean_diff = torch.mean(torch.abs(bf - rf)).item() + rel_diff = torch.abs(bf - rf) / (torch.abs(bf) + 1e-8) mean_rel = rel_diff.mean().item() mismatches.append( f"{hook_name}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, mean_rel={mean_rel:.6f}" @@ -195,11 +199,13 @@ def hook_fn(tensor, hook): "hook_k", "ln1.hook_", "ln2.hook_", + "ln_final.hook_", "hook_resid_mid", "hook_resid_pre", "hook_resid_post", "hook_embed", "hook_pos_embed", + "unembed.hook_", "mlp.hook_post", "mlp.hook_pre", "hook_mlp_out", @@ -431,10 +437,12 @@ def hook_fn(tensor, hook): reference_finite = reference_grad[torch.isfinite(reference_grad)] if bridge_finite.numel() > 0 and reference_finite.numel() > 0: - if not torch.allclose( + if not safe_allclose( bridge_finite, reference_finite, atol=abs_tolerance, rtol=rel_tolerance ): - max_diff = torch.max(torch.abs(bridge_finite - reference_finite)).item() + max_diff = torch.max( + torch.abs(bridge_finite.float() - reference_finite.float()) + ).item() mismatches.append(f"{hook_name}: max_diff={max_diff:.6f}") if mismatches: diff --git a/transformer_lens/benchmarks/component_benchmark.py b/transformer_lens/benchmarks/component_benchmark.py index 787fbb09c..0e653baf8 100644 --- a/transformer_lens/benchmarks/component_benchmark.py +++ b/transformer_lens/benchmarks/component_benchmark.py @@ -6,294 +6,10 @@ from typing import Any, Optional -import torch - from transformer_lens.benchmarks.component_outputs import ComponentBenchmarker from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity -def benchmark_component_forward( - bridge_component: Any, - hf_component: Any, - test_input: torch.Tensor, - component_name: str, - atol: float = 1e-4, - rtol: float = 1e-4, -) -> BenchmarkResult: - """Benchmark forward pass equivalence for a single component. - - Args: - bridge_component: The bridge component to test - hf_component: The HuggingFace component to compare against - test_input: Input tensor for the component - component_name: Name of the component being tested - atol: Absolute tolerance for comparison - rtol: Relative tolerance for comparison - - Returns: - BenchmarkResult with the comparison results - """ - try: - # Run both components - with torch.no_grad(): - bridge_output = bridge_component(test_input) - hf_output = hf_component(test_input) - - # Extract tensors from outputs (handle both tensor and tuple outputs) - if isinstance(bridge_output, tuple): - bridge_tensor = bridge_output[0] - else: - bridge_tensor = bridge_output - - if isinstance(hf_output, tuple): - hf_tensor = hf_output[0] - else: - hf_tensor = hf_output - - # Compare outputs - if not torch.allclose(bridge_tensor, hf_tensor, atol=atol, rtol=rtol): - max_diff = (bridge_tensor - hf_tensor).abs().max().item() - mean_diff = (bridge_tensor - hf_tensor).abs().mean().item() - - return BenchmarkResult( - name=f"{component_name}_forward", - passed=False, - severity=BenchmarkSeverity.DANGER, - message=f"Component {component_name} outputs differ: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}", - details={ - "max_diff": max_diff, - "mean_diff": mean_diff, - "bridge_mean": bridge_tensor.mean().item(), - "hf_mean": hf_tensor.mean().item(), - }, - ) - - return BenchmarkResult( - name=f"{component_name}_forward", - passed=True, - severity=BenchmarkSeverity.INFO, - message=f"Component {component_name} produces equivalent outputs", - ) - - except Exception as e: - return BenchmarkResult( - name=f"{component_name}_forward", - passed=False, - severity=BenchmarkSeverity.ERROR, - message=f"Error testing component {component_name}: {str(e)}", - details={"exception": str(e)}, - ) - - -def benchmark_block_components( - bridge, - hf_model, - block_idx: int = 0, - atol: float = 1e-4, -) -> list[BenchmarkResult]: - """Benchmark all components within a transformer block. - - Args: - bridge: The TransformerBridge model - hf_model: The HuggingFace model - block_idx: Which block to test (default: 0) - atol: Absolute tolerance for comparison - - Returns: - List of BenchmarkResult for each component in the block - """ - results = [] - - try: - # Create test input - batch_size = 1 - seq_len = 4 - d_model = bridge.cfg.d_model - test_input = torch.randn(batch_size, seq_len, d_model, device=bridge.cfg.device) - - # Get the blocks - bridge_block = bridge.blocks[block_idx] - - # Get HF block using adapter - hf_blocks_path = bridge.adapter.component_mapping.blocks.name - hf_blocks = hf_model - for part in hf_blocks_path.split("."): - hf_blocks = getattr(hf_blocks, part) - hf_block = hf_blocks[block_idx] - - # Test attention component - if hasattr(bridge_block, "attn"): - bridge_attn = bridge_block.attn - - # Get HF attention - attn_path = bridge.adapter.component_mapping.blocks.submodules["attn"].name - if attn_path: - hf_attn = hf_block - for part in attn_path.split("."): - hf_attn = getattr(hf_attn, part) - - results.append( - benchmark_component_forward( - bridge_attn, - hf_attn, - test_input, - f"block_{block_idx}_attn", - atol=atol, - ) - ) - - # Test MLP component - if hasattr(bridge_block, "mlp"): - bridge_mlp = bridge_block.mlp - - # Get HF MLP - mlp_path = bridge.adapter.component_mapping.blocks.submodules["mlp"].name - if mlp_path: - hf_mlp = hf_block - for part in mlp_path.split("."): - hf_mlp = getattr(hf_mlp, part) - - results.append( - benchmark_component_forward( - bridge_mlp, - hf_mlp, - test_input, - f"block_{block_idx}_mlp", - atol=atol, - ) - ) - - # Test layer norms if present - if hasattr(bridge_block, "ln1"): - bridge_ln1 = bridge_block.ln1 - - # Get HF ln1 - ln1_path = bridge.adapter.component_mapping.blocks.submodules["ln1"].name - if ln1_path: - hf_ln1 = hf_block - for part in ln1_path.split("."): - hf_ln1 = getattr(hf_ln1, part) - - results.append( - benchmark_component_forward( - bridge_ln1, - hf_ln1, - test_input, - f"block_{block_idx}_ln1", - atol=atol, - ) - ) - - if hasattr(bridge_block, "ln2"): - bridge_ln2 = bridge_block.ln2 - - # Get HF ln2 - ln2_path = bridge.adapter.component_mapping.blocks.submodules["ln2"].name - if ln2_path: - hf_ln2 = hf_block - for part in ln2_path.split("."): - hf_ln2 = getattr(hf_ln2, part) - - results.append( - benchmark_component_forward( - bridge_ln2, - hf_ln2, - test_input, - f"block_{block_idx}_ln2", - atol=atol, - ) - ) - - except Exception as e: - results.append( - BenchmarkResult( - name=f"block_{block_idx}_components", - passed=False, - severity=BenchmarkSeverity.ERROR, - message=f"Error benchmarking block {block_idx} components: {str(e)}", - details={"exception": str(e)}, - ) - ) - - return results - - -def benchmark_attention_subcomponents( - bridge, - hf_model, - block_idx: int = 0, - atol: float = 1e-4, -) -> list[BenchmarkResult]: - """Benchmark attention subcomponents (Q, K, V, O projections). - - Args: - bridge: The TransformerBridge model - hf_model: The HuggingFace model - block_idx: Which block to test (default: 0) - atol: Absolute tolerance for comparison - - Returns: - List of BenchmarkResult for each attention subcomponent - """ - results = [] - - try: - # Create test input - batch_size = 1 - seq_len = 4 - d_model = bridge.cfg.d_model - test_input = torch.randn(batch_size, seq_len, d_model, device=bridge.cfg.device) - - # Get the attention components - bridge_attn = bridge.blocks[block_idx].attn - - # Get HF block - hf_blocks_path = bridge.adapter.component_mapping.blocks.name - hf_blocks = hf_model - for part in hf_blocks_path.split("."): - hf_blocks = getattr(hf_blocks, part) - hf_block = hf_blocks[block_idx] - - # Get HF attention - attn_path = bridge.adapter.component_mapping.blocks.submodules["attn"].name - hf_attn = hf_block - for part in attn_path.split("."): - hf_attn = getattr(hf_attn, part) - - # Test Q, K, V projections if they exist - for proj_name in ["q", "k", "v", "o"]: - if hasattr(bridge_attn, proj_name): - bridge_proj = getattr(bridge_attn, proj_name) - - # Try to get corresponding HF projection - hf_proj_name = f"{proj_name}_proj" - if hasattr(hf_attn, hf_proj_name): - hf_proj = getattr(hf_attn, hf_proj_name) - - results.append( - benchmark_component_forward( - bridge_proj, - hf_proj, - test_input, - f"block_{block_idx}_attn_{proj_name}", - atol=atol, - ) - ) - - except Exception as e: - results.append( - BenchmarkResult( - name=f"block_{block_idx}_attn_subcomponents", - passed=False, - severity=BenchmarkSeverity.ERROR, - message=f"Error benchmarking attention subcomponents: {str(e)}", - details={"exception": str(e)}, - ) - ) - - return results - - def benchmark_all_components( bridge, hf_model, diff --git a/transformer_lens/benchmarks/component_outputs.py b/transformer_lens/benchmarks/component_outputs.py index c631ea706..59f0442b8 100644 --- a/transformer_lens/benchmarks/component_outputs.py +++ b/transformer_lens/benchmarks/component_outputs.py @@ -751,16 +751,15 @@ def _compare_outputs( if bridge_output.shape != hf_output.shape: return False, float("inf"), float("inf"), {} - # Compute differences - diff = torch.abs(bridge_output - hf_output) + # Compute differences (upcast to float32 for safety) + bo = bridge_output.float() + ho = hf_output.float() + diff = torch.abs(bo - ho) max_diff = diff.max().item() mean_diff = diff.mean().item() # Compute percentile differences - # Convert to float32 for quantile computation (bfloat16 not supported) flat_diff = diff.flatten() - if flat_diff.dtype == torch.bfloat16 or flat_diff.dtype == torch.float16: - flat_diff = flat_diff.float() percentile_diffs = { "50th": torch.quantile(flat_diff, 0.5).item(), "90th": torch.quantile(flat_diff, 0.9).item(), @@ -768,7 +767,7 @@ def _compare_outputs( } # Check if within tolerance - passed = torch.allclose(bridge_output, hf_output, atol=self.atol, rtol=self.rtol) + passed = torch.allclose(bo, ho, atol=self.atol, rtol=self.rtol) return passed, max_diff, mean_diff, percentile_diffs diff --git a/transformer_lens/benchmarks/granular_weight_processing.py b/transformer_lens/benchmarks/granular_weight_processing.py index bee376d08..90379eed9 100644 --- a/transformer_lens/benchmarks/granular_weight_processing.py +++ b/transformer_lens/benchmarks/granular_weight_processing.py @@ -39,7 +39,7 @@ def __str__(self) -> str: return "+".join(flags) if flags else "none" -# Phase 4: Individual weight processing operations (test each flag in isolation) +# Phase 5: Individual weight processing operations (test each flag in isolation) # NOTE: Centering operations (center_writing_weights, center_unembed) require fold_ln=True # as they rely on LayerNorm ignoring the mean. Testing them without fold_ln produces # invalid/misleading results, so we test them with fold_ln enabled. @@ -82,7 +82,7 @@ def __str__(self) -> str: ), ] -# Phase 5: Combinations of weight processing operations +# Phase 6: Combinations of weight processing operations COMBINATION_CONFIGS = [ # Two-way combinations (fold_ln + one other) WeightProcessingConfig( @@ -185,8 +185,8 @@ def run_granular_weight_processing_benchmarks( ) -> Dict[str, List[BenchmarkResult]]: """Run benchmarks with each weight processing configuration. - This function tests each weight processing flag individually (Phase 4) and - in combination (Phase 5) to identify which specific processing steps cause issues. + This function tests each weight processing flag individually (Phase 5) and + in combination (Phase 6) to identify which specific processing steps cause issues. Args: model_name: Name of the model to benchmark @@ -194,7 +194,7 @@ def run_granular_weight_processing_benchmarks( test_text: Test text for generation/inference verbose: Whether to print detailed output include_refactor_tests: Whether to include experimental refactor_factored_attn_matrices tests - phase: Optional phase number (4 for individual, 5 for combinations). If None, runs both. + phase: Optional phase number (5 for individual, 6 for combinations). If None, runs both. Returns: Dictionary mapping config name to list of benchmark results @@ -247,18 +247,18 @@ def run_granular_weight_processing_benchmarks( configs_to_test = [] phase_name = "" - if phase is None or phase == 4: + if phase is None or phase == 5: configs_to_test.extend(INDIVIDUAL_CONFIGS) - if phase == 4: - phase_name = "PHASE 4: Individual Weight Processing Flags" + if phase == 5: + phase_name = "PHASE 5: Individual Weight Processing Flags" - if phase is None or phase == 5: + if phase is None or phase == 6: configs_to_test.extend(COMBINATION_CONFIGS) - if phase == 5: - phase_name = "PHASE 5: Combined Weight Processing Flags" + if phase == 6: + phase_name = "PHASE 6: Combined Weight Processing Flags" if phase is None: - phase_name = "PHASE 4 & 5: Granular Weight Processing" + phase_name = "PHASE 5 & 6: Granular Weight Processing" if include_refactor_tests: configs_to_test.extend(REFACTOR_ATTN_CONFIGS) @@ -268,9 +268,9 @@ def run_granular_weight_processing_benchmarks( print(phase_name) print(f"Model: {model_name}") print(f"Testing {len(configs_to_test)} configurations") - if phase is None or phase == 4: - print(f" Individual flags: {len(INDIVIDUAL_CONFIGS)}") if phase is None or phase == 5: + print(f" Individual flags: {len(INDIVIDUAL_CONFIGS)}") + if phase is None or phase == 6: print(f" Combinations: {len(COMBINATION_CONFIGS)}") if include_refactor_tests: print(f" Refactor tests: {len(REFACTOR_ATTN_CONFIGS)}") diff --git a/transformer_lens/benchmarks/hook_registration.py b/transformer_lens/benchmarks/hook_registration.py index f41ff60d2..dfa06772d 100644 --- a/transformer_lens/benchmarks/hook_registration.py +++ b/transformer_lens/benchmarks/hook_registration.py @@ -5,457 +5,45 @@ import torch from transformer_lens import HookedTransformer +from transformer_lens.benchmarks.hook_structure import validate_hook_shape_compatibility from transformer_lens.benchmarks.utils import ( BenchmarkResult, BenchmarkSeverity, compare_scalars, + safe_allclose, ) from transformer_lens.model_bridge import TransformerBridge - -def validate_hook_shape_compatibility( - target_shape: tuple, - reference_shape: tuple, - hook_name: str, - cross_model: bool = False, -) -> tuple[bool, Optional[str]]: - """Validate that hook shapes have compatible structure across different models. - - This allows comparing hooks from different models (e.g., Llama vs GPT-2) by checking - structural compatibility rather than exact shape matching. - - Args: - target_shape: Shape of the tensor from the target model - reference_shape: Shape of the tensor from the reference model - hook_name: Name of the hook (for error messages) - cross_model: If True, skip sequence dimension checks (different tokenizers - produce different token counts for the same text) - - Returns: - Tuple of (is_compatible, error_message) - - is_compatible: True if shapes are structurally compatible - - error_message: None if compatible, otherwise description of incompatibility - """ - # For GQA (Grouped Query Attention) models, k/v hooks may have different ranks - # GPT-2: (batch, seq, n_heads, d_head) = 4D - # Gemma/Llama with GQA: (batch, seq, d_head) = 3D (heads are already collapsed) - # This is expected and fine - both are valid attention representations - gqa_attention_hooks = ["hook_q", "hook_k", "hook_v", "hook_z"] - is_gqa_hook = any(pattern in hook_name for pattern in gqa_attention_hooks) - - # Attention pattern hooks have shape [batch, n_heads, seq_q, seq_k] - # Different models can have different numbers of heads - is_attention_pattern_hook = "hook_pattern" in hook_name or "hook_attn_scores" in hook_name - - # Same rank (number of dimensions) is required, except for GQA attention hooks - if len(target_shape) != len(reference_shape): - if is_gqa_hook: - # For GQA hooks, different ranks are okay - just verify batch and sequence dims match - if len(target_shape) >= 2 and len(reference_shape) >= 2: - if target_shape[0] != reference_shape[0]: - return ( - False, - f"Batch dimension mismatch: {target_shape[0]} vs {reference_shape[0]}", - ) - if not cross_model and target_shape[1] != reference_shape[1]: - return ( - False, - f"Sequence dimension mismatch: {target_shape[1]} vs {reference_shape[1]}", - ) - # Rank mismatch is fine for GQA - different attention implementations - return True, None - else: - return False, f"Invalid tensor rank: {len(target_shape)} or {len(reference_shape)}" - return False, f"Rank mismatch: {len(target_shape)} vs {len(reference_shape)}" - - # For each dimension, check compatibility - for i, (target_dim, ref_dim) in enumerate(zip(target_shape, reference_shape)): - if i == 0: # Batch dimension - # Should be same (both use same test input) - if target_dim != ref_dim: - return False, f"Batch dimension mismatch: {target_dim} vs {ref_dim}" - elif i == 1: # Usually sequence dimension, but n_heads for attention patterns - if is_attention_pattern_hook: - # For attention patterns: [batch, n_heads, seq_q, seq_k] - # Dimension 1 is n_heads, which can differ between models - # Just verify it's valid - if target_dim <= 0 or ref_dim <= 0: - return False, f"Invalid n_heads dimension: {target_dim} vs {ref_dim}" - else: - # For other hooks, dimension 1 is sequence - # Cross-model references may tokenize differently, so skip this check - if not cross_model and target_dim != ref_dim: - return False, f"Sequence dimension mismatch: {target_dim} vs {ref_dim}" - elif i >= 2 and is_attention_pattern_hook: - # For attention patterns, dimensions 2 and 3 are seq_q and seq_k - # Cross-model references may tokenize differently - if not cross_model and target_dim != ref_dim: - return False, f"Sequence dimension mismatch: {target_dim} vs {ref_dim}" - else: # Model-specific dimensions (d_model, n_heads, d_head, etc.) - # Can differ between models - just verify it's valid - if target_dim <= 0: - return False, f"Invalid dimension {i}: {target_dim} <= 0" - if ref_dim <= 0: - return False, f"Invalid reference dimension {i}: {ref_dim} <= 0" - - return True, None - - -def benchmark_forward_hooks_structure( - bridge: TransformerBridge, - test_text: str, - reference_model: Optional[HookedTransformer] = None, - prepend_bos: Optional[bool] = None, - cross_model: bool = False, -) -> BenchmarkResult: - """Benchmark forward hooks for structural correctness (existence, firing, shapes). - - This checks: - - All reference hooks exist in bridge - - Hooks can be registered - - Hooks fire during forward pass - - Hook tensor shapes are compatible (allows cross-model comparison) - - Args: - bridge: TransformerBridge model to test - test_text: Input text for testing - reference_model: Optional HookedTransformer for comparison - prepend_bos: Whether to prepend BOS token. If None, uses model default. - cross_model: If True, uses relaxed shape matching for cross-model comparison - - Returns: - BenchmarkResult with structural validation details - """ - try: - bridge_activations: Dict[str, torch.Tensor] = {} - reference_activations: Dict[str, torch.Tensor] = {} - - # Get all hook names - if reference_model is not None: - hook_names = list(reference_model.hook_dict.keys()) - else: - hook_names = list(bridge.hook_dict.keys()) - - # Register hooks on bridge and track missing hooks - def make_bridge_hook(name: str): - def hook_fn(tensor, hook): - if isinstance(tensor, torch.Tensor): - bridge_activations[name] = tensor.detach().clone() - elif isinstance(tensor, tuple) and len(tensor) > 0: - if isinstance(tensor[0], torch.Tensor): - bridge_activations[name] = tensor[0].detach().clone() - return tensor - - return hook_fn - - bridge_handles = [] - missing_from_bridge = [] - for hook_name in hook_names: - if hook_name in bridge.hook_dict: - hook_point = bridge.hook_dict[hook_name] - handle = hook_point.add_hook(make_bridge_hook(hook_name)) # type: ignore[func-returns-value] - bridge_handles.append((hook_name, handle)) - else: - missing_from_bridge.append(hook_name) - - # Run bridge forward pass - with torch.no_grad(): - if prepend_bos is not None: - _ = bridge(test_text, prepend_bos=prepend_bos) - else: - _ = bridge(test_text) - - # Clean up bridge hooks - for hook_name, handle in bridge_handles: - if handle is not None: - handle.remove() - - # Check for hooks that didn't fire - registered_hooks = {name for name, _ in bridge_handles} - hooks_that_didnt_fire = registered_hooks - set(bridge_activations.keys()) - - if reference_model is None: - # No reference - just verify hooks were captured - if hooks_that_didnt_fire: - return BenchmarkResult( - name="forward_hooks_structure", - severity=BenchmarkSeverity.WARNING, - message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} hooks didn't fire", - details={ - "captured": len(bridge_activations), - "registered": len(registered_hooks), - "didnt_fire": list(hooks_that_didnt_fire)[:10], - }, - ) - - return BenchmarkResult( - name="forward_hooks_structure", - severity=BenchmarkSeverity.INFO, - message=f"Bridge captured {len(bridge_activations)} forward hook activations", - details={"activation_count": len(bridge_activations)}, - ) - - # Register hooks on reference model - def make_reference_hook(name: str): - def hook_fn(tensor, hook): - if isinstance(tensor, torch.Tensor): - reference_activations[name] = tensor.detach().clone() - elif isinstance(tensor, tuple) and len(tensor) > 0: - if isinstance(tensor[0], torch.Tensor): - reference_activations[name] = tensor[0].detach().clone() - return tensor - - return hook_fn - - reference_handles = [] - for hook_name in hook_names: - if hook_name in reference_model.hook_dict: - hook_point = reference_model.hook_dict[hook_name] - handle = hook_point.add_hook(make_reference_hook(hook_name)) # type: ignore[func-returns-value] - reference_handles.append(handle) - - # Run reference forward pass - with torch.no_grad(): - if prepend_bos is not None: - _ = reference_model(test_text, prepend_bos=prepend_bos) - else: - _ = reference_model(test_text) - - # Clean up reference hooks - for handle in reference_handles: - if handle is not None: - handle.remove() - - # CRITICAL CHECK: Bridge must have all hooks that reference has - if missing_from_bridge: - return BenchmarkResult( - name="forward_hooks_structure", - severity=BenchmarkSeverity.DANGER, - message=f"Bridge MISSING {len(missing_from_bridge)} hooks from reference", - details={ - "missing_count": len(missing_from_bridge), - "missing_hooks": missing_from_bridge[:20], - "total_reference_hooks": len(hook_names), - }, - passed=False, - ) - - # CRITICAL CHECK: All registered hooks must fire - if hooks_that_didnt_fire: - return BenchmarkResult( - name="forward_hooks_structure", - severity=BenchmarkSeverity.DANGER, - message=f"{len(hooks_that_didnt_fire)} hooks DIDN'T FIRE during forward pass", - details={ - "didnt_fire_count": len(hooks_that_didnt_fire), - "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20], - "total_registered": len(registered_hooks), - }, - passed=False, - ) - - # Check shapes - common_hooks = set(bridge_activations.keys()) & set(reference_activations.keys()) - shape_mismatches = [] - - for hook_name in sorted(common_hooks): - bridge_tensor = bridge_activations[hook_name] - reference_tensor = reference_activations[hook_name] - - if cross_model: - # Use relaxed shape matching for cross-model comparison - is_compatible, error_msg = validate_hook_shape_compatibility( - bridge_tensor.shape, reference_tensor.shape, hook_name, cross_model=True - ) - if not is_compatible: - shape_mismatches.append(f"{hook_name}: {error_msg}") - else: - # Exact shape matching for same-model comparison - if bridge_tensor.shape != reference_tensor.shape: - shape_mismatches.append( - f"{hook_name}: Shape {bridge_tensor.shape} vs {reference_tensor.shape}" - ) - - if shape_mismatches: - return BenchmarkResult( - name="forward_hooks_structure", - severity=BenchmarkSeverity.DANGER, - message=f"Found {len(shape_mismatches)}/{len(common_hooks)} hooks with shape incompatibilities", - details={ - "total_hooks": len(common_hooks), - "shape_mismatches": len(shape_mismatches), - "sample_mismatches": shape_mismatches[:5], - "cross_model": cross_model, - }, - passed=False, - ) - - ref_type = "cross-model reference" if cross_model else "same-model reference" - return BenchmarkResult( - name="forward_hooks_structure", - severity=BenchmarkSeverity.INFO, - message=f"All {len(common_hooks)} forward hooks structurally compatible ({ref_type})", - details={"hook_count": len(common_hooks), "cross_model": cross_model}, - ) - - except Exception as e: - return BenchmarkResult( - name="forward_hooks_structure", - severity=BenchmarkSeverity.ERROR, - message=f"Forward hooks structure check failed: {str(e)}", - passed=False, - ) - - -def benchmark_forward_hooks_values( - bridge: TransformerBridge, - test_text: str, - reference_model: HookedTransformer, - tolerance: float = 0.5, - prepend_bos: Optional[bool] = None, -) -> BenchmarkResult: - """Benchmark forward hooks for value equivalence (requires same-model reference). - - This checks that hook activation values match between bridge and reference. - Should only be called when reference_model is the same architecture as bridge. - - Args: - bridge: TransformerBridge model to test - test_text: Input text for testing - reference_model: HookedTransformer reference (must be same architecture) - tolerance: Tolerance for activation matching - prepend_bos: Whether to prepend BOS token. If None, uses model default. - - Returns: - BenchmarkResult with value comparison details - """ - try: - bridge_activations: Dict[str, torch.Tensor] = {} - reference_activations: Dict[str, torch.Tensor] = {} - - hook_names = list(reference_model.hook_dict.keys()) - - # Register hooks on bridge - def make_bridge_hook(name: str): - def hook_fn(tensor, hook): - if isinstance(tensor, torch.Tensor): - bridge_activations[name] = tensor.detach().clone() - elif isinstance(tensor, tuple) and len(tensor) > 0: - if isinstance(tensor[0], torch.Tensor): - bridge_activations[name] = tensor[0].detach().clone() - return tensor - - return hook_fn - - bridge_handles = [] - for hook_name in hook_names: - if hook_name in bridge.hook_dict: - hook_point = bridge.hook_dict[hook_name] - handle = hook_point.add_hook(make_bridge_hook(hook_name)) # type: ignore[func-returns-value] - bridge_handles.append((hook_name, handle)) - - # Run bridge forward pass - with torch.no_grad(): - if prepend_bos is not None: - _ = bridge(test_text, prepend_bos=prepend_bos) - else: - _ = bridge(test_text) - - # Clean up bridge hooks - for hook_name, handle in bridge_handles: - if handle is not None: - handle.remove() - - # Register hooks on reference - def make_reference_hook(name: str): - def hook_fn(tensor, hook): - if isinstance(tensor, torch.Tensor): - reference_activations[name] = tensor.detach().clone() - elif isinstance(tensor, tuple) and len(tensor) > 0: - if isinstance(tensor[0], torch.Tensor): - reference_activations[name] = tensor[0].detach().clone() - return tensor - - return hook_fn - - reference_handles = [] - for hook_name in hook_names: - if hook_name in reference_model.hook_dict: - hook_point = reference_model.hook_dict[hook_name] - handle = hook_point.add_hook(make_reference_hook(hook_name)) # type: ignore[func-returns-value] - reference_handles.append(handle) - - # Run reference forward pass - with torch.no_grad(): - if prepend_bos is not None: - _ = reference_model(test_text, prepend_bos=prepend_bos) - else: - _ = reference_model(test_text) - - # Clean up reference hooks - for handle in reference_handles: - if handle is not None: - handle.remove() - - # Compare activation values - common_hooks = set(bridge_activations.keys()) & set(reference_activations.keys()) - value_mismatches = [] - - for hook_name in sorted(common_hooks): - bridge_tensor = bridge_activations[hook_name] - reference_tensor = reference_activations[hook_name] - - # Check values - if not torch.allclose(bridge_tensor, reference_tensor, atol=tolerance, rtol=0): - max_diff = torch.max(torch.abs(bridge_tensor - reference_tensor)).item() - mean_diff = torch.mean(torch.abs(bridge_tensor - reference_tensor)).item() - value_mismatches.append( - f"{hook_name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}" - ) - - if value_mismatches: - # Filter out known architectural differences - significant_mismatches = [ - m - for m in value_mismatches - if "hook_attn_scores" not in m # Exclude attn_scores (has inf from masking) - ] - - if significant_mismatches: - return BenchmarkResult( - name="forward_hooks_values", - severity=BenchmarkSeverity.DANGER, - message=f"Found {len(significant_mismatches)}/{len(common_hooks)} hooks with value mismatches", - details={ - "total_hooks": len(common_hooks), - "mismatches": len(significant_mismatches), - "sample_mismatches": significant_mismatches[:5], - "tolerance": tolerance, - }, - passed=False, - ) - else: - return BenchmarkResult( - name="forward_hooks_values", - severity=BenchmarkSeverity.WARNING, - message=f"All mismatches due to known differences ({len(value_mismatches)} hooks)", - details={"total_hooks": len(common_hooks), "tolerance": tolerance}, - ) - - return BenchmarkResult( - name="forward_hooks_values", - severity=BenchmarkSeverity.INFO, - message=f"All {len(common_hooks)} forward hooks match within tolerance", - details={"hook_count": len(common_hooks), "tolerance": tolerance}, - ) - - except Exception as e: - return BenchmarkResult( - name="forward_hooks_values", - severity=BenchmarkSeverity.ERROR, - message=f"Forward hooks values check failed: {str(e)}", - passed=False, - ) +# Hook patterns that bridge models inherently don't have because they wrap HF's +# native implementation. Used to filter expected missing/non-firing hooks. +_BRIDGE_EXPECTED_MISSING_PATTERNS = [ + "mlp.hook_pre", + "mlp.hook_post", + "hook_mlp_in", + "hook_mlp_out", + "attn.hook_rot_q", + "attn.hook_rot_k", + "hook_pos_embed", + "embed.ln.hook_scale", + "embed.ln.hook_normalized", + "attn.hook_q", + "attn.hook_k", + "attn.hook_v", + "hook_q_input", + "hook_k_input", + "hook_v_input", + "attn.hook_attn_scores", + "attn.hook_pattern", +] + + +def _filter_expected_missing(hook_names): + """Filter out hook names that bridge models are expected to be missing.""" + return [ + h + for h in hook_names + if not any(pattern in h for pattern in _BRIDGE_EXPECTED_MISSING_PATTERNS) + ] def benchmark_hook_registry( @@ -507,25 +95,9 @@ def benchmark_hook_registry( missing_hooks = reference_hooks - bridge_hooks extra_hooks = bridge_hooks - reference_hooks - # In cross-model mode, filter out hooks that are expected to differ - # due to architectural differences (e.g. fused QKV, rotary embeddings) - if cross_model and missing_hooks: - expected_missing_patterns = [ - "hook_pos_embed", - "attn.hook_q", - "attn.hook_k", - "attn.hook_v", - "hook_q_input", - "hook_k_input", - "hook_v_input", - "attn.hook_attn_scores", - "attn.hook_pattern", - ] - missing_hooks = { - h - for h in missing_hooks - if not any(pattern in h for pattern in expected_missing_patterns) - } + # Filter out hooks that are expected to differ due to architectural differences. + if missing_hooks: + missing_hooks = set(_filter_expected_missing(missing_hooks)) if missing_hooks: return BenchmarkResult( @@ -686,24 +258,9 @@ def hook_fn(tensor, hook): handle.remove() # CRITICAL CHECK: Bridge must have all hooks that reference has - # In cross-model mode, filter out expected architectural differences - if cross_model and missing_from_bridge: - expected_missing_patterns = [ - "hook_pos_embed", - "attn.hook_q", - "attn.hook_k", - "attn.hook_v", - "hook_q_input", - "hook_k_input", - "hook_v_input", - "attn.hook_attn_scores", - "attn.hook_pattern", - ] - missing_from_bridge = [ - h - for h in missing_from_bridge - if not any(pattern in h for pattern in expected_missing_patterns) - ] + # Filter out hooks that bridge models inherently don't have. + if missing_from_bridge: + missing_from_bridge = _filter_expected_missing(missing_from_bridge) if missing_from_bridge: return BenchmarkResult( @@ -719,26 +276,9 @@ def hook_fn(tensor, hook): ) # CRITICAL CHECK: All registered hooks must fire - # Filter out expected missing hooks in cross-model mode - if cross_model and hooks_that_didnt_fire: - # In cross-model mode, some hooks are expected to not fire due to architectural differences - expected_missing_patterns = [ - "hook_pos_embed", - "attn.hook_q", - "attn.hook_k", - "attn.hook_v", - "hook_q_input", - "hook_k_input", - "hook_v_input", - "attn.hook_attn_scores", - "attn.hook_pattern", - ] - actual_didnt_fire = [ - h - for h in hooks_that_didnt_fire - if not any(pattern in h for pattern in expected_missing_patterns) - ] - hooks_that_didnt_fire = set(actual_didnt_fire) + # Filter out hooks expected to not fire due to architectural differences. + if hooks_that_didnt_fire: + hooks_that_didnt_fire = set(_filter_expected_missing(hooks_that_didnt_fire)) if hooks_that_didnt_fire: return BenchmarkResult( @@ -774,17 +314,34 @@ def hook_fn(tensor, hook): # We only check that hooks exist, fire, and have compatible structure continue else: - # Use exact shape matching for same-model comparison + # Handle batch dimension differences: some HF models (e.g., OPT) + # internally reshape to 2D for MLP path, producing [seq, dim] hooks + # while HT always maintains [batch, seq, dim] if bridge_tensor.shape != reference_tensor.shape: - mismatches.append( - f"{hook_name}: Shape mismatch - Bridge{bridge_tensor.shape} vs Ref{reference_tensor.shape}" - ) - continue + if ( + bridge_tensor.ndim == reference_tensor.ndim - 1 + and reference_tensor.shape[0] == 1 + and bridge_tensor.shape == reference_tensor.shape[1:] + ): + bridge_tensor = bridge_tensor.unsqueeze(0) + elif ( + reference_tensor.ndim == bridge_tensor.ndim - 1 + and bridge_tensor.shape[0] == 1 + and reference_tensor.shape == bridge_tensor.shape[1:] + ): + reference_tensor = reference_tensor.unsqueeze(0) + else: + mismatches.append( + f"{hook_name}: Shape mismatch - Bridge{bridge_tensor.shape} vs Ref{reference_tensor.shape}" + ) + continue # Check values (only for same-model comparison) - if not torch.allclose(bridge_tensor, reference_tensor, atol=tolerance, rtol=0): - max_diff = torch.max(torch.abs(bridge_tensor - reference_tensor)).item() - mean_diff = torch.mean(torch.abs(bridge_tensor - reference_tensor)).item() + if not safe_allclose(bridge_tensor, reference_tensor, atol=tolerance, rtol=0.0): + b = bridge_tensor.float() + r = reference_tensor.float() + max_diff = torch.max(torch.abs(b - r)).item() + mean_diff = torch.mean(torch.abs(b - r)).item() mismatches.append( f"{hook_name}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}" ) @@ -973,45 +530,36 @@ def hook_fn(tensor, hook): # Skip value comparison for cross-model (different architectures have different values) # We only check that hooks exist, fire, and have compatible structure else: - # Use exact shape matching for same-model comparison + # Handle batch dimension differences (see forward_hooks) if bridge_tensor.shape != reference_tensor.shape: - mismatches.append( - f"{hook_name}: Shape mismatch - Bridge{bridge_tensor.shape} vs Ref{reference_tensor.shape}" - ) - continue + if ( + bridge_tensor.ndim == reference_tensor.ndim - 1 + and reference_tensor.shape[0] == 1 + and bridge_tensor.shape == reference_tensor.shape[1:] + ): + bridge_tensor = bridge_tensor.unsqueeze(0) + elif ( + reference_tensor.ndim == bridge_tensor.ndim - 1 + and bridge_tensor.shape[0] == 1 + and reference_tensor.shape == bridge_tensor.shape[1:] + ): + reference_tensor = reference_tensor.unsqueeze(0) + else: + mismatches.append( + f"{hook_name}: Shape mismatch - Bridge{bridge_tensor.shape} vs Ref{reference_tensor.shape}" + ) + continue # Only compare values for same-model comparison - if not torch.allclose(bridge_tensor, reference_tensor, atol=tolerance, rtol=0): - max_diff = torch.max(torch.abs(bridge_tensor - reference_tensor)).item() + if not safe_allclose(bridge_tensor, reference_tensor, atol=tolerance, rtol=0.0): + max_diff = torch.max( + torch.abs(bridge_tensor.float() - reference_tensor.float()) + ).item() mismatches.append(f"{hook_name}: max_diff={max_diff:.6f}") - # Check if bridge is missing critical hooks (BAD) - # Filter out expected missing hooks in cross-model mode - if cross_model and bridge_missing: - # In cross-model mode, some hooks are expected to be missing due to architectural differences - # For example, rotary embedding models (Gemma, LLaMA) don't have hook_pos_embed - # Hooks that may be missing due to architectural differences: - # - hook_pos_embed: rotary models don't have positional embeddings - # - hook_q/k/v: fused QKV architectures (maintain_native_attention) - # - hook_q/k/v_input: same reason - # - hook_attn_scores/pattern: native attention doesn't expose these - expected_missing_patterns = [ - "hook_pos_embed", - "attn.hook_q", - "attn.hook_k", - "attn.hook_v", - "hook_q_input", - "hook_k_input", - "hook_v_input", - "attn.hook_attn_scores", - "attn.hook_pattern", - ] - actual_missing = [ - h - for h in bridge_missing - if not any(pattern in h for pattern in expected_missing_patterns) - ] - bridge_missing = actual_missing + # Filter out hooks expected to be missing in bridge models. + if bridge_missing: + bridge_missing = _filter_expected_missing(bridge_missing) if bridge_missing: return BenchmarkResult( @@ -1124,7 +672,7 @@ def benchmark_hook_functionality( def ablation_hook(activation, hook): # Zero out an attention head in layer 0 - # Clone to avoid in-place modification of a view from a custom Function + # Clone to avoid in-place modification of autograd views activation = activation.clone() # For GQA models, the head dimension may be smaller than n_heads n_heads = activation.shape[2] diff --git a/transformer_lens/benchmarks/main_benchmark.py b/transformer_lens/benchmarks/main_benchmark.py index 16071f759..9b06d15e6 100644 --- a/transformer_lens/benchmarks/main_benchmark.py +++ b/transformer_lens/benchmarks/main_benchmark.py @@ -1,11 +1,13 @@ """Main benchmark runner for TransformerBridge. This module provides the main benchmark suite that compares TransformerBridge -against reference implementations in an optimized 4-phase approach: +against reference implementations in an optimized multi-phase approach: Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing -Phase 4: Granular Weight Processing Tests (optional) +Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2 Medium +Phase 5: Granular Weight Processing Tests (optional, individual flags) +Phase 6: Granular Weight Processing Tests (optional, combined flags) """ import gc @@ -44,6 +46,7 @@ from transformer_lens.benchmarks.hook_structure import ( benchmark_activation_cache_structure, ) +from transformer_lens.benchmarks.text_quality import benchmark_text_quality from transformer_lens.benchmarks.utils import ( BenchmarkResult, BenchmarkSeverity, @@ -121,18 +124,10 @@ def get_auto_model_class(model_name: str, trust_remote_code: bool = False): def _fixup_custom_model(hf_model) -> None: - """Apply post-load fixups for models with custom code. - - Some custom models (e.g., OpenELM) have non-persistent buffers (inv_freq, - causal_mask) that may be zeroed during HuggingFace's meta-device loading. - This function recomputes broken buffers to minimize forward pass divergence - against the bridge model. - - Note: The bridge model goes through a more thorough initialization via the - adapter's prepare_loading() + prepare_model() lifecycle hooks. Any remaining - forward pass divergence is an inherent consequence of different loading paths - for custom-code models, not a bridge correctness issue (all individual - components produce identical output, and hooks have zero numerical impact). + """Apply post-load fixups for models with custom code (e.g., OpenELM). + + Recomputes non-persistent buffers (inv_freq, causal_mask) that may be + zeroed during HuggingFace's meta-device loading. """ # OpenELM fixups if hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "layers"): @@ -190,6 +185,7 @@ def run_comparison_benchmarks( verbose: bool = True, gpt2_reference: Optional[HookedTransformer] = None, phase1_reference: Optional[PhaseReferenceData] = None, + restore_dtype_after_equivalence: Optional[torch.dtype] = None, ) -> List[BenchmarkResult]: """Run standardized comparison benchmarks between Bridge and reference model. @@ -205,6 +201,9 @@ def run_comparison_benchmarks( verbose: Whether to print detailed results gpt2_reference: Optional GPT-2 reference for cross-model validation phase1_reference: Optional saved Phase 1 HF reference data for equivalence testing + restore_dtype_after_equivalence: If set, downcast bridge_model to this dtype after + the equivalence comparison but before hook/cache/gradient tests. Used when the + bridge was upcast to float32 for precise equivalence testing. Returns: List of BenchmarkResult objects @@ -279,6 +278,8 @@ def add_result(result: BenchmarkResult) -> None: if verbose: print("2. Model Equivalence Benchmarks (Forward Pass)") + has_phase1_ref = phase1_reference is not None and phase1_reference.hf_logits is not None + if ht_available: try: add_result( @@ -293,29 +294,35 @@ def add_result(result: BenchmarkResult) -> None: except Exception as e: if verbose: print(f"✗ Equivalence benchmark failed: {e}\n") - elif phase1_reference is not None and phase1_reference.hf_logits is not None: - # Use saved Phase 1 bridge logits/loss as ground truth. - # Weight processing should be mathematically equivalent, so the processed - # bridge should produce the same output as the unprocessed bridge. - # - # Important: center_unembed intentionally shifts raw logits by a per-position - # constant (softmax-invariant). We compare log_softmax to be invariant to this. + elif has_phase1_ref: + # Compare processed bridge against unprocessed Phase 1 reference. + # We use log_softmax because center_unembed shifts raw logits by a + # softmax-invariant constant. Both passes run in float32 (no bf16 round-trip). try: if verbose: print("Using saved Phase 1 bridge reference for equivalence comparison") - # Compare log_softmax instead of raw logits to be centering-invariant. - # center_unembed shifts all vocab logits at each position by a constant, - # which changes raw logits but preserves log-probabilities. + + assert phase1_reference is not None + assert phase1_reference.hf_logits is not None + + # Compare log_softmax (centering-invariant) instead of raw logits. bridge_logits = bridge_model(test_text, return_type="logits") ref_logits = phase1_reference.hf_logits.to(bridge_logits.device) bridge_log_probs = torch.nn.functional.log_softmax(bridge_logits, dim=-1) ref_log_probs = torch.nn.functional.log_softmax(ref_logits, dim=-1) + + # Both passes in float32 — remaining error is float32 non-associativity + # in weight processing (~0.006 max_diff on 24-layer Qwen2). + logits_atol = 0.01 + logits_rtol = 1e-4 + loss_atol = 1e-3 + add_result( compare_tensors( bridge_log_probs, ref_log_probs, - atol=1e-4, - rtol=1e-4, + atol=logits_atol, + rtol=logits_rtol, name="logits_equivalence", ) ) @@ -325,7 +332,7 @@ def add_result(result: BenchmarkResult) -> None: bridge_model, test_text, reference_loss=phase1_reference.hf_loss, - atol=1e-3, + atol=loss_atol, ) ) else: @@ -354,6 +361,16 @@ def add_result(result: BenchmarkResult) -> None: ) ) + # Restore native dtype so remaining tests run in the model's real dtype. + if restore_dtype_after_equivalence is not None: + try: + bridge_model.to(restore_dtype_after_equivalence) + if verbose: + print(f" (restored to {restore_dtype_after_equivalence} for remaining tests)\n") + except Exception as e: + if verbose: + print(f"⚠ Could not restore dtype: {e}\n") + # ======================================================================== # 3. Hook Registration Benchmarks # Tests hooks exist and are registered - depends on model structure @@ -653,14 +670,15 @@ def run_benchmark_suite( ) -> List[BenchmarkResult]: """Run comprehensive benchmark suite for TransformerBridge. - This function implements an optimized 5-phase approach to minimize model reloading: + This function implements an optimized multi-phase approach to minimize model reloading: Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing - Phase 4: Individual Weight Processing Flags (optional) - Phase 5: Combined Weight Processing Flags (optional) + Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2 Medium + Phase 5: Individual Weight Processing Flags (optional) + Phase 6: Combined Weight Processing Flags (optional) - When test_weight_processing_individually=True, Phases 4 & 5 run after + When test_weight_processing_individually=True, Phases 5 & 6 run after Phase 3, testing each weight processing flag individually and in combinations. Args: @@ -715,44 +733,42 @@ def get_memory_mb(): print(f"Device: {device}") print(f"{'='*80}\n") - # Early exit if only running Phase 4/5 (they load their own models independently) - if phases is not None and all(p in [4, 5] for p in phases): + # Early exit if only running Phase 5/6 (they load their own models independently) + if phases is not None and all(p in [5, 6] for p in phases): if verbose: - print(f"Skipping Phase 1-3 (only running Phase {', '.join(map(str, sorted(phases)))})") - print("Phase 4/5 load their own models independently\n") + print(f"Skipping Phase 1-4 (only running Phase {', '.join(map(str, sorted(phases)))})") + print("Phase 5/6 load their own models independently\n") - # Jump directly to Phase 4/5 - # Jump to granular testing from transformer_lens.benchmarks.granular_weight_processing import ( run_granular_weight_processing_benchmarks, ) - if 4 in phases and test_weight_processing_individually and enable_compatibility_mode: - phase4_results = run_granular_weight_processing_benchmarks( + if 5 in phases and test_weight_processing_individually and enable_compatibility_mode: + phase5_results = run_granular_weight_processing_benchmarks( model_name=model_name, device=device, test_text=test_text, verbose=verbose, - phase=4, + phase=5, ) - for config_name, config_results in phase4_results.items(): + for config_name, config_results in phase5_results.items(): for result in config_results: - result.phase = 4 + result.phase = 5 results.append(result) if verbose: result.print_immediate() - if 5 in phases and test_weight_processing_individually and enable_compatibility_mode: - phase5_results = run_granular_weight_processing_benchmarks( + if 6 in phases and test_weight_processing_individually and enable_compatibility_mode: + phase6_results = run_granular_weight_processing_benchmarks( model_name=model_name, device=device, test_text=test_text, verbose=verbose, - phase=5, + phase=6, ) - for config_name, config_results in phase5_results.items(): + for config_name, config_results in phase6_results.items(): for result in config_results: - result.phase = 5 + result.phase = 6 results.append(result) if verbose: result.print_immediate() @@ -806,7 +822,7 @@ def cleanup_model(model, model_name_str: str): if track_memory and memory_tracker is not None: memory_before = get_memory_mb() - # NEW: Move model to CPU first to free GPU memory immediately + # Move model to CPU first to free GPU memory immediately if device != "cpu" and hasattr(model, "cpu"): try: model.cpu() @@ -841,7 +857,7 @@ def cleanup_model(model, model_name_str: str): if hasattr(module, "remove_all_hooks"): module.remove_all_hooks() - # NEW: Clear gradients + # Clear gradients if hasattr(module, "zero_grad"): try: module.zero_grad(set_to_none=True) @@ -859,15 +875,14 @@ def cleanup_model(model, model_name_str: str): if hasattr(model, "_forward_pre_hooks"): model._forward_pre_hooks.clear() - # NEW: Clear top-level gradients + # Clear top-level gradients if hasattr(model, "zero_grad"): try: model.zero_grad(set_to_none=True) except Exception: pass - # OPTIMIZATION: Break circular references more aggressively - # Clear all submodule references to help GC + # Break circular references to help GC if hasattr(model, "_modules"): # Clear each submodule's __dict__ to break circular references for name, submodule in list(model._modules.items()): @@ -886,7 +901,6 @@ def cleanup_model(model, model_name_str: str): for param_name in list(model._parameters.keys()): param = model._parameters[param_name] if param is not None: - # NEW: Delete parameter tensor del param model._parameters[param_name] = None model._parameters.clear() @@ -896,7 +910,6 @@ def cleanup_model(model, model_name_str: str): for buffer_name in list(model._buffers.keys()): buffer = model._buffers[buffer_name] if buffer is not None: - # NEW: Delete buffer tensor del buffer model._buffers[buffer_name] = None model._buffers.clear() @@ -910,7 +923,6 @@ def cleanup_model(model, model_name_str: str): # Clear CUDA cache if using GPU if device != "cpu" and torch.cuda.is_available(): torch.cuda.empty_cache() - # NEW: Synchronize to ensure GPU operations complete torch.cuda.synchronize() # Track memory after cleanup @@ -948,13 +960,10 @@ def cleanup_model(model, model_name_str: str): try: # Load a lightweight version without weights to get config bridge_config_only = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, load_weights=False, trust_remote_code=trust_remote_code) # type: ignore[attr-defined] - # Extract attn_implementation for HF model loading. - # First check if adapter explicitly sets it (e.g. qwen3, gemma3). + # Match bridge's attn_implementation: check adapter config first, then + # default to "eager" (bridge uses output_attentions=True which forces eager). if hasattr(bridge_config_only.adapter.cfg, "attn_implementation"): attn_implementation = bridge_config_only.adapter.cfg.attn_implementation - # TransformerBridge always loads HF models with output_attentions=True - # (see sources/transformers.py), which causes HF to fall back from SDPA - # to eager attention. We must match this in the reference model. if attn_implementation is None: attn_implementation = "eager" if verbose: @@ -965,9 +974,8 @@ def cleanup_model(model, model_name_str: str): except Exception as e: if verbose: print(f"⚠ Could not detect config (will use defaults): {str(e)}") - # For custom code models, the config-only bridge may fail. We still need to - # apply architecture-specific patches (e.g., OpenELM _init_weights fix) before - # loading any model, otherwise _init_weights may re-randomize loaded weights. + # Config-only bridge failed; apply architecture patches directly to prevent + # _init_weights from re-randomizing loaded weights. if trust_remote_code: try: from transformer_lens.model_bridge.sources.transformers import ( @@ -995,9 +1003,7 @@ def cleanup_model(model, model_name_str: str): try: if verbose: print("Loading HuggingFace reference model...") - # Match loading path to TransformerBridge: no device_map, explicit .to(device) - # Using device_map causes different weight materialization than .to(device), - # which produces numerical divergence for bfloat16 models. + # Match bridge loading path: no device_map, explicit .to(device). hf_kwargs = { "low_cpu_mem_usage": True, # Reduce memory spikes during loading } @@ -1009,9 +1015,7 @@ def cleanup_model(model, model_name_str: str): auto_model_class = get_auto_model_class(model_name, trust_remote_code=trust_remote_code) if verbose and auto_model_class != AutoModelForCausalLM: print(f"Using {auto_model_class.__name__} for encoder-decoder model") - # Ensure pad_token_id exists on HF config. Transformers v5 raises - # AttributeError for missing config attributes, which crashes models - # like StableLM that access config.pad_token_id during __init__. + # Ensure pad_token_id exists (some models crash without it during init). hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: hf_config.pad_token_id = getattr(hf_config, "eos_token_id", None) @@ -1020,11 +1024,9 @@ def cleanup_model(model, model_name_str: str): hf_kwargs["trust_remote_code"] = True hf_model = auto_model_class.from_pretrained(model_name, **hf_kwargs) # type: ignore[arg-type] # Post-load fixup for custom code models (e.g., OpenELM). - # NOTE: We intentionally use _fixup_custom_model instead of the adapter's - # prepare_model here. The adapter's prepare_model unconditionally recomputes - # non-persistent buffers (causal_mask, inv_freq) which is needed for the - # bridge path (meta-device loading), but the reference model loads normally - # on CPU with correct buffers. Recomputing them can introduce numeric drift. + # Uses _fixup_custom_model (not adapter.prepare_model) because the reference + # model loads normally on CPU — unconditional buffer recomputation would + # introduce numeric drift. _fixup_custom_model(hf_model) hf_model = hf_model.to(device) hf_model.eval() @@ -1035,6 +1037,21 @@ def cleanup_model(model, model_name_str: str): print(f"Detected dtype={bridge_dtype}") except StopIteration: pass + # Float16 introduces too much rounding error for benchmarking; upcast. + if bridge_dtype == torch.float16: + if verbose: + print("⚠ Float16 detected, upcasting to float32 for benchmarking...") + del hf_model + gc.collect() + bridge_dtype = torch.float32 + hf_model = auto_model_class.from_pretrained( + model_name, torch_dtype=torch.float32, **hf_kwargs + ) + _fixup_custom_model(hf_model) + hf_model = hf_model.to(device) + hf_model.eval() + if verbose: + print("✓ Reloaded in float32") if verbose: print("✓ HuggingFace model loaded\n") except Exception as e: @@ -1098,23 +1115,27 @@ def cleanup_model(model, model_name_str: str): if verbose: print(f"✗ Forward pass benchmark failed: {e}\n") - # Capture unprocessed bridge reference data for Phase 3 reuse. - # We save the BRIDGE's logits/loss (not the HF model's), because the bridge - # forward path may differ slightly from HF. Phase 3 tests whether weight - # processing preserves the bridge's own output — comparing processed bridge - # vs unprocessed bridge. + # Capture Phase 1 reference in float32 for Phase 3 equivalence comparison. + # Running both passes in float32 avoids bf16 forward-pass amplification. if bridge_unprocessed is not None: try: + original_dtype = bridge_unprocessed.cfg.dtype + needs_upcast = original_dtype not in (torch.float32, torch.float64) + if needs_upcast: + bridge_unprocessed.to(torch.float32) with torch.no_grad(): bridge_logits = bridge_unprocessed(test_text, return_type="logits") phase1_reference.hf_logits = bridge_logits.detach().cpu().clone() bridge_loss = bridge_unprocessed(test_text, return_type="loss") phase1_reference.hf_loss = bridge_loss.item() phase1_reference.test_text = test_text + if needs_upcast: + bridge_unprocessed.to(original_dtype) if verbose: + dtype_note = " (upcast to float32)" if needs_upcast else "" print( f"✓ Saved Phase 1 reference data " - f"(logits: {phase1_reference.hf_logits.shape})" + f"(logits: {phase1_reference.hf_logits.shape}){dtype_note}" ) except Exception as e: if verbose: @@ -1175,6 +1196,14 @@ def cleanup_model(model, model_name_str: str): message="Skipped (encoder-decoder model)", ) ) + add_result( + BenchmarkResult( + name="text_quality", + severity=BenchmarkSeverity.INFO, + passed=True, + message="Skipped (encoder-decoder model)", + ) + ) else: try: add_result(benchmark_generation(bridge_unprocessed, test_text, max_new_tokens=10)) @@ -1199,6 +1228,33 @@ def cleanup_model(model, model_name_str: str): if verbose: print(f"✗ Generation benchmark failed: {e}\n") + # Phase 4: Text Quality (runs in Phase 2 memory window) + if should_run_phase(4): + try: + if verbose: + print("\n2. Text Quality Benchmark (Phase 4)") + text_quality_result = benchmark_text_quality( + bridge_unprocessed, + test_text, + max_new_tokens=50, + scoring_model_name="gpt2-medium", + pass_threshold=85.0, + device=device, + ) + text_quality_result.phase = 4 + add_result(text_quality_result) + gc.collect() + except Exception as e: + if verbose: + print(f"✗ Text quality benchmark failed: {e}\n") + + # Match bridge's default_prepend_bos setting in HookedTransformer. + ht_prepend_bos = None + if bridge_unprocessed is not None and hasattr(bridge_unprocessed, "cfg"): + bridge_bos = getattr(bridge_unprocessed.cfg, "default_prepend_bos", None) + if bridge_bos is not None: + ht_prepend_bos = bridge_bos + # Load HookedTransformer for comparison (after generation benchmarks) ht_model_unprocessed = None if should_run_phase(2) and use_ht_reference: @@ -1214,6 +1270,7 @@ def cleanup_model(model, model_name_str: str): center_unembed=False, fold_value_biases=False, refactor_factored_attn_matrices=False, + default_prepend_bos=ht_prepend_bos, ) if verbose: print("✓ HookedTransformer loaded (unprocessed)\n") @@ -1247,11 +1304,8 @@ def cleanup_model(model, model_name_str: str): if ht_model_unprocessed is not None: cleanup_model(ht_model_unprocessed, "HookedTransformer (unprocessed)") ht_model_unprocessed = None - # NOTE: bridge_unprocessed is intentionally kept alive for Phase 3. - # Instead of loading a fresh bridge (which can produce non-deterministic - # outputs for some architectures like OpenELM), we reuse the same instance - # and process its weights in-place. This ensures Phase 3 tests purely - # measure the effect of weight processing, not loading variability. + # bridge_unprocessed is kept alive for Phase 3 — reusing the same instance + # avoids non-deterministic loading in some architectures (e.g., OpenELM). # ======================================================================== # PHASE 3: Bridge (processed) + HookedTransformer (processed) @@ -1295,17 +1349,25 @@ def _cleanup_bridge_unprocessed(): bridge_processed = None ht_model_processed = None - # Reuse the Phase 1 bridge instance for Phase 3 instead of loading a fresh one. - # This avoids non-deterministic loading issues (some architectures like OpenELM - # produce different outputs across separate from_pretrained calls despite - # identical parameters and buffers). Processing weights in-place on the same - # instance ensures Phase 3 purely measures weight processing equivalence. + # Reuse the Phase 1 bridge instance and process weights in-place. + # For reduced-precision models, upcast to float32 BEFORE processing so + # weight operations avoid bf16 quantization round-trips. The bridge is + # downcast back to native dtype after the equivalence comparison. + phase3_native_dtype = None # Set if we upcast; used to restore later if bridge_unprocessed is not None: try: if verbose: print("Processing weights on existing bridge (reusing Phase 1 instance)...") bridge_processed = bridge_unprocessed bridge_unprocessed = None # Transfer ownership + # Upcast to float32 before processing to avoid bf16 quantization loss. + phase3_native_dtype = bridge_processed.cfg.dtype + if phase3_native_dtype not in (torch.float32, torch.float64): + bridge_processed.to(torch.float32) + if verbose: + print(f" (upcast from {phase3_native_dtype} to float32 before processing)") + else: + phase3_native_dtype = None # No restore needed bridge_processed.enable_compatibility_mode(disable_warnings=True) if verbose: print("✓ TransformerBridge compatibility mode enabled (processed)\n") @@ -1408,6 +1470,7 @@ def _cleanup_bridge_unprocessed(): center_unembed=True, fold_value_biases=True, refactor_factored_attn_matrices=False, + default_prepend_bos=ht_prepend_bos, ) if verbose: print("✓ HookedTransformer loaded (processed)\n") @@ -1451,6 +1514,7 @@ def _cleanup_bridge_unprocessed(): verbose=verbose, gpt2_reference=gpt2_reference, # Use GPT-2 cross-model ref if no same-arch HT phase1_reference=phase1_reference, # Saved HF logits/loss for equivalence testing + restore_dtype_after_equivalence=phase3_native_dtype, ) # Tag all phase 3 results with phase number for result in phase3_results: @@ -1467,12 +1531,12 @@ def _cleanup_bridge_unprocessed(): ht_model_processed = None # ======================================================================== - # Phase 4: Granular Weight Processing Tests (Optional) + # Phase 5/6: Granular Weight Processing Tests (Optional) # ======================================================================== if test_weight_processing_individually and enable_compatibility_mode: if verbose: print("\n" + "=" * 80) - print("PHASE 4: GRANULAR WEIGHT PROCESSING TESTS") + print("PHASE 5/6: GRANULAR WEIGHT PROCESSING TESTS") print("=" * 80) print("Testing each weight processing flag individually and in combinations") print("to isolate which specific processing steps cause issues.") @@ -1499,7 +1563,7 @@ def _cleanup_bridge_unprocessed(): if verbose: print("\n" + "=" * 80) - print("PHASE 4 COMPLETE") + print("PHASE 5/6 COMPLETE") print("=" * 80) except Exception as e: diff --git a/transformer_lens/benchmarks/text_quality.py b/transformer_lens/benchmarks/text_quality.py new file mode 100644 index 000000000..5508b2926 --- /dev/null +++ b/transformer_lens/benchmarks/text_quality.py @@ -0,0 +1,310 @@ +"""Text quality benchmark for TransformerBridge. + +Generates text with the bridge model from multiple diverse prompts and scores +each continuation's legibility using GPT-2 Medium as a perplexity-based judge. +Only the generated continuation tokens are scored (prompt tokens are masked), +and a repetition penalty is applied to catch degenerate looping output. + +Generation is seeded for reproducibility, and the scoring model is loaded once +and reused across all prompts. +""" + +import gc +import math +from typing import List, Optional, Tuple + +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, +) + +from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity +from transformer_lens.model_bridge import TransformerBridge + +# Diverse prompts used alongside the caller-provided test_text to get a robust +# quality signal across different domains and styles. +_DEFAULT_PROMPTS = [ + "The theory of relativity explains that", + "In the dense forests of the Amazon,", + "Modern computing relies heavily on", +] + + +def _load_scoring_model( + scoring_model_name: str, + device: str, +) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: + """Load the scoring model and tokenizer. + + Separated from perplexity computation so the caller can load once and + reuse across multiple prompts. + """ + tokenizer = AutoTokenizer.from_pretrained(scoring_model_name) + model = AutoModelForCausalLM.from_pretrained(scoring_model_name) + torch.nn.Module.to(model, device) + model.eval() + return model, tokenizer + + +def _compute_continuation_perplexity( + prompt: str, + full_text: str, + tokenizer: PreTrainedTokenizerBase, + scoring_model: PreTrainedModel, + device: str, +) -> Tuple[float, Optional[str]]: + """Compute perplexity of only the continuation tokens (excluding prompt). + + Prompt tokens are masked with -100 in labels so CrossEntropyLoss ignores + them. This prevents well-formed prompt text from artificially lowering + the perplexity of generated content. + + Args: + prompt: The original input prompt. + full_text: The complete text (prompt + generated continuation). + tokenizer: Pre-loaded tokenizer. + scoring_model: Pre-loaded scoring model. + device: Device string. + + Returns: + Tuple of (perplexity, error_message). error_message is None on success. + """ + try: + encodings = tokenizer(full_text, return_tensors="pt") + input_ids = encodings["input_ids"].to(device) + + # Tokenize just the prompt to find where continuation starts + prompt_encodings = tokenizer(prompt, return_tensors="pt") + prompt_len = prompt_encodings["input_ids"].shape[1] + + # Build labels: -100 for prompt positions, actual ids for continuation + labels = input_ids.clone() + labels[0, :prompt_len] = -100 + + continuation_len = input_ids.shape[1] - prompt_len + if continuation_len < 2: + return float("inf"), "Generated continuation too short (< 2 tokens)" + + with torch.no_grad(): + outputs = scoring_model(input_ids, labels=labels) + loss = outputs.loss.item() + + perplexity = math.exp(loss) + return perplexity, None + + except Exception as e: + return float("inf"), f"Perplexity computation failed: {str(e)}" + + +def _compute_repetition_penalty(text: str, ns: Tuple[int, ...] = (2, 3, 4)) -> float: + """Compute a repetition penalty based on n-gram uniqueness ratio. + + Returns a multiplier in [0.0, 1.0] where 1.0 means no repetition and + lower values penalize repetitive text. The penalty is the minimum + unique-n-gram ratio across all checked n-gram sizes. + + Args: + text: The generated continuation text (prompt excluded). + ns: Tuple of n-gram sizes to check. + + Returns: + Penalty multiplier in [0.0, 1.0]. + """ + words = text.lower().split() + if len(words) < 2: + return 1.0 + + min_ratio = 1.0 + for n in ns: + if len(words) < n: + continue + ngrams = [tuple(words[i : i + n]) for i in range(len(words) - n + 1)] + if len(ngrams) == 0: + continue + unique_ratio = len(set(ngrams)) / len(ngrams) + min_ratio = min(min_ratio, unique_ratio) + + return min_ratio + + +def _perplexity_to_score(perplexity: float) -> float: + """Map continuation perplexity to a 0-100 legibility score. + + Uses: score = 135 - 10 * ln(perplexity), capped to [0, 100]. + Calibrated for continuation-only perplexity (higher than full-text). + A well-functioning model typically gets ppl 40-60 -> score 94-98. + Default pass threshold of 85 corresponds to approximately ppl 150. + + Args: + perplexity: The perplexity value from the scoring model. + + Returns: + Score from 0.0 to 100.0. + """ + if perplexity <= 0 or math.isinf(perplexity): + return 0.0 + return max(0.0, min(100.0, 135.0 - 10.0 * math.log(perplexity))) + + +def benchmark_text_quality( + bridge: TransformerBridge, + test_text: str, + max_new_tokens: int = 50, + scoring_model_name: str = "gpt2-medium", + pass_threshold: float = 85.0, + device: str = "cpu", +) -> BenchmarkResult: + """Benchmark text generation quality using continuation-only perplexity scoring. + + Generates text from multiple diverse prompts, scores each continuation using + GPT-2 Medium perplexity (prompt tokens masked), applies a repetition penalty, + and returns the averaged score. + + Args: + bridge: TransformerBridge model to test. + test_text: Primary input prompt (additional diverse prompts are also used). + max_new_tokens: Number of tokens to generate per prompt. + scoring_model_name: HuggingFace model to use as scorer. + pass_threshold: Minimum average score to pass (default 95.0). + device: Device for the scoring model. + + Returns: + BenchmarkResult with quality score details. + """ + scoring_model = None + try: + prompts = [test_text] + _DEFAULT_PROMPTS + + # Seed for reproducibility + torch.manual_seed(42) + + # Generate text for each prompt + generations: List[Tuple[str, str]] = [] # (prompt, full_text) + primary_generated = "" + for i, prompt in enumerate(prompts): + generated = bridge.generate( + prompt, + max_new_tokens=max_new_tokens, + temperature=0.7, + do_sample=True, + ) + if not isinstance(generated, str) or len(generated.strip()) == 0: + continue + generations.append((prompt, generated)) + if i == 0: + primary_generated = generated + + if len(generations) == 0: + return BenchmarkResult( + name="text_quality", + severity=BenchmarkSeverity.DANGER, + message="Generation produced empty output for all prompts", + passed=False, + ) + + # Load scoring model once + scoring_model, tokenizer = _load_scoring_model(scoring_model_name, device) + + # Score each continuation + per_prompt_scores = [] + per_prompt_perplexities = [] + per_prompt_penalties = [] + prompt_details_parts = [] + + for prompt, full_text in generations: + perplexity, error = _compute_continuation_perplexity( + prompt, full_text, tokenizer, scoring_model, device + ) + if error is not None: + continue + + raw_score = _perplexity_to_score(perplexity) + + # Repetition penalty on continuation only + continuation = full_text[len(prompt) :] + rep_penalty = _compute_repetition_penalty(continuation) + adjusted_score = raw_score * rep_penalty + + per_prompt_scores.append(adjusted_score) + per_prompt_perplexities.append(perplexity) + per_prompt_penalties.append(rep_penalty) + prompt_details_parts.append( + f"ppl={perplexity:.1f} score={adjusted_score:.1f} rep={rep_penalty:.2f}" + ) + + if len(per_prompt_scores) == 0: + return BenchmarkResult( + name="text_quality", + severity=BenchmarkSeverity.ERROR, + message="Scoring failed for all prompts", + details={"generated_text": primary_generated}, + passed=False, + ) + + avg_score = sum(per_prompt_scores) / len(per_prompt_scores) + avg_perplexity = sum(per_prompt_perplexities) / len(per_prompt_perplexities) + avg_rep_penalty = sum(per_prompt_penalties) / len(per_prompt_penalties) + + details = { + "score": round(avg_score, 1), + "avg_perplexity": round(avg_perplexity, 2), + "avg_repetition_penalty": round(avg_rep_penalty, 2), + "num_prompts": len(per_prompt_scores), + "per_prompt": " | ".join(prompt_details_parts), + "scoring_model": scoring_model_name, + "max_new_tokens": max_new_tokens, + "generated_text": primary_generated, + } + + if avg_score >= pass_threshold: + return BenchmarkResult( + name="text_quality", + severity=BenchmarkSeverity.INFO, + message=( + f"Text quality score: {avg_score:.1f}/100 " + f"(avg perplexity: {avg_perplexity:.1f}, " + f"{len(per_prompt_scores)} prompts)" + ), + details=details, + ) + elif avg_score >= 80.0: + return BenchmarkResult( + name="text_quality", + severity=BenchmarkSeverity.WARNING, + message=( + f"Text quality score: {avg_score:.1f}/100 " + f"(below {pass_threshold}, avg perplexity: {avg_perplexity:.1f})" + ), + details=details, + passed=False, + ) + else: + return BenchmarkResult( + name="text_quality", + severity=BenchmarkSeverity.DANGER, + message=( + f"Text quality score: {avg_score:.1f}/100 " + f"(avg perplexity: {avg_perplexity:.1f}) " + f"— generated text may be incoherent" + ), + details=details, + passed=False, + ) + + except Exception as e: + return BenchmarkResult( + name="text_quality", + severity=BenchmarkSeverity.ERROR, + message=f"Text quality benchmark failed: {str(e)}", + passed=False, + ) + + finally: + if scoring_model is not None: + del scoring_model + gc.collect() + if device != "cpu" and torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/transformer_lens/benchmarks/utils.py b/transformer_lens/benchmarks/utils.py index 50c1d0454..607ca7f52 100644 --- a/transformer_lens/benchmarks/utils.py +++ b/transformer_lens/benchmarks/utils.py @@ -7,6 +7,19 @@ import torch +def safe_allclose( + tensor1: torch.Tensor, + tensor2: torch.Tensor, + atol: float = 1e-5, + rtol: float = 1e-5, +) -> bool: + """torch.allclose that handles dtype mismatches by upcasting to float32.""" + if tensor1.dtype != tensor2.dtype: + tensor1 = tensor1.to(torch.float32) + tensor2 = tensor2.to(torch.float32) + return torch.allclose(tensor1, tensor2, atol=atol, rtol=rtol) + + class BenchmarkSeverity(Enum): """Severity levels for benchmark results.""" @@ -61,15 +74,11 @@ def print_immediate(self) -> None: @dataclass class PhaseReferenceData: - """Reference data saved from Phase 1 for reuse in Phase 3. - - When a model has no HookedTransformer support, Phase 1 HF logits serve as - ground truth for verifying that weight processing doesn't alter model output. - """ + """Float32 reference data from Phase 1 for Phase 3 equivalence comparison.""" - hf_logits: Optional[torch.Tensor] = None # [batch, seq, vocab] from HF model - hf_loss: Optional[float] = None # scalar loss from bridge (unprocessed) - test_text: Optional[str] = None # text used (for verification) + hf_logits: Optional[torch.Tensor] = None + hf_loss: Optional[float] = None + test_text: Optional[str] = None def compare_tensors( @@ -100,7 +109,10 @@ def compare_tensors( passed=False, ) - # Compare values + if tensor1.dtype != tensor2.dtype: + tensor1 = tensor1.to(torch.float32) + tensor2 = tensor2.to(torch.float32) + if torch.allclose(tensor1, tensor2, atol=atol, rtol=rtol): return BenchmarkResult( name=name, @@ -109,24 +121,15 @@ def compare_tensors( details={"atol": atol, "rtol": rtol}, ) - # Calculate differences diff = torch.abs(tensor1 - tensor2) max_diff = diff.max().item() mean_diff = diff.mean().item() rel_diff = diff / (torch.abs(tensor1) + 1e-10) mean_rel = rel_diff.mean().item() - # Determine severity based on differences - if max_diff < atol * 10 and mean_rel < rtol * 10: - severity = BenchmarkSeverity.WARNING - passed = True - else: - severity = BenchmarkSeverity.DANGER - passed = False - return BenchmarkResult( name=name, - severity=severity, + severity=BenchmarkSeverity.DANGER, message=f"Tensors differ: max_diff={max_diff:.6f}, mean_rel={mean_rel:.6f}", details={ "max_diff": max_diff, @@ -135,7 +138,7 @@ def compare_tensors( "atol": atol, "rtol": rtol, }, - passed=passed, + passed=False, ) @@ -165,18 +168,11 @@ def compare_scalars( message=f"Scalars match: {scalar1:.6f} ≈ {scalar2:.6f}", details={"diff": diff, "atol": atol}, ) - elif diff < atol * 10: - return BenchmarkResult( - name=name, - severity=BenchmarkSeverity.WARNING, - message=f"Scalars differ slightly: {scalar1:.6f} vs {scalar2:.6f}", - details={"diff": diff, "atol": atol}, - ) else: return BenchmarkResult( name=name, severity=BenchmarkSeverity.DANGER, - message=f"Scalars differ significantly: {scalar1:.6f} vs {scalar2:.6f}", + message=f"Scalars differ: {scalar1:.6f} vs {scalar2:.6f}", details={"diff": diff, "atol": atol}, passed=False, ) diff --git a/transformer_lens/benchmarks/weight_processing.py b/transformer_lens/benchmarks/weight_processing.py index 84b6875e6..6ed15db37 100644 --- a/transformer_lens/benchmarks/weight_processing.py +++ b/transformer_lens/benchmarks/weight_processing.py @@ -5,7 +5,11 @@ import torch from transformer_lens import HookedTransformer -from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity +from transformer_lens.benchmarks.utils import ( + BenchmarkResult, + BenchmarkSeverity, + safe_allclose, +) from transformer_lens.model_bridge import TransformerBridge @@ -174,7 +178,7 @@ def benchmark_weight_sharing( }, ) - if not torch.allclose(bridge_W_V, reference_W_V): # type: ignore[arg-type] + if not safe_allclose(bridge_W_V, reference_W_V): # type: ignore[arg-type] return BenchmarkResult( name="weight_sharing", severity=BenchmarkSeverity.WARNING, @@ -311,11 +315,38 @@ def benchmark_weight_modification( # Loss should change change = abs(modified_loss - original_loss) if change < 1e-6: + # W_V modification didn't propagate. This can happen in models with + # combined QKV projections (e.g., Bloom) where the split V weight + # is separate from the combined QKV weight used in forward. + # Try MLP weight modification as fallback. + mlp_fallback_error = None + try: + with torch.no_grad(): + original_mlp_w = bridge.blocks[0].mlp.out.weight.clone() + bridge.blocks[0].mlp.out.weight[0, :] = 0 + mlp_modified_loss = bridge(test_text, return_type="loss") + with torch.no_grad(): + bridge.blocks[0].mlp.out.weight.copy_(original_mlp_w) + mlp_change = abs(mlp_modified_loss - original_loss) + if mlp_change > 1e-6: + return BenchmarkResult( + name="weight_modification", + severity=BenchmarkSeverity.INFO, + message=f"Weight modification propagates via MLP (change: {mlp_change:.6f}). " + f"W_V not propagated (combined QKV architecture).", + details={"change": mlp_change.item(), "fallback": "mlp"}, + ) + except Exception as mlp_err: + mlp_fallback_error = str(mlp_err) + + details = {"change": change.item()} + if mlp_fallback_error is not None: + details["mlp_fallback_error"] = mlp_fallback_error return BenchmarkResult( name="weight_modification", severity=BenchmarkSeverity.DANGER, message=f"Weight modification did not affect loss (change: {change:.6f})", - details={"change": change.item()}, + details=details, passed=False, ) @@ -364,6 +395,16 @@ def benchmark_layer_norm_folding( BenchmarkResult with layer norm folding verification details """ try: + # Skip for architectures that don't support fold_ln (e.g., post-LN like BERT) + adapter = getattr(bridge, "adapter", None) + if adapter and not getattr(adapter, "supports_fold_ln", True): + return BenchmarkResult( + name="layer_norm_folding", + severity=BenchmarkSeverity.SKIPPED, + message="Skipped (post-LN architecture does not support fold_ln)", + passed=True, + ) + # Get state dict from bridge (should return TransformerLens format keys) state_dict = bridge.state_dict() @@ -525,8 +566,20 @@ def benchmark_mlp_output_centering( details={"is_moe": True}, ) - # Check if W_out exists and is accessible - if not hasattr(bridge.blocks[0].mlp, "W_out"): + # Check if W_out exists and is accessible (HT format or bridge format) + w_out = None + if hasattr(bridge.blocks[0].mlp, "W_out"): + w_out = bridge.blocks[0].mlp.W_out + elif hasattr(bridge.blocks[0].mlp, "out"): + # Bridge format: mlp.out is a LinearBridge wrapping nn.Linear + out_module = bridge.blocks[0].mlp.out + if hasattr(out_module, "original_component") and hasattr( + out_module.original_component, "weight" + ): + w_out = out_module.original_component.weight + elif hasattr(out_module, "weight"): + w_out = out_module.weight + if w_out is None: return BenchmarkResult( name="mlp_output_centering", severity=BenchmarkSeverity.WARNING, @@ -534,8 +587,6 @@ def benchmark_mlp_output_centering( passed=False, ) - w_out = bridge.blocks[0].mlp.W_out - # Compute mean along output dimension mean_abs = torch.mean(torch.abs(torch.mean(w_out, dim=-1))).item() diff --git a/transformer_lens/benchmarks/weight_processing_benchmark.py b/transformer_lens/benchmarks/weight_processing_benchmark.py deleted file mode 100644 index 6fd0f94cd..000000000 --- a/transformer_lens/benchmarks/weight_processing_benchmark.py +++ /dev/null @@ -1,473 +0,0 @@ -#!/usr/bin/env python3 -"""Benchmark suite for validating weight processing in TransformerBridge models. - -This suite verifies that each weight processing step (layer norm folding, centering, -value bias folding, etc.) has been correctly applied to the model weights. -""" - -from dataclasses import dataclass -from typing import Any, Dict, List, Tuple - -import torch - - -@dataclass -class WeightProcessingCheck: - """Result of a single weight processing verification check.""" - - name: str - passed: bool - details: Dict[str, Any] - message: str - - -class WeightProcessingBenchmark: - """Benchmark suite for validating weight processing steps.""" - - def __init__(self, bridge_model: Any, verbose: bool = True): - """Initialize the benchmark suite. - - Args: - bridge_model: TransformerBridge model instance - verbose: Whether to print detailed output - """ - self.bridge = bridge_model - self.cfg = bridge_model.cfg - self.verbose = verbose - self.results: List[WeightProcessingCheck] = [] - - def run_all_checks(self) -> Tuple[int, int]: - """Run all weight processing validation checks. - - Returns: - Tuple of (passed_count, total_count) - """ - if self.verbose: - print("=" * 80) - print(f"Weight Processing Benchmark: {self.cfg.model_name}") - print("=" * 80) - - # Get state dict from the processed model - state_dict = self.bridge.original_model.state_dict() - - # Clean keys - cleaned_state_dict = {} - for key, value in state_dict.items(): - clean_key = key.replace("._original_component", "") - cleaned_state_dict[clean_key] = value - - # Run checks - self._check_layer_norm_folding(cleaned_state_dict) - self._check_weight_centering(cleaned_state_dict) - self._check_unembed_centering(cleaned_state_dict) - self._check_value_bias_folding(cleaned_state_dict) - self._check_no_nan_inf(cleaned_state_dict) - self._check_weight_magnitudes(cleaned_state_dict) - - # Print summary - passed = sum(1 for r in self.results if r.passed) - total = len(self.results) - - if self.verbose: - print("\n" + "=" * 80) - print("RESULTS") - print("=" * 80) - for result in self.results: - status = "✅" if result.passed else "❌" - print(f"\n{status} {result.name}") - print(f" {result.message}") - if not result.passed or self.verbose: - for key, value in result.details.items(): - print(f" {key}: {value}") - - print("\n" + "=" * 80) - print(f"SUMMARY: {passed}/{total} checks passed ({100*passed//total}%)") - print("=" * 80) - - return passed, total - - def _check_layer_norm_folding(self, state_dict: Dict[str, torch.Tensor]) -> None: - """Check that layer norm weights have been folded into subsequent layers.""" - # For models with LayerNorm/RMSNorm, check if normalization weights still exist - # After folding, ln weights should be removed or set to identity - uses_rms_norm = getattr(self.cfg, "uses_rms_norm", False) - - # Check if ln1 weights exist for first block - ln1_key_patterns = [ - f"blocks.0.ln1.weight", # GPT-2 (TransformerLens format) - f"model.layers.0.input_layernorm.weight", # Gemma - ] - - ln1_exists = False - ln1_key = None - for pattern in ln1_key_patterns: - if pattern in state_dict: - ln1_exists = True - ln1_key = pattern - break - - if ln1_exists and ln1_key: - ln1_weight = state_dict[ln1_key] - - if uses_rms_norm: - # RMS norm weights should be folded (multiplied into downstream weights) - # After folding, they might still exist but should be identity-like - # For RMS norm, "identity" means all ones - expected_val = 1.0 - is_identity = torch.allclose(ln1_weight, torch.ones_like(ln1_weight), atol=1e-4) - - self.results.append( - WeightProcessingCheck( - name="layer_norm_folding", - passed=is_identity, - details={ - "norm_type": "RMSNorm", - "ln1_mean": ln1_weight.mean().item(), - "ln1_std": ln1_weight.std().item(), - "expected_mean": expected_val, - "is_identity": is_identity, - }, - message=f"RMSNorm weights {'are' if is_identity else 'are NOT'} identity after folding", - ) - ) - else: - # LayerNorm weights should be folded - # After folding, they should be identity (all ones) - expected_val = 1.0 - is_identity = torch.allclose(ln1_weight, torch.ones_like(ln1_weight), atol=1e-4) - - self.results.append( - WeightProcessingCheck( - name="layer_norm_folding", - passed=is_identity, - details={ - "norm_type": "LayerNorm", - "ln1_mean": ln1_weight.mean().item(), - "ln1_std": ln1_weight.std().item(), - "expected_mean": expected_val, - "is_identity": is_identity, - }, - message=f"LayerNorm weights {'are' if is_identity else 'are NOT'} identity after folding", - ) - ) - else: - # No ln1 found - might be architecture without it - self.results.append( - WeightProcessingCheck( - name="layer_norm_folding", - passed=True, - details={"norm_type": "None", "reason": "No ln1 layer found"}, - message="No normalization layer found (expected for some architectures)", - ) - ) - - def _check_weight_centering(self, state_dict: Dict[str, torch.Tensor]) -> None: - """Check that writing weights have been centered.""" - # Writing weights are those that write to the residual stream: - # - Attention output projection (W_O) - # - MLP output projection - # These should have mean ≈ 0 along the output dimension after centering - - # Check attention output (W_O / o_proj) - wo_key_patterns = [ - "blocks.0.attn.o.weight", # GPT-2 (TransformerLens format) - "model.layers.0.self_attn.o_proj.weight", # Gemma - ] - - wo_key = None - for pattern in wo_key_patterns: - if pattern in state_dict: - wo_key = pattern - break - - if wo_key: - wo = state_dict[wo_key] - # W_O shape is typically (d_model, d_model) or similar - # After centering, mean along output dimension (dim=0) should be ~0 - output_means = wo.mean(dim=-1) # Mean along input dimension - mean_magnitude = output_means.abs().mean().item() - - # Threshold for "centered" - mean should be small - is_centered = mean_magnitude < 0.01 - - self.results.append( - WeightProcessingCheck( - name="attention_output_centering", - passed=is_centered, - details={ - "weight": "W_O (attention output)", - "mean_magnitude": mean_magnitude, - "threshold": 0.01, - "shape": list(wo.shape), - }, - message=f"Attention output weights {'are' if is_centered else 'are NOT'} centered (mean={mean_magnitude:.6f})", - ) - ) - else: - self.results.append( - WeightProcessingCheck( - name="attention_output_centering", - passed=True, - details={"reason": "No W_O found"}, - message="No attention output weight found", - ) - ) - - # Check MLP output - mlp_out_patterns = [ - "blocks.0.mlp.out", # GPT-2 (TransformerLens format) - "model.layers.0.mlp.down_proj.weight", # Gemma - ] - - mlp_out_key = None - for pattern in mlp_out_patterns: - if pattern in state_dict: - mlp_out_key = pattern - break - - if mlp_out_key: - mlp_out = state_dict[mlp_out_key] - output_means = mlp_out.mean(dim=-1) - mean_magnitude = output_means.abs().mean().item() - - is_centered = mean_magnitude < 0.01 - - self.results.append( - WeightProcessingCheck( - name="mlp_output_centering", - passed=is_centered, - details={ - "weight": "MLP output", - "mean_magnitude": mean_magnitude, - "threshold": 0.01, - "shape": list(mlp_out.shape), - }, - message=f"MLP output weights {'are' if is_centered else 'are NOT'} centered (mean={mean_magnitude:.6f})", - ) - ) - else: - self.results.append( - WeightProcessingCheck( - name="mlp_output_centering", - passed=True, - details={"reason": "No MLP output found"}, - message="No MLP output weight found", - ) - ) - - def _check_unembed_centering(self, state_dict: Dict[str, torch.Tensor]) -> None: - """Check that unembedding matrix has been centered.""" - unembed_patterns = [ - "lm_head.weight", # Most models - "output.weight", # Some models - ] - - unembed_key = None - for pattern in unembed_patterns: - if pattern in state_dict: - unembed_key = pattern - break - - if unembed_key: - unembed = state_dict[unembed_key] - # Unembed should have mean ≈ 0 along vocabulary dimension - vocab_means = unembed.mean(dim=0) # Mean across vocabulary - mean_magnitude = vocab_means.abs().mean().item() - - is_centered = mean_magnitude < 0.1 # Slightly higher tolerance - - self.results.append( - WeightProcessingCheck( - name="unembed_centering", - passed=is_centered, - details={ - "mean_magnitude": mean_magnitude, - "threshold": 0.1, - "shape": list(unembed.shape), - }, - message=f"Unembedding matrix {'is' if is_centered else 'is NOT'} centered (mean={mean_magnitude:.6f})", - ) - ) - else: - self.results.append( - WeightProcessingCheck( - name="unembed_centering", - passed=True, - details={"reason": "No unembed found"}, - message="No unembedding matrix found", - ) - ) - - def _check_value_bias_folding(self, state_dict: Dict[str, torch.Tensor]) -> None: - """Check that value biases have been folded into output bias.""" - # After value bias folding, b_V should be zero and b_O should be modified - - # Check if b_V exists and is zero - bv_patterns = [ - "blocks.0.attn.v.bias", # GPT-2 (TransformerLens format) - "model.layers.0.self_attn.v_proj.bias", # Gemma - ] - - bv_key = None - for pattern in bv_patterns: - if pattern in state_dict: - bv_key = pattern - break - - # Check value bias (already split in TransformerLens format) - if bv_key: - bv = state_dict[bv_key] - bv_is_zero = torch.allclose(bv, torch.zeros_like(bv), atol=1e-6) - bv_mean = bv.abs().mean().item() - - self.results.append( - WeightProcessingCheck( - name="value_bias_folding", - passed=bv_is_zero, - details={ - "bv_mean_abs": bv_mean, - "threshold": 1e-6, - }, - message=f"Value bias {'is' if bv_is_zero else 'is NOT'} zero after folding (mean={bv_mean:.8f})", - ) - ) - else: - # No value bias found (some models don't have biases) - self.results.append( - WeightProcessingCheck( - name="value_bias_folding", - passed=True, - details={"reason": "No value bias found"}, - message="No value bias found (expected for some architectures)", - ) - ) - - def _check_no_nan_inf(self, state_dict: Dict[str, torch.Tensor]) -> None: - """Check that no weights contain NaN or Inf values.""" - nan_keys = [] - inf_keys = [] - - for key, tensor in state_dict.items(): - if torch.isnan(tensor).any(): - nan_keys.append(key) - if torch.isinf(tensor).any(): - inf_keys.append(key) - - has_issues = len(nan_keys) > 0 or len(inf_keys) > 0 - - self.results.append( - WeightProcessingCheck( - name="no_nan_inf", - passed=not has_issues, - details={ - "nan_count": len(nan_keys), - "inf_count": len(inf_keys), - "nan_keys": nan_keys[:5] if nan_keys else [], - "inf_keys": inf_keys[:5] if inf_keys else [], - }, - message=f"Weights {'contain' if has_issues else 'do not contain'} NaN/Inf values", - ) - ) - - def _check_weight_magnitudes(self, state_dict: Dict[str, torch.Tensor]) -> None: - """Check that weight magnitudes are reasonable (no explosion/vanishing).""" - issues = [] - - for key, tensor in state_dict.items(): - if "weight" not in key.lower(): - continue - - mean_abs = tensor.abs().mean().item() - max_abs = tensor.abs().max().item() - - # Check for suspiciously large or small weights - if mean_abs > 100: - issues.append(f"{key}: mean_abs={mean_abs:.2f} (too large)") - elif mean_abs < 1e-6 and "norm" not in key.lower(): - issues.append(f"{key}: mean_abs={mean_abs:.2e} (too small)") - - if max_abs > 1000: - issues.append(f"{key}: max_abs={max_abs:.2f} (too large)") - - has_issues = len(issues) > 0 - - self.results.append( - WeightProcessingCheck( - name="weight_magnitudes", - passed=not has_issues, - details={ - "issue_count": len(issues), - "issues": issues[:10], # First 10 issues - }, - message=f"Weight magnitudes {'are suspicious' if has_issues else 'are reasonable'}", - ) - ) - - -def benchmark_weight_processing( - model_name: str, device: str = "cpu", verbose: bool = True -) -> Tuple[int, int]: - """Run weight processing benchmark on a model. - - Args: - model_name: HuggingFace model name - device: Device to load model on - verbose: Whether to print detailed output - - Returns: - Tuple of (passed_count, total_count) - """ - import torch - - from transformer_lens.model_bridge import TransformerBridge - - if verbose: - print(f"\nLoading {model_name}...") - - # Load model with weight processing - bridge = TransformerBridge.boot_transformers(model_name, device=device, dtype=torch.float32) # type: ignore[attr-defined] - - if verbose: - print(f"Processing weights...") - - bridge.process_compatibility_weights(verbose=False) - - # Run benchmark - benchmark = WeightProcessingBenchmark(bridge, verbose=verbose) - return benchmark.run_all_checks() - - -if __name__ == "__main__": - import sys - - # Test on multiple models - models = [ - "gpt2", - "google/gemma-2-2b-it", - ] - - if len(sys.argv) > 1: - models = sys.argv[1:] - - total_passed = 0 - total_checks = 0 - - for model_name in models: - try: - passed, total = benchmark_weight_processing(model_name, verbose=True) - total_passed += passed - total_checks += total - print() - except Exception as e: - print(f"\n❌ Error benchmarking {model_name}: {e}") - import traceback - - traceback.print_exc() - - print("\n" + "=" * 80) - print("OVERALL SUMMARY") - print("=" * 80) - print( - f"Total: {total_passed}/{total_checks} checks passed ({100*total_passed//total_checks if total_checks > 0 else 0}%)" - ) - print("=" * 80) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index dd78fa53c..c5cd9ecda 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -718,7 +718,8 @@ def process_weights( if adapter and hasattr(adapter, "preprocess_weights"): state_dict = adapter.preprocess_weights(state_dict) - # Use unified ProcessWeights.process_weights() like HookedTransformer does + # Use unified ProcessWeights.process_weights() like HookedTransformer does. + # Float32 upcasting for precision is handled centrally in process_weights(). if verbose: print(" Processing weights (fold_ln, center_writing_weights, etc.)...") state_dict = ProcessWeights.process_weights( @@ -732,7 +733,29 @@ def process_weights( adapter=adapter, ) - # print("new", state_dict.keys()) + # Normalize any remaining HF-prefix keys to TL format. + # Some architectures (e.g., OPT with SymbolicBridge) produce state dict keys + # with HF prefixes (model.decoder.layers.0.mlp.in.weight) instead of TL prefixes + # (blocks.0.mlp.in.weight). distribute_weights_to_components uses TL prefixes + # for routing, so we normalize all keys here. + import re + + hf_to_tl_prefix = {} + for tl_name, (remote_path, _component) in self.real_components.items(): + if remote_path and remote_path != tl_name: + hf_to_tl_prefix[remote_path] = tl_name + + normalized_state_dict = {} + for key, value in state_dict.items(): + new_key = key + for hf_prefix, tl_prefix in hf_to_tl_prefix.items(): + if key.startswith(hf_prefix + "."): + suffix = key[len(hf_prefix) + 1 :] + new_key = f"{tl_prefix}.{suffix}" + break + normalized_state_dict[new_key] = value + state_dict = normalized_state_dict + if verbose: print(" Distributing weights to generalized components...") ProcessWeights.distribute_weights_to_components( diff --git a/transformer_lens/model_bridge/generalized_components/attention.py b/transformer_lens/model_bridge/generalized_components/attention.py index 07ce195bd..6c2fb9d8a 100644 --- a/transformer_lens/model_bridge/generalized_components/attention.py +++ b/transformer_lens/model_bridge/generalized_components/attention.py @@ -320,8 +320,6 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: ): hooked = hooked.to(dtype=target_dtype) args = (hooked,) + args[1:] - kwargs["hidden_states"] = args[0] - args = args[1:] output = self.original_component(*args, **kwargs) if isinstance(output, tuple) and len(output) >= 2: # output[0] is attention output diff --git a/transformer_lens/model_bridge/generalized_components/bloom_attention.py b/transformer_lens/model_bridge/generalized_components/bloom_attention.py index 96dfa583c..db16f6d60 100644 --- a/transformer_lens/model_bridge/generalized_components/bloom_attention.py +++ b/transformer_lens/model_bridge/generalized_components/bloom_attention.py @@ -3,7 +3,7 @@ BLOOM attention requires special arguments (residual, alibi, attention_mask) that standard JointQKVAttentionBridge doesn't handle. This custom component passes these arguments through. """ -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Mapping, Optional import torch @@ -109,3 +109,69 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: output = self.hook_out(output) return output + + def set_processed_weights( + self, weights: Mapping[str, torch.Tensor | None], verbose: bool = False + ) -> None: + """Set processed weights and recombine Q/K/V back into combined QKV. + + BloomAttentionBridge's forward() delegates to the original HF attention + component which uses the combined query_key_value weight. After weight + processing (fold_ln etc.) modifies the split Q/K/V weights, we must + recombine them back into the interleaved QKV format so the original + component uses the processed weights. + """ + # First, let the parent distribute weights to Q/K/V/O submodules + super().set_processed_weights(dict(weights), verbose=verbose) # type: ignore[arg-type] + + if self.original_component is None: + return + + # Get the processed Q/K/V weights from split components + assert self.q.original_component is not None + assert self.k.original_component is not None + assert self.v.original_component is not None + q_weight: torch.Tensor = self.q.original_component.weight.data # type: ignore[union-attr, assignment] + k_weight: torch.Tensor = self.k.original_component.weight.data # type: ignore[union-attr, assignment] + v_weight: torch.Tensor = self.v.original_component.weight.data # type: ignore[union-attr, assignment] + + assert self.config is not None + n_heads: int = self.config.n_heads + d_head: int = self.config.d_head + d_model = int(q_weight.shape[1]) + + # Reverse the split: recombine into interleaved QKV format + # [n_heads*d_head, d_model] -> [d_model, n_heads, d_head] + W_Q = q_weight.T.reshape(d_model, n_heads, d_head) + W_K = k_weight.T.reshape(d_model, n_heads, d_head) + W_V = v_weight.T.reshape(d_model, n_heads, d_head) + + # Stack into [d_model, n_heads, 3, d_head] (interleaved format) + W_combined = torch.stack([W_Q, W_K, W_V], dim=2) + + # Reshape to [d_model, 3*n_heads*d_head] and transpose to nn.Linear format + qkv_weight = W_combined.reshape(d_model, 3 * n_heads * d_head).T + + # Update the original component's combined QKV weight + self.original_component.query_key_value.weight = torch.nn.Parameter( # type: ignore[union-attr] + qkv_weight + ) + + # Also recombine biases + q_bias = self.q.original_component.bias # type: ignore[union-attr] + if q_bias is not None: + assert self.k.original_component is not None + assert self.v.original_component is not None + k_bias = self.k.original_component.bias.data # type: ignore[union-attr] + v_bias = self.v.original_component.bias.data # type: ignore[union-attr] + + # [n_heads*d_head] -> [n_heads, d_head] + b_Q = q_bias.data.reshape(n_heads, d_head) # type: ignore[union-attr, operator] + b_K = k_bias.reshape(n_heads, d_head) # type: ignore[operator] + b_V = v_bias.reshape(n_heads, d_head) # type: ignore[operator] + + # Stack into [n_heads, 3, d_head] and flatten + qkv_bias = torch.stack([b_Q, b_K, b_V], dim=1).reshape(3 * n_heads * d_head) + self.original_component.query_key_value.bias = torch.nn.Parameter( # type: ignore[union-attr] + qkv_bias + ) diff --git a/transformer_lens/model_bridge/generalized_components/t5_block.py b/transformer_lens/model_bridge/generalized_components/t5_block.py index fcacd1679..54a2be7fc 100644 --- a/transformer_lens/model_bridge/generalized_components/t5_block.py +++ b/transformer_lens/model_bridge/generalized_components/t5_block.py @@ -86,12 +86,19 @@ def patched_forward( **kwargs, ): """Patched T5 block forward with hooks.""" + import inspect + hidden_states = self.hook_in(hidden_states) if not hasattr(block_self, "layer"): raise RuntimeError(f"T5 block {block_self} does not have 'layer' attribute") layers = block_self.layer is_decoder_block = len(layers) == 3 - if past_key_value is not None: + + # Check which parameters are accepted by the layer forward methods + # (Transformers v5 removed past_key_value, use_cache, layer_head_mask) + self_attn_params = set(inspect.signature(layers[0].forward).parameters.keys()) + + if "past_key_value" in self_attn_params and past_key_value is not None: if not is_decoder_block: expected_num_past_key_values = 0 else: @@ -105,33 +112,43 @@ def patched_forward( else: self_attn_past_key_value = None cross_attn_past_key_value = None - self_attention_outputs = layers[0]( - hidden_states, + self_attn_kwargs = dict( + hidden_states=hidden_states, attention_mask=attention_mask, position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, - use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, ) + # Conditionally pass parameters removed in Transformers v5 + if "past_key_value" in self_attn_params: + self_attn_kwargs["past_key_value"] = self_attn_past_key_value + if "use_cache" in self_attn_params: + self_attn_kwargs["use_cache"] = use_cache + if "layer_head_mask" in self_attn_params: + self_attn_kwargs["layer_head_mask"] = layer_head_mask + self_attention_outputs = layers[0](**self_attn_kwargs) hidden_states = self_attention_outputs[0] # Keep self-attention outputs and relative position weights # attention_outputs contains: (position_bias,) or (position_bias, attn_weights) attention_outputs = self_attention_outputs[1:] hidden_states = self.hook_resid_mid(hidden_states) if is_decoder_block and encoder_hidden_states is not None: - cross_attention_outputs = layers[1]( - hidden_states, + cross_attn_params = set(inspect.signature(layers[1].forward).parameters.keys()) + cross_attn_kwargs = dict( + hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, ) + if "past_key_value" in cross_attn_params: + cross_attn_kwargs["past_key_value"] = cross_attn_past_key_value + if "use_cache" in cross_attn_params: + cross_attn_kwargs["use_cache"] = use_cache + if "layer_head_mask" in cross_attn_params: + cross_attn_kwargs["layer_head_mask"] = cross_attn_layer_head_mask + cross_attention_outputs = layers[1](**cross_attn_kwargs) hidden_states = cross_attention_outputs[0] if hasattr(self, "hook_resid_mid2"): hidden_states = self.hook_resid_mid2(hidden_states) diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index 41f2a6581..1bc117df1 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -11,8 +11,8 @@ import torch from transformers import ( AutoConfig, - AutoModel, AutoModelForCausalLM, + AutoModelForMaskedLM, AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedTokenizerBase, @@ -227,7 +227,7 @@ def get_hf_model_class_for_architecture(architecture: str): if architecture in seq2seq_architectures: return AutoModelForSeq2SeqLM elif architecture in masked_lm_architectures: - return AutoModel + return AutoModelForMaskedLM else: return AutoModelForCausalLM diff --git a/transformer_lens/model_bridge/supported_architectures/bert.py b/transformer_lens/model_bridge/supported_architectures/bert.py index daf0b9bbf..634e2975b 100644 --- a/transformer_lens/model_bridge/supported_architectures/bert.py +++ b/transformer_lens/model_bridge/supported_architectures/bert.py @@ -40,44 +40,63 @@ def __init__(self, cfg: Any) -> None: self.cfg.gated_mlp = False self.cfg.attn_only = False + # BERT uses post-LN (LayerNorm after residual, not before sublayer). + # fold_ln assumes pre-LN (LN before sublayer) and folds ln1 into attention + # QKV and ln2 into MLP. For post-LN, ln1 output feeds MLP (not attention) + # and ln2 output feeds next block's attention (not MLP), so folding into + # the wrong sublayer produces incorrect results. + self.supports_fold_ln = False + + n_heads = self.cfg.n_heads + self.weight_processing_conversions = { "blocks.{i}.attn.q.weight": ParamProcessingConversion( tensor_conversion=RearrangeTensorConversion( - "(h d_head) d_model -> h d_head d_model" + "(h d_head) d_model -> h d_model d_head", h=n_heads ), ), "blocks.{i}.attn.k.weight": ParamProcessingConversion( tensor_conversion=RearrangeTensorConversion( - "(h d_head) d_model -> h d_head d_model" + "(h d_head) d_model -> h d_model d_head", h=n_heads ), ), "blocks.{i}.attn.v.weight": ParamProcessingConversion( tensor_conversion=RearrangeTensorConversion( - "(h d_head) d_model -> h d_head d_model" + "(h d_head) d_model -> h d_model d_head", h=n_heads ), ), "blocks.{i}.attn.q.bias": ParamProcessingConversion( - tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head"), + tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), ), "blocks.{i}.attn.k.bias": ParamProcessingConversion( - tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head"), + tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), ), "blocks.{i}.attn.v.bias": ParamProcessingConversion( - tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head"), + tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), ), "blocks.{i}.attn.o.weight": ParamProcessingConversion( tensor_conversion=RearrangeTensorConversion( - "d_model (h d_head) -> h d_head d_model" + "d_model (h d_head) -> h d_head d_model", h=n_heads ), ), } # Set up component mapping + # The bridge loads BertForMaskedLM, so core model paths need the 'bert.' prefix. + # The MLM head (cls.predictions) is at the top level of BertForMaskedLM. self.component_mapping = { - "embed": EmbeddingBridge(name="bert.embeddings"), + "embed": EmbeddingBridge(name="bert.embeddings.word_embeddings"), "pos_embed": PosEmbedBridge(name="bert.embeddings.position_embeddings"), "blocks": BlockBridge( name="bert.encoder.layer", + # BERT has no single MLP module (intermediate.dense and output.dense + # are siblings in BertLayer), so the MLPBridge forward is never called + # and mlp.hook_out never fires. Redirect hook_mlp_out to the actual + # MLP output hook (output of the "out" linear layer). + hook_alias_overrides={ + "hook_mlp_out": "mlp.out.hook_out", + "hook_mlp_in": "mlp.in.hook_in", + }, submodules={ "ln1": NormalizationBridge(name="attention.output.LayerNorm", config=self.cfg), "ln2": NormalizationBridge(name="output.LayerNorm", config=self.cfg), @@ -92,15 +111,17 @@ def __init__(self, cfg: Any) -> None: }, ), "mlp": MLPBridge( - name="intermediate", + name=None, config=self.cfg, submodules={ - "in": LinearBridge(name="dense"), - "out": LinearBridge(name="../output.dense"), + "in": LinearBridge(name="intermediate.dense"), + "out": LinearBridge(name="output.dense"), }, ), }, ), - "unembed": UnembeddingBridge(name="cls.predictions"), - "ln_final": NormalizationBridge(name="bert.pooler.dense", config=self.cfg), + "unembed": UnembeddingBridge(name="cls.predictions.decoder"), + "ln_final": NormalizationBridge( + name="cls.predictions.transform.LayerNorm", config=self.cfg + ), } diff --git a/transformer_lens/model_bridge/supported_architectures/bloom.py b/transformer_lens/model_bridge/supported_architectures/bloom.py index 87984c130..517f05376 100644 --- a/transformer_lens/model_bridge/supported_architectures/bloom.py +++ b/transformer_lens/model_bridge/supported_architectures/bloom.py @@ -35,34 +35,29 @@ def __init__(self, cfg: Any) -> None: self.cfg.attn_only = False self.cfg.default_prepend_bos = False + # After split_qkv_matrix, Q/K/V are individual [n_heads*d_head, d_model] weights. + # Convert to TL format [n_heads, d_model, d_head]. self.weight_processing_conversions = { "blocks.{i}.attn.q": ParamProcessingConversion( tensor_conversion=RearrangeTensorConversion( - "(three n h) m -> three n m h", - three=3, + "(n h) m -> n m h", n=self.cfg.n_heads, ), - source_key="transformer.h.{i}.self_attention.query_key_value.weight", ), "blocks.{i}.attn.k": ParamProcessingConversion( tensor_conversion=RearrangeTensorConversion( - "(three n h) m -> three n m h", - three=3, + "(n h) m -> n m h", n=self.cfg.n_heads, ), - source_key="transformer.h.{i}.self_attention.query_key_value.weight", ), "blocks.{i}.attn.v": ParamProcessingConversion( tensor_conversion=RearrangeTensorConversion( - "(three n h) m -> three n m h", - three=3, + "(n h) m -> n m h", n=self.cfg.n_heads, ), - source_key="transformer.h.{i}.self_attention.query_key_value.weight", ), "blocks.{i}.attn.o": ParamProcessingConversion( tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), - source_key="transformer.h.{i}.self_attention.dense.weight", ), } @@ -118,41 +113,40 @@ def split_qkv_matrix( # Keep mypy happy assert isinstance(qkv_weights, torch.Tensor) - # We want to split weights into [d_model, n_heads * d_head] for each of Q, K, V - W_split = qkv_weights.T.reshape(self.cfg.d_model, 3, self.cfg.n_heads * self.cfg.d_head) + # Bloom QKV weights are interleaved: [Q0,K0,V0, Q1,K1,V1, ...] + # i.e. layout is (n_heads, 3, d_head), not (3, n_heads*d_head). + # Reshape to [d_model, n_heads, 3, d_head] to correctly deinterleave. + W_split = qkv_weights.T.reshape(self.cfg.d_model, self.cfg.n_heads, 3, self.cfg.d_head) - W_Q, W_K, W_V = W_split[:, 0, :], W_split[:, 1, :], W_split[:, 2, :] + # W_Q/K/V shape: [d_model, n_heads, d_head] + W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :] qkv_bias = original_attention_component.query_key_value.bias # Keep mypy happy assert isinstance(qkv_bias, torch.Tensor) - # Reshape to [3, n_heads * d_head] to split by Q, K, V - qkv_bias = qkv_bias.reshape(3, self.cfg.n_heads * self.cfg.d_head) + # Same interleaved layout for bias: reshape to [n_heads, 3, d_head] + qkv_bias = qkv_bias.reshape(self.cfg.n_heads, 3, self.cfg.d_head) - b_Q, b_K, b_V = qkv_bias[0, :], qkv_bias[1, :], qkv_bias[2, :] + # b_Q/K/V shape: [n_heads, d_head] + b_Q, b_K, b_V = qkv_bias[:, 0, :], qkv_bias[:, 1, :], qkv_bias[:, 2, :] # Create nn.Linear modules - # W_Q, W_K, W_V shapes are [d_model, n_heads * d_head] - # nn.Linear expects weight shape [out_features, in_features] - # So for Linear(d_model, n_heads * d_head), weight should be [n_heads * d_head, d_model] - W_Q_transformation = torch.nn.Linear(W_Q.shape[0], W_Q.shape[1], bias=True) - W_Q_transformation.weight = torch.nn.Parameter( - W_Q.T - ) # Transpose to [n_heads * d_head, d_model] - W_Q_transformation.bias = torch.nn.Parameter(b_Q) - - W_K_transformation = torch.nn.Linear(W_K.shape[0], W_K.shape[1], bias=True) - W_K_transformation.weight = torch.nn.Parameter( - W_K.T - ) # Transpose to [n_heads * d_head, d_model] - W_K_transformation.bias = torch.nn.Parameter(b_K) - - W_V_transformation = torch.nn.Linear(W_V.shape[0], W_V.shape[1], bias=True) - W_V_transformation.weight = torch.nn.Parameter( - W_V.T - ) # Transpose to [n_heads * d_head, d_model] - W_V_transformation.bias = torch.nn.Parameter(b_V) + # W_Q shape is [d_model, n_heads, d_head] -> flatten to [d_model, n_heads*d_head] + # nn.Linear expects weight shape [out_features, in_features] = [n_heads*d_head, d_model] + d_out = self.cfg.n_heads * self.cfg.d_head + + W_Q_transformation = torch.nn.Linear(self.cfg.d_model, d_out, bias=True) + W_Q_transformation.weight = torch.nn.Parameter(W_Q.reshape(self.cfg.d_model, d_out).T) + W_Q_transformation.bias = torch.nn.Parameter(b_Q.reshape(d_out)) + + W_K_transformation = torch.nn.Linear(self.cfg.d_model, d_out, bias=True) + W_K_transformation.weight = torch.nn.Parameter(W_K.reshape(self.cfg.d_model, d_out).T) + W_K_transformation.bias = torch.nn.Parameter(b_K.reshape(d_out)) + + W_V_transformation = torch.nn.Linear(self.cfg.d_model, d_out, bias=True) + W_V_transformation.weight = torch.nn.Parameter(W_V.reshape(self.cfg.d_model, d_out).T) + W_V_transformation.bias = torch.nn.Parameter(b_V.reshape(d_out)) return W_Q_transformation, W_K_transformation, W_V_transformation diff --git a/transformer_lens/model_bridge/supported_architectures/neox.py b/transformer_lens/model_bridge/supported_architectures/neox.py index 634afbeae..932a742bb 100644 --- a/transformer_lens/model_bridge/supported_architectures/neox.py +++ b/transformer_lens/model_bridge/supported_architectures/neox.py @@ -126,7 +126,9 @@ def __init__(self, cfg: Any) -> None: ), "blocks.{i}.attn.o": ParamProcessingConversion( tensor_conversion=RearrangeTensorConversion( - "d_model (head d_head) -> head d_head d_model" + "d_model (head d_head) -> head d_head d_model", + head=self.cfg.n_heads, + d_head=self.cfg.d_model // self.cfg.n_heads, ), source_key="gpt_neox.layers.{i}.attention.dense.weight", ), diff --git a/transformer_lens/model_bridge/supported_architectures/pythia.py b/transformer_lens/model_bridge/supported_architectures/pythia.py index 3f9c2c89a..7d4d9c924 100644 --- a/transformer_lens/model_bridge/supported_architectures/pythia.py +++ b/transformer_lens/model_bridge/supported_architectures/pythia.py @@ -119,7 +119,9 @@ def __init__(self, cfg: Any) -> None: ), "blocks.{i}.attn.o": ParamProcessingConversion( tensor_conversion=RearrangeTensorConversion( - "d_model (head d_head) -> head d_head d_model" + "d_model (head d_head) -> head d_head d_model", + head=self.cfg.n_heads, + d_head=self.cfg.d_model // self.cfg.n_heads, ), source_key="gpt_neox.layers.{i}.attention.dense.weight", ), diff --git a/transformer_lens/weight_processing.py b/transformer_lens/weight_processing.py index 8abdd0a8a..09f28938a 100644 --- a/transformer_lens/weight_processing.py +++ b/transformer_lens/weight_processing.py @@ -104,6 +104,43 @@ def _prepare_component_path(tl_key: str) -> str: return f"{base_path}.{replacement}" return tl_key + @staticmethod + def _resolve_state_dict_key( + state_dict: Dict[str, torch.Tensor], + key: str, + layer: Optional[int] = None, + ) -> str: + """Resolve a bridge-style key to the actual key in the state_dict. + + Some architectures (e.g., OPT with SymbolicBridge) store parameters + with HF-style prefixes instead of bridge-style prefixes. This method + handles the key resolution by falling back to a suffix search. + + Args: + state_dict: Model state dictionary + key: The expected key (e.g., "blocks.0.mlp.in.weight") + layer: Optional layer index for layer-specific searches + + Returns: + The actual key found in state_dict, or the original key if no match + """ + if key in state_dict: + return key + + # Extract the component path after "blocks.{i}." + import re + + match = re.match(r"blocks\.(\d+)\.(.*)", key) + if match: + layer_idx = match.group(1) + component_suffix = match.group(2) + # Search for keys ending with the component suffix that include the layer index + for sd_key in state_dict: + if sd_key.endswith(f".{component_suffix}") and f".{layer_idx}." in sd_key: + return sd_key + + return key + @staticmethod def _safe_get_tensor( state_dict: Dict[str, torch.Tensor], @@ -327,6 +364,22 @@ def extract_attention_tensors_for_folding( b_V_key, state_dict, bv_tensor, cfg, adapter, layer ) + # Auto-reshape 1D biases to [n_heads, d_head] when weights are 3D + # [n_heads, d_model, d_head]. This handles adapters that define weight + # conversions but not bias conversions (e.g., OPT). + def _reshape_bias_if_needed(bias, weight): + if bias is not None and weight is not None: + if len(weight.shape) == 3 and len(bias.shape) == 1: + n_heads = weight.shape[0] + d_head = weight.shape[2] + if bias.shape[0] == n_heads * d_head: + return bias.reshape(n_heads, d_head) + return bias + + bq_tensor = _reshape_bias_if_needed(bq_tensor, wq_tensor) + bk_tensor = _reshape_bias_if_needed(bk_tensor, wk_tensor) + bv_tensor = _reshape_bias_if_needed(bv_tensor, wv_tensor) + return { "wq": wq_tensor, "wk": wk_tensor, @@ -493,15 +546,27 @@ def _fold_mlp_layer_norm( if getattr(cfg, "attn_only", False): return state_dict - mlp_b_in_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.b_in", adapter) - mlp_W_in_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_in", adapter) + mlp_b_in_key = ProcessWeights._resolve_state_dict_key( + state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.mlp.b_in", adapter), layer + ) + mlp_W_in_key = ProcessWeights._resolve_state_dict_key( + state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_in", adapter), layer + ) mlp_W_gate_key = ( - ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_gate", adapter) + ProcessWeights._resolve_state_dict_key( + state_dict, + ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_gate", adapter), + layer, + ) if getattr(cfg, "gated_mlp", False) else None ) - ln2_b_key = ProcessWeights._get_param_key(f"blocks.{layer}.ln2.b", adapter) - ln2_w_key = ProcessWeights._get_param_key(f"blocks.{layer}.ln2.w", adapter) + ln2_b_key = ProcessWeights._resolve_state_dict_key( + state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.ln2.b", adapter), layer + ) + ln2_w_key = ProcessWeights._resolve_state_dict_key( + state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.ln2.w", adapter), layer + ) # CRITICAL FIX: For RMS norm (Gemma), ln2_b doesn't exist. Only require ln2_w! if ln2_w_key in state_dict: # MoE layers: fold ln2 into router gate and each expert's W_in/W_gate @@ -602,9 +667,27 @@ def _fold_mlp_layer_norm( mlp_W_in_key, state_dict, state_dict.get(mlp_W_in_key), cfg, adapter, layer ) assert mlp_W_in_centered is not None, f"MLP W_in not found at key {mlp_W_in_key}" - mlp_W_in_centered = mlp_W_in_centered - einops.reduce( - mlp_W_in_centered, "d_model d_mlp -> 1 d_mlp", "mean" - ) + # Center along d_model dimension. Detect format: + # TL format [d_model, d_mlp] -> center along dim=0 + # HF format [d_mlp, d_model] -> center along dim=-1 + d_model = cfg.d_model if cfg is not None else None + if ( + d_model is not None + and mlp_W_in_centered.shape[0] == d_model + and mlp_W_in_centered.shape[-1] != d_model + ): + # TL format [d_model, d_mlp] + mlp_W_in_centered = mlp_W_in_centered - mlp_W_in_centered.mean(0, keepdim=True) + elif ( + d_model is not None + and mlp_W_in_centered.shape[-1] == d_model + and mlp_W_in_centered.shape[0] != d_model + ): + # HF format [d_mlp, d_model] + mlp_W_in_centered = mlp_W_in_centered - mlp_W_in_centered.mean(-1, keepdim=True) + else: + # Fallback: assume TL format + mlp_W_in_centered = mlp_W_in_centered - mlp_W_in_centered.mean(0, keepdim=True) state_dict[mlp_W_in_key] = ProcessWeights.convert_tensor_to_hf_format( mlp_W_in_key, mlp_W_in_centered, cfg, adapter, layer ) @@ -638,9 +721,24 @@ def _fold_mlp_layer_norm( new_mlp_W_out = mlp_W_out * mlp_ln_w[:, None] if center_weights: - new_mlp_W_out = new_mlp_W_out - einops.reduce( - new_mlp_W_out, "d_mlp d_model -> 1 d_model", "mean" - ) + # Center along d_mlp dimension. Detect format: + # TL format [d_mlp, d_model] -> center along dim=0 + # HF format [d_model, d_mlp] -> center along dim=-1 + d_model_val = cfg.d_model if cfg is not None else None + if ( + d_model_val is not None + and new_mlp_W_out.shape[-1] == d_model_val + and new_mlp_W_out.shape[0] != d_model_val + ): + new_mlp_W_out = new_mlp_W_out - new_mlp_W_out.mean(0, keepdim=True) + elif ( + d_model_val is not None + and new_mlp_W_out.shape[0] == d_model_val + and new_mlp_W_out.shape[-1] != d_model_val + ): + new_mlp_W_out = new_mlp_W_out - new_mlp_W_out.mean(-1, keepdim=True) + else: + new_mlp_W_out = new_mlp_W_out - new_mlp_W_out.mean(0, keepdim=True) state_dict[mlp_W_out_key] = ProcessWeights.convert_tensor_to_hf_format( mlp_W_out_key, new_mlp_W_out, cfg, adapter, layer @@ -781,9 +879,23 @@ def _fold_unembed_layer_norm( unembed_weight_centered is not None ), f"Unembed weight not found at key {unembed_W_U_key}" if len(unembed_weight_centered.shape) == 2: - unembed_weight_centered = unembed_weight_centered - einops.reduce( - unembed_weight_centered, "d_model d_vocab -> 1 d_vocab", "mean" - ) + # Detect format: TL [d_model, d_vocab] vs HF [d_vocab, d_model]. + # Center along the d_model dimension (mean over d_model). + d_vocab = getattr(cfg, "d_vocab", None) if cfg is not None else None + if ( + d_vocab is not None + and unembed_weight_centered.shape[0] == d_vocab + and unembed_weight_centered.shape[-1] != d_vocab + ): + # HF format [d_vocab, d_model] — center along dim=-1 + unembed_weight_centered = ( + unembed_weight_centered - unembed_weight_centered.mean(-1, keepdim=True) + ) + else: + # TL format [d_model, d_vocab] — center along dim=0 + unembed_weight_centered = ( + unembed_weight_centered - unembed_weight_centered.mean(0, keepdim=True) + ) state_dict[unembed_W_U_key] = ProcessWeights.convert_tensor_to_hf_format( unembed_W_U_key, unembed_weight_centered, cfg, adapter, None ) @@ -921,7 +1033,8 @@ def center_writing_weights( try: pos_embed_W_pos_key = ( ProcessWeights._get_param_key("pos_embed.W_pos", adapter) - if getattr(cfg, "positional_embedding_type", "standard") != "rotary" + if getattr(cfg, "positional_embedding_type", "standard") + not in ("rotary", "alibi") else None ) except ValueError: @@ -940,7 +1053,7 @@ def center_writing_weights( ) if ( - getattr(cfg, "positional_embedding_type", "standard") != "rotary" + getattr(cfg, "positional_embedding_type", "standard") not in ("rotary", "alibi") and pos_embed_W_pos_key is not None ): if pos_embed_W_pos_key not in state_dict: @@ -966,8 +1079,12 @@ def center_writing_weights( attn_W_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_O", adapter) attn_b_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_O", adapter) try: - mlp_W_out_key = ProcessWeights._get_param_key(f"blocks.{l}.mlp.W_out", adapter) - mlp_b_out_key = ProcessWeights._get_param_key(f"blocks.{l}.mlp.b_out", adapter) + mlp_W_out_key = ProcessWeights._resolve_state_dict_key( + state_dict, ProcessWeights._get_param_key(f"blocks.{l}.mlp.W_out", adapter), l + ) + mlp_b_out_key = ProcessWeights._resolve_state_dict_key( + state_dict, ProcessWeights._get_param_key(f"blocks.{l}.mlp.b_out", adapter), l + ) except ValueError: mlp_W_out_key = None mlp_b_out_key = None @@ -1335,6 +1452,20 @@ def process_weights( Returns: Dict[str, torch.Tensor]: Fully processed state dict. """ + # Upcast to float32 for weight processing to avoid precision loss in + # reduced-precision dtypes (bfloat16, float16). Operations like LayerNorm + # folding involve multiplications that accumulate rounding errors when + # performed in low precision. + original_dtypes: Dict[str, torch.dtype] = {} + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor) and v.is_floating_point() and v.dtype != torch.float32: + original_dtypes[k] = v.dtype + state_dict[k] = v.float() + + # Skip fold_ln for adapters that don't support it (e.g., post-LN architectures + # like BERT where LN placement means folding goes into the wrong sublayer). + if fold_ln and adapter and not getattr(adapter, "supports_fold_ln", True): + fold_ln = False if fold_ln: if getattr(cfg, "normalization_type", "LN") in ["LN", "LNPre"]: state_dict = ProcessWeights.fold_layer_norm( @@ -1378,6 +1509,12 @@ def process_weights( state_dict = ProcessWeights.refactor_factored_attn_matrices( state_dict, cfg, adapter=adapter ) + + # Downcast back to original dtypes + for k, orig_dtype in original_dtypes.items(): + if k in state_dict and isinstance(state_dict[k], torch.Tensor): + state_dict[k] = state_dict[k].to(orig_dtype) + return state_dict @staticmethod @@ -1597,9 +1734,22 @@ def convert_tensor_to_tl_format( if "blocks." in param_name: placeholder_param_name = re.sub(r"blocks\.(\d+)\.", "blocks.{i}.", param_name) - # Check if we have a conversion for this parameter + # Check if we have a conversion for this parameter. + # Try exact match first, then strip .weight suffix for adapters + # that define conversions without the suffix (e.g. Pythia's "blocks.{i}.attn.q"). + # NOTE: Only strip .weight, NOT .bias — stripping .bias would incorrectly + # match bias keys against weight conversions (e.g. "blocks.{i}.attn.q.bias" + # would match the weight conversion for "blocks.{i}.attn.q"). + matched_key = None if placeholder_param_name in adapter.weight_processing_conversions: - param_conversion = adapter.weight_processing_conversions[placeholder_param_name] + matched_key = placeholder_param_name + elif placeholder_param_name.endswith(".weight"): + stripped = placeholder_param_name[: -len(".weight")] + if stripped in adapter.weight_processing_conversions: + matched_key = stripped + + if matched_key is not None: + param_conversion = adapter.weight_processing_conversions[matched_key] # Handle both ParamProcessingConversion objects and legacy string mappings if isinstance(param_conversion, str): @@ -1624,9 +1774,32 @@ def convert_tensor_to_tl_format( param_name, param_conversion.source_key ) if resolved_key not in model_state_dict and tensor is not None: - return param_conversion.tensor_conversion.convert( - tensor, model_state_dict + # Source key not in state dict — the tensor is already in + # bridge format (e.g. already split from combined QKV). + # If the conversion is a ChainTensorConversion that includes + # a SplitTensorConversion, skip the split step since + # it was already applied during bridge construction. + from transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion import ( + ChainTensorConversion, + ) + from transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion import ( + SplitTensorConversion, ) + + tc = param_conversion.tensor_conversion + if isinstance(tc, ChainTensorConversion): + non_split = [ + c + for c in tc.conversions + if not isinstance(c, SplitTensorConversion) + ] + if len(non_split) < len(tc.conversions): + # Apply only the non-split conversions + result = tensor + for conv in non_split: + result = conv.handle_conversion(result, model_state_dict) + return result + return tc.convert(tensor, model_state_dict) return param_conversion.convert(model_state_dict, param_name) else: # No conversion defined, return tensor as-is (may be None for optional params) @@ -1696,16 +1869,45 @@ def convert_tensor_to_hf_format( if "blocks." in param_name: placeholder_param_name = re.sub(r"blocks\.(\d+)\.", "blocks.{i}.", param_name) - # Check if we have a conversion for this parameter + # Check if we have a conversion for this parameter. + # Try exact match first, then strip .weight suffix (not .bias — see convert_tensor_to_tl_format). + matched_key = None if placeholder_param_name in adapter.weight_processing_conversions: - param_conversion = adapter.weight_processing_conversions[placeholder_param_name] + matched_key = placeholder_param_name + elif placeholder_param_name.endswith(".weight"): + stripped = placeholder_param_name[: -len(".weight")] + if stripped in adapter.weight_processing_conversions: + matched_key = stripped + + if matched_key is not None: + param_conversion = adapter.weight_processing_conversions[matched_key] # Handle both ParamProcessingConversion objects and legacy string mappings if isinstance(param_conversion, str): # Legacy string mapping - just return the tensor as-is return tensor else: - # Use ParamProcessingConversion to handle reverting + # Revert the conversion. For ChainTensorConversions that include + # SplitTensorConversion, skip the split revert step (which is a + # no-op anyway) to match the forward conversion path. + from transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion import ( + ChainTensorConversion, + ) + from transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion import ( + SplitTensorConversion, + ) + + tc = param_conversion.tensor_conversion + if isinstance(tc, ChainTensorConversion): + non_split = [ + c for c in tc.conversions if not isinstance(c, SplitTensorConversion) + ] + if len(non_split) < len(tc.conversions): + # Revert only the non-split conversions in reverse order + result = tensor + for conv in reversed(non_split): + result = conv.revert(result) + return result return param_conversion.revert(tensor) else: return tensor