Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -435,6 +437,44 @@ def test_init_xavier_normal(self, d_model, d_mlp):
assert torch.allclose(x_new, x, rtol=1e-2)


def test_tokenize_and_concatenate_no_spurious_sequence_length_warning():
"""Test that tokenize_and_concatenate does not emit the HF 'sequence length longer than maximum' warning."""
from datasets import Dataset
from transformers import AutoTokenizer

# Use a tokenizer with model_max_length and EOS
tokenizer = AutoTokenizer.from_pretrained("t5-small")
assert tokenizer.model_max_length == 512
assert tokenizer.eos_token is not None

# Long text so that when split into 20 chunks, at least one chunk tokenizes to > 512 tokens
long_text = "word " * 20000
dataset = Dataset.from_dict({"text": [long_text]})

with warnings.catch_warnings(record=True) as recorded:
warnings.simplefilter("always")
result = utils.tokenize_and_concatenate(
dataset,
tokenizer,
max_length=tokenizer.model_max_length,
add_bos_token=False,
streaming=True,
)

# No warning about sequence length exceeding model maximum
for w in recorded:
msg = str(w.message)
assert (
"longer than the specified maximum" not in msg
), f"tokenize_and_concatenate should not emit sequence-length warning; got: {msg}"

# Sanity: output has expected shape (batch, max_length); result is a Dataset
assert len(result) >= 1
first_row = result[0]["tokens"]
assert first_row.shape[0] == tokenizer.model_max_length
assert first_row.dim() == 1


def test_tokenize_and_concatenate_short_sequence_no_invalid_tokens():
"""
When the tokenizer has no pad token, output should only contain token IDs in the model's vocab.
Expand Down
129 changes: 73 additions & 56 deletions transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,64 +352,81 @@ def tokenize_and_concatenate(
if not has_pad_token:
# We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model.
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
# Define the length to chop things up into - leaving space for a bos_token if required
if add_bos_token:
seq_len = max_length - 1
else:
seq_len = max_length

def tokenize_function(examples: Any) -> dict[str, np.ndarray]:
# datasets.map() may pass a LazyBatch, not a plain dict; accept dict-like batches
text = examples[column_name]
# Concatenate it all into an enormous string, separated by eos_tokens
assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token."
full_text = tokenizer.eos_token.join(text)

# Handle the case when full_text is empty
if not full_text.strip():
return {"tokens": np.array([], dtype=np.int64)}

# Divide into 20 chunks of ~ equal length
num_chunks = 20
chunk_length = (len(full_text) - 1) // num_chunks + 1
chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)]
# Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned
tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten()
# Drop padding tokens
tokens = tokens[tokens != tokenizer.pad_token_id]
num_tokens = len(tokens)

# Handle cases where num_tokens is less than seq_len
if num_tokens < seq_len:
num_batches = 1
# Pad tokens if necessary. Use eos_token_id if the model has no pad token.
tokens = tokens[:seq_len]
if len(tokens) < seq_len:
padding_length = seq_len - len(tokens)
padding_id = tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id
padding = np.full(padding_length, padding_id)
tokens = np.concatenate([tokens, padding], axis=0)
else:
num_batches = num_tokens // seq_len
# Drop the final tokens if not enough to make a full sequence
tokens = tokens[: seq_len * num_batches]

tokens = einops.rearrange(
tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
)
# Suppress the "sequence length longer than maximum" warning during chunked tokenization.
_deprecation_warnings_saved = None
if hasattr(tokenizer, "deprecation_warnings"):
_deprecation_warnings_saved = tokenizer.deprecation_warnings.copy()
tokenizer.deprecation_warnings[
"sequence-length-is-longer-than-the-specified-maximum"
] = False
try:
# Define the length to chop things up into - leaving space for a bos_token if required
if add_bos_token:
prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
tokens = np.concatenate([prefix, tokens], axis=1)
return {"tokens": tokens}

tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
num_proc=(num_proc if not streaming else None),
remove_columns=[column_name],
)
tokenized_dataset.set_format(type="torch", columns=["tokens"])
return tokenized_dataset
seq_len = max_length - 1
else:
seq_len = max_length

def tokenize_function(examples: Any) -> dict[str, np.ndarray]:
# datasets.map() may pass a LazyBatch, not a plain dict; accept dict-like batches
text = examples[column_name]
# Concatenate it all into an enormous string, separated by eos_tokens
assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token."
full_text = tokenizer.eos_token.join(text)

# Handle the case when full_text is empty
if not full_text.strip():
return {"tokens": np.array([], dtype=np.int64)}

# Divide into 20 chunks of ~ equal length
num_chunks = 20
chunk_length = (len(full_text) - 1) // num_chunks + 1
chunks = [
full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)
]
# Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned
tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten()
# Drop padding tokens
tokens = tokens[tokens != tokenizer.pad_token_id]
num_tokens = len(tokens)

# Handle cases where num_tokens is less than seq_len
if num_tokens < seq_len:
num_batches = 1
# Pad tokens if necessary
tokens = tokens[:seq_len]
if len(tokens) < seq_len:
padding_length = seq_len - len(tokens)
padding_id = (
tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id
)
padding = np.full(padding_length, padding_id)
tokens = np.concatenate([tokens, padding], axis=0)
else:
num_batches = num_tokens // seq_len
# Drop the final tokens if not enough to make a full sequence
tokens = tokens[: seq_len * num_batches]

tokens = einops.rearrange(
tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
)
if add_bos_token:
prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
tokens = np.concatenate([prefix, tokens], axis=1)
return {"tokens": tokens}

tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
num_proc=(num_proc if not streaming else None),
remove_columns=[column_name],
)
tokenized_dataset.set_format(type="torch", columns=["tokens"])
return tokenized_dataset
finally:
if _deprecation_warnings_saved is not None:
tokenizer.deprecation_warnings.clear()
tokenizer.deprecation_warnings.update(_deprecation_warnings_saved)


def sample_logits(
Expand Down
Loading