From 5a1e8cf600c5869e9eab9a2b331ce876a7e94eb0 Mon Sep 17 00:00:00 2001 From: Evelyn Yen Date: Tue, 17 Feb 2026 20:49:56 -0500 Subject: [PATCH] fix return output --- transformer_lens/hook_points.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index 612b6f6d2..7585fa431 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -93,11 +93,12 @@ def full_hook( module_input: Any, module_output: Any, ): - if ( - dir == "bwd" - ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why. - module_output = module_output[0] - return hook(module_output, hook=self) + # For a backwards hook, module_output is a tuple of (grad,) + hook_arg = module_output[0] if dir == "bwd" else module_output + result = hook(hook_arg, hook=self) + if dir == "bwd" and result is not None: + return result if isinstance(result, tuple) and len(result) == 1 else (result,) + return result # annotate the `full_hook` with the string representation of the `hook` function if isinstance(hook, partial):