diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 2dc861a57..02ee73807 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np import pytest import torch @@ -435,6 +437,44 @@ def test_init_xavier_normal(self, d_model, d_mlp): assert torch.allclose(x_new, x, rtol=1e-2) +def test_tokenize_and_concatenate_no_spurious_sequence_length_warning(): + """Test that tokenize_and_concatenate does not emit the HF 'sequence length longer than maximum' warning.""" + from datasets import Dataset + from transformers import AutoTokenizer + + # Use a tokenizer with model_max_length and EOS + tokenizer = AutoTokenizer.from_pretrained("t5-small") + assert tokenizer.model_max_length == 512 + assert tokenizer.eos_token is not None + + # Long text so that when split into 20 chunks, at least one chunk tokenizes to > 512 tokens + long_text = "word " * 20000 + dataset = Dataset.from_dict({"text": [long_text]}) + + with warnings.catch_warnings(record=True) as recorded: + warnings.simplefilter("always") + result = utils.tokenize_and_concatenate( + dataset, + tokenizer, + max_length=tokenizer.model_max_length, + add_bos_token=False, + streaming=True, + ) + + # No warning about sequence length exceeding model maximum + for w in recorded: + msg = str(w.message) + assert ( + "longer than the specified maximum" not in msg + ), f"tokenize_and_concatenate should not emit sequence-length warning; got: {msg}" + + # Sanity: output has expected shape (batch, max_length); result is a Dataset + assert len(result) >= 1 + first_row = result[0]["tokens"] + assert first_row.shape[0] == tokenizer.model_max_length + assert first_row.dim() == 1 + + def test_tokenize_and_concatenate_short_sequence_no_invalid_tokens(): """ When the tokenizer has no pad token, output should only contain token IDs in the model's vocab. diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 73693860f..414a6b99f 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -352,64 +352,81 @@ def tokenize_and_concatenate( if not has_pad_token: # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model. tokenizer.add_special_tokens({"pad_token": ""}) - # Define the length to chop things up into - leaving space for a bos_token if required - if add_bos_token: - seq_len = max_length - 1 - else: - seq_len = max_length - - def tokenize_function(examples: Any) -> dict[str, np.ndarray]: - # datasets.map() may pass a LazyBatch, not a plain dict; accept dict-like batches - text = examples[column_name] - # Concatenate it all into an enormous string, separated by eos_tokens - assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token." - full_text = tokenizer.eos_token.join(text) - - # Handle the case when full_text is empty - if not full_text.strip(): - return {"tokens": np.array([], dtype=np.int64)} - - # Divide into 20 chunks of ~ equal length - num_chunks = 20 - chunk_length = (len(full_text) - 1) // num_chunks + 1 - chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)] - # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned - tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() - # Drop padding tokens - tokens = tokens[tokens != tokenizer.pad_token_id] - num_tokens = len(tokens) - - # Handle cases where num_tokens is less than seq_len - if num_tokens < seq_len: - num_batches = 1 - # Pad tokens if necessary. Use eos_token_id if the model has no pad token. - tokens = tokens[:seq_len] - if len(tokens) < seq_len: - padding_length = seq_len - len(tokens) - padding_id = tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id - padding = np.full(padding_length, padding_id) - tokens = np.concatenate([tokens, padding], axis=0) - else: - num_batches = num_tokens // seq_len - # Drop the final tokens if not enough to make a full sequence - tokens = tokens[: seq_len * num_batches] - tokens = einops.rearrange( - tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len - ) + # Suppress the "sequence length longer than maximum" warning during chunked tokenization. + _deprecation_warnings_saved = None + if hasattr(tokenizer, "deprecation_warnings"): + _deprecation_warnings_saved = tokenizer.deprecation_warnings.copy() + tokenizer.deprecation_warnings[ + "sequence-length-is-longer-than-the-specified-maximum" + ] = False + try: + # Define the length to chop things up into - leaving space for a bos_token if required if add_bos_token: - prefix = np.full((num_batches, 1), tokenizer.bos_token_id) - tokens = np.concatenate([prefix, tokens], axis=1) - return {"tokens": tokens} - - tokenized_dataset = dataset.map( - tokenize_function, - batched=True, - num_proc=(num_proc if not streaming else None), - remove_columns=[column_name], - ) - tokenized_dataset.set_format(type="torch", columns=["tokens"]) - return tokenized_dataset + seq_len = max_length - 1 + else: + seq_len = max_length + + def tokenize_function(examples: Any) -> dict[str, np.ndarray]: + # datasets.map() may pass a LazyBatch, not a plain dict; accept dict-like batches + text = examples[column_name] + # Concatenate it all into an enormous string, separated by eos_tokens + assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token." + full_text = tokenizer.eos_token.join(text) + + # Handle the case when full_text is empty + if not full_text.strip(): + return {"tokens": np.array([], dtype=np.int64)} + + # Divide into 20 chunks of ~ equal length + num_chunks = 20 + chunk_length = (len(full_text) - 1) // num_chunks + 1 + chunks = [ + full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks) + ] + # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned + tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() + # Drop padding tokens + tokens = tokens[tokens != tokenizer.pad_token_id] + num_tokens = len(tokens) + + # Handle cases where num_tokens is less than seq_len + if num_tokens < seq_len: + num_batches = 1 + # Pad tokens if necessary + tokens = tokens[:seq_len] + if len(tokens) < seq_len: + padding_length = seq_len - len(tokens) + padding_id = ( + tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id + ) + padding = np.full(padding_length, padding_id) + tokens = np.concatenate([tokens, padding], axis=0) + else: + num_batches = num_tokens // seq_len + # Drop the final tokens if not enough to make a full sequence + tokens = tokens[: seq_len * num_batches] + + tokens = einops.rearrange( + tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len + ) + if add_bos_token: + prefix = np.full((num_batches, 1), tokenizer.bos_token_id) + tokens = np.concatenate([prefix, tokens], axis=1) + return {"tokens": tokens} + + tokenized_dataset = dataset.map( + tokenize_function, + batched=True, + num_proc=(num_proc if not streaming else None), + remove_columns=[column_name], + ) + tokenized_dataset.set_format(type="torch", columns=["tokens"]) + return tokenized_dataset + finally: + if _deprecation_warnings_saved is not None: + tokenizer.deprecation_warnings.clear() + tokenizer.deprecation_warnings.update(_deprecation_warnings_saved) def sample_logits(