Conversation
Support ingesting raw audio directories into WSDS shards via the new shard_from_audio_dir command. Update extract_index_for_shard to infer audio duration from metadata instead of requiring pre-computed fields, and broaden the torchcodec fallback to catch all exceptions (not just ImportError).
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…hard Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds a new CLI command to generate WSDS audio shards from a directory of audio files, improves audio-duration inference for indexing, relaxes torchcodec import fallback behavior, and introduces a pytest-based test suite + packaging/test config updates.
Changes:
- Add
shard_from_audio_dircommand to write.wsdsaudio shards from a filesystem directory (with key customization + optional key mapping). - Update indexing to optionally require audio duration and infer it more robustly from audio metadata.
- Add pytest suite for the new command; update wheel packaging and pytest configuration in
pyproject.toml.
Reviewed changes
Copilot reviewed 4 out of 5 changed files in this pull request and generated 12 comments.
| File | Description |
|---|---|
wsds/ws_tools.py |
Adds shard_from_audio_dir, refactors shard indexing to infer audio duration from metadata, and threads require_audio_duration into index creation. |
wsds/ws_audio.py |
Broadens torchcodec import fallback to catch broader failures and fall back to compat decoder. |
tests/test_shard_from_audio.py |
New pytest suite validating shard writing behavior (keys, naming, skipping oversized, subdirs). |
pyproject.toml |
Fix wheel package inclusion (packages = ["wsds"]), add test extras and pytest discovery config. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| print(f"[DONE] Wrote {shard_idx} WSDS shards -> {output_dir}") | ||
|
|
||
| dataset_root = output_dir.parent if output_dir.name == "audio" else output_dir | ||
|
|
||
| if key_mapping: | ||
| mapping_path = dataset_root / "key_mapping.json" | ||
| with open(mapping_path, "w") as f: | ||
| json.dump(key_mapping, f, indent=2) | ||
| print(f"[INFO] Wrote key mapping ({len(key_mapping):,} entries) -> {mapping_path}") | ||
|
|
||
| if init_index: | ||
| init(dataset_root, require_audio_duration=require_audio_duration) |
There was a problem hiding this comment.
When init_index=True, init() is run on dataset_root, but if output_dir is not an audio/ subdirectory the shards are written into dataset_root itself (not into a column subdir). init()/list_all_shards() only discovers shards within subdirectories, so indexing will likely produce an empty index. Consider forcing output into dataset_root/audio/ (or erroring unless output_dir.name == "audio") when init_index is requested.
There was a problem hiding this comment.
@copilot open a new pull request to apply changes based on this feedback
| exts = (".wav", ".flac", ".mp3", ".m4a", ".ogg", ".opus") | ||
| all_files = sorted(p for p in input_dir.rglob("*") if p.suffix.lower() in exts) | ||
| print(f"[INFO] Found {len(all_files):,} audio files under {input_dir}") |
There was a problem hiding this comment.
Building all_files = sorted(input_dir.rglob(...)) loads and sorts the entire file list in memory up-front. For large audio corpora this can be a significant memory/time hit. Consider iterating without sorting (or sorting per-directory), or making sorting optional behind a flag if deterministic ordering is needed.
There was a problem hiding this comment.
@copilot open a new pull request to apply changes based on this feedback
| if self.reader is None or sample_rate_switch: | ||
| try: | ||
| from torchcodec.decoders import AudioDecoder | ||
| except ImportError: | ||
| except Exception: | ||
| AudioDecoder = CompatAudioDecoder | ||
|
|
There was a problem hiding this comment.
Catching Exception around the torchcodec import will also swallow unrelated issues (e.g., internal bugs in torchcodec or environment problems) and silently fall back to the torchaudio-based decoder, making failures harder to diagnose. Consider catching a narrower set of exceptions (e.g., ImportError, OSError, RuntimeError) and/or logging the exception when falling back.
There was a problem hiding this comment.
@copilot open a new pull request to apply changes based on this feedback
| # Patch the read_bytes to attach a fake size, then patch len check via | ||
| # a wrapper around shard_from_audio_dir that lowers MAX_ARROW_BYTES. | ||
| # Since MAX_ARROW_BYTES is a local, we instead wrap the whole function | ||
| # by replacing it with one that sets a lower limit. | ||
| import wsds.ws_tools as mod | ||
|
|
||
| orig_code = mod.shard_from_audio_dir.__code__ | ||
|
|
||
| # Replace the constant in the code object's co_consts | ||
| new_consts = tuple( | ||
| fake_limit if c == 2_140_000_000 else c for c in orig_code.co_consts | ||
| ) | ||
| new_code = orig_code.replace(co_consts=new_consts) | ||
| monkeypatch.setattr(mod.shard_from_audio_dir, "__code__", new_code) | ||
|
|
There was a problem hiding this comment.
This test mutates shard_from_audio_dir.__code__.co_consts to change a local constant. This is brittle (depends on CPython implementation details and the constant appearing exactly once) and can break with small refactors. Prefer making the max-arrow-bytes limit injectable (e.g., a parameter or a module-level constant) so the test can monkeypatch it safely.
| # Patch the read_bytes to attach a fake size, then patch len check via | |
| # a wrapper around shard_from_audio_dir that lowers MAX_ARROW_BYTES. | |
| # Since MAX_ARROW_BYTES is a local, we instead wrap the whole function | |
| # by replacing it with one that sets a lower limit. | |
| import wsds.ws_tools as mod | |
| orig_code = mod.shard_from_audio_dir.__code__ | |
| # Replace the constant in the code object's co_consts | |
| new_consts = tuple( | |
| fake_limit if c == 2_140_000_000 else c for c in orig_code.co_consts | |
| ) | |
| new_code = orig_code.replace(co_consts=new_consts) | |
| monkeypatch.setattr(mod.shard_from_audio_dir, "__code__", new_code) | |
| # Patch the Arrow byte-limit via a module-level constant so that | |
| # files larger than fake_limit are skipped. | |
| import wsds.ws_tools as mod | |
| # Override the max-bytes limit used by shard_from_audio_dir. | |
| monkeypatch.setattr(mod, "MAX_ARROW_BYTES", fake_limit) |
| vad_column=vad_column, | ||
| require_audio_duration=require_audio_duration, | ||
| ) | ||
| all_shards = ds.get_shard_list() |
There was a problem hiding this comment.
init() now calls ds.get_shard_list() without ignore_index=True. If the source dataset already has an index (possibly stale/incomplete), this will rebuild the new index from the existing index’s shard list rather than from the shards on disk, which can silently miss shards. Consider using get_shard_list(ignore_index=True) here (as init_split already does), or add an explicit flag to control whether to trust the existing index.
| all_shards = ds.get_shard_list() | |
| all_shards = ds.get_shard_list(ignore_index=True) |
| key = s["__key__"] | ||
| sample = WSSample(ds, (dataset_path, shard_name), 0) | ||
|
|
||
| for s in ds.sequential_from(sample, 0): |
There was a problem hiding this comment.
ds.sequential_from(sample, 0) relies on the falsy behavior of 0 to mean “no max_N limit” inside sequential_from (max_N or sys.maxsize). This is non-obvious and easy to break if sequential_from changes. Prefer passing max_N=None (or omitting the argument) for clarity.
| for s in ds.sequential_from(sample, 0): | |
| for s in ds.sequential_from(sample): |
| # a wrapper around shard_from_audio_dir that lowers MAX_ARROW_BYTES. | ||
| # Since MAX_ARROW_BYTES is a local, we instead wrap the whole function | ||
| # by replacing it with one that sets a lower limit. | ||
| import wsds.ws_tools as mod |
There was a problem hiding this comment.
Module 'wsds.ws_tools' is imported with both 'import' and 'import from'.
| if data_bytes > 0 and bytes_per_sample > 0: | ||
| num_samples = data_bytes // bytes_per_sample | ||
| return float(num_samples) / float(sample_rate) | ||
| except Exception: |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| except Exception: | |
| except Exception: | |
| # If deriving duration from raw audio bytes fails for any reason, | |
| # fall back to returning None so callers can handle missing duration. |
Co-authored-by: tlebryk <43556997+tlebryk@users.noreply.github.com>
Fix init_index to create shards in audio/ subdirectory
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 5 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| avoid collisions when processing multiple directories with | ||
| files that share the same names (e.g., "egyptian", "saudi"). | ||
| """ | ||
| from tqdm import tqdm |
There was a problem hiding this comment.
shard_from_audio_dir imports tqdm unconditionally, but tqdm is not in the base dependencies (requirements.txt). This makes the CLI command fail for standard installs. Either add tqdm to the main dependencies or make the progress bar optional (fallback to plain iteration if tqdm isn’t available).
| from tqdm import tqdm | |
| try: | |
| from tqdm import tqdm # optional progress bar dependency | |
| except ImportError: | |
| def tqdm(iterable, *args, **kwargs): | |
| return iterable |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Add shard_from_audio_dir command to create .wsds shards from directories of audio files.
Add robust audio duration inference that doesn't rely on specific metadata field names.
Broaden the torchcodec import fallback from ImportError to Exception to handle driver failures. Remove unused webdataset import.
Add pytest test suite for shard_from_audio_dir with inline WAV generation (no fixture files). Fix wheel config (packages = ["wsds"]) and add test optional-dependencies.