Skip to content

Web gpu#37

Merged
36000 merged 12 commits intodipy:masterfrom
neurolabusc:WebGPU
Mar 11, 2026
Merged

Web gpu#37
36000 merged 12 commits intodipy:masterfrom
neurolabusc:WebGPU

Conversation

@neurolabusc
Copy link
Contributor

Provide support for Metal shaders on Apple silicon and WebGPU across architectures. See the WebGPU README for details and benchmarks on both Apple Silicon and NVidia GPUs. Highlights include:

  • Optional Soft Angular Weighting biases trajectory continuation toward directions close to the current heading, applied only at voxels with multi-directional ODFs. This reflects the physiological prior that white matter fibers follow smooth, gradually curving paths — axons do not make sharp turns. In ambiguous voxels where fiber bundles run in different directions (the "kissing vs. crossing" problem), this weighting favors the interpretation most consistent with the incoming trajectory, reducing spurious crossings into adjacent bundles.
  • SH basis fix in cu_direction_getters.py and the shared boot_utils.py module also benefits the CUDA path. This should make GPU and CPU results more similar.

Note: These backend were written by Claude Code (Anthropic's AI coding agent), with architectural direction, validation, and iterative review from me. The port is intentionally derivative — it mirrors the naming conventions, file structure, and two-pass kernel architecture of the existing CUDA backend to minimize cognitive overhead for contributors familiar with the codebase.

neurolabusc and others added 2 commits March 3, 2026 08:29
Adds two new GPU backends alongside the existing CUDA backend:

- **Metal** (Apple Silicon): Zero-copy unified memory, MSL shaders.
  ~95x faster than CPU at 100k seeds on M4 Pro.

- **WebGPU** (cross-platform via wgpu-py): WGSL shaders, runs on
  NVIDIA/AMD/Intel/Apple via Vulkan/D3D12/Metal. ~47x faster than
  CPU. Produces bit-identical results to Metal on shared hardware.

Both backends implement all three direction getters (Boot, Prob, PTT)
and share bootstrap matrix preparation via new boot_utils.py module.

Backend auto-detection priority: Metal → CUDA → WebGPU.
Install: pip install "cuslines[metal]" or "cuslines[webgpu]"

Includes cross-backend benchmark script:
  python -m cuslines.webgpu.benchmark --nseeds 10000

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add 100k seed benchmark results for NVIDIA RTX 4090 + Threadripper PRO
7995WX to WebGPU README: CUDA 2.9s (273x), WebGPU 19.3s (41x), CPU 783s.
Add /proc/cpuinfo parsing for Linux CPU name in benchmark script.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@neurolabusc neurolabusc mentioned this pull request Mar 3, 2026
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@skoudoro
Copy link
Member

skoudoro commented Mar 3, 2026

really exciting work!

make sense that webgpu is slower than native but the multiplatform part is so important for DIPY.

We change our tracking framework and currently, Cuslines do not follow it. We do not use LocalTracking anymore and it will be deprecated on the next release.

this can be change later on.

thanks @neurolabusc!

I will look into deeper later on

@neurolabusc
Copy link
Contributor Author

@skoudoro I did use the CPU LocalTracking to validate against. Can you extend the benchmark I use to provide the modern tracking framework. It is nice to have a gold standard to compare against. I do think that the quality of the streamlines can be improved with the soft angular weighting function that is included in my Metal solution but not implemented by default. It may also be nice to have validation datasets that exhibit edge cases that we can test against. I did try to create quality ratings based on mean/median fiber length, and number of commissural fibers. However, I would be interested if you have more refined quality metrics and tests. If we can include these, I think we can explore optimizations that influence speed without sacrificing quality.

@36000
Copy link
Collaborator

36000 commented Mar 4, 2026

Thus far GPUStreamlines has not attempted to mirror the DIPY API, just be compatible with it, but we can change that if people think that makes sense.

For validation, I will update pyAFQ to include this and see if it can track all the major bundles using our default settings.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@36000
Copy link
Collaborator

36000 commented Mar 4, 2026

@neurolabusc

I tried running this on a subject from HBN and using asymmetric ODFs. I caused the WebGPU approach to crash at this line:

        bg0 = device.create_bind_group(
            layout=self.getnum_pipeline.get_bind_group_layout(0),
            entries=[
                {"binding": 0, "resource": {"buffer": params_buf}},
                {"binding": 1, "resource": {"buffer": sp.seeds_buf}},
                {"binding": 2, "resource": {"buffer": gt.dataf_buf}},
                {"binding": 4, "resource": {"buffer": gt.sphere_vertices_buf}},
                {"binding": 5, "resource": {"buffer": gt.sphere_edges_buf}},
            ],
        )

With this error:

  File "/home/john/pyAFQ/AFQ/tasks/tractography.py", line 293, in gpu_tractography
    sft = gpu_track(
          ^^^^^^^^^^
  File "/home/john/pyAFQ/AFQ/tractography/gputractography.py", line 221, in gpu_track
    return gpu_tracker.generate_trx(seeds, seed_img)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/GPUStreamlines/cuslines/webgpu/wg_tractography.py", line 226, in generate_trx
    self.seed_propagator.propagate(chunk)
  File "/home/john/GPUStreamlines/cuslines/webgpu/wg_propagate_seeds.py", line 152, in propagate
    self.gpu_tracker.dg.getNumStreamlines(nseeds, block, grid, self)
  File "/home/john/GPUStreamlines/cuslines/webgpu/wg_direction_getters.py", line 180, in getNumStreamlines
    bg0 = device.create_bind_group(
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/miniforge3/envs/afq/lib/python3.12/site-packages/wgpu/backends/wgpu_native/_api.py", line 1784, in create_bind_group
    id = libf.wgpuDeviceCreateBindGroup(self._internal, struct)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/miniforge3/envs/afq/lib/python3.12/site-packages/wgpu/backends/wgpu_native/_helpers.py", line 373, in proxy_func
    raise wgpu_error
  File "/home/john/miniforge3/envs/afq/lib/python3.12/site-packages/wgpu/backends/wgpu_native/_api.py", line 1784, in create_bind_group
    id = libf.wgpuDeviceCreateBindGroup(self._internal, struct)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
wgpu._classes.GPUValidationError: Validation Error

Caused by:
  In wgpuDeviceCreateBindGroup
    Buffer binding 2 range 2198915424 exceeds `max_*_buffer_binding_size` limit 2147483647

I looked into it and that number is the size of the gt.dataf_buf alone
So it looks like for large enough subjects with complicated enough ODFs, this data buffer is too big for binding? I am not sure about how this web GPU stuff works.

@36000
Copy link
Collaborator

36000 commented Mar 4, 2026

I switched to using the DIPY small sphere instead of default sphere to discretize the SH model. Now it runs on my GPU. It is roughly 5x slower than using CUDA, but this is to be expected (roughly 5 million streamlines in 5 minutes on my GPU for this HBN subject, roughly 1 minute 15 seconds using CUDA). It is still quite fast, much faster than CPU tracking single threaded. I haven't tested it (and it would depend on the hardware) but my guess is you would need to run DIPY's CPU code with 8-16 threads on a decent CPU to be competitive with the web GPU implementation.

In terms of results, they are extremely similar to CUDA, so I would say we can merge this once we solve the above problem.
CUDA.csv
WebGPU.csv

@neurolabusc
Copy link
Contributor Author

@36000 can you provide a way for me to reproduce your crash. For example, a sample dataset and a script to demonstrate the behavior. It would be ideal to have a simple example that works with the CUDA backend but fails with WebGPU. You can think of webgpu as the lowest common denominator of the languages that support is, so beyond unified memory, there are a lot of other inherent limitations vs Metal or CUDA. So I have faced include Max Storage Buffers Per Shader Stage (8 vs 31 in metal), Max Compute Workgroups Per Dimension = 65535, No Float Atomics, No Storage Pointers as Function Parameters. The good news is Claude can typically find a way to get things to work, albeit with some performance penalties.

@36000
Copy link
Collaborator

36000 commented Mar 9, 2026

@neurolabusc I believe I sent you an email regarding this, maybe it got sent to spam? Either way, I think it is better to show it here:

If you install pyAFQ it is relatively easy to download the HBN subject in BIDS format:

import AFQ.data.fetch as afd

bids_path = afd.fetch_hbn_preproc(
    ["NDARAA948VFH"])[1]

Then I ran pyAFQ to hit the error, but I was able to reproduce the same error with this minimal example using just GPUStreamlines (no pyAFQ required):

python run_gpu_streamlines.py ~/AFQ_data/HBN/derivatives/qsiprep/sub-NDARAA948VFH/ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-preproc_dwi.nii.gz ~/AFQ_data/HBN/derivatives/qsiprep/sub-NDARAA948VFH/ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-preproc_dwi.bval ~/AFQ_data/HBN/derivatives/qsiprep/sub-NDARAA948VFH/ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-preproc_dwi.bvec ~/AFQ_data/HBN/derivatives/qsiprep/sub-NDARAA948VFH/ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-brain_mask.nii.gz ~/AFQ_data/HBN/derivatives/qsiprep/sub-NDARAA948VFH/ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-brain_mask.nii.gz --device webgpu --dg prob

I believe this will be important to fix as the stanford hardi test dataset we use is quite small, and HBN may be more representative of the size of normal datasets.

Request the adapter's maximum max-buffer-size and
max-storage-buffer-binding-size when creating the WebGPU device.
Without this, the device gets WebGPU spec defaults (256 MB buffer,
128 MB binding) which are too small for real-world diffusion MRI
datasets — e.g. HBN with CSD asymmetric ODFs produces a ~5 GB
dataf buffer that exceeds the default limit.

Also adds buffer size validation with clear error messages in both
WebGPU and Metal backends if the data exceeds the device limit.

Fixes dipy#37

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@neurolabusc
Copy link
Contributor Author

@36000 Can you please test the latest commit. I believe this resolves your issue. However, there seems to be an issue with pyAFQ and my MacOS:

pip install pyAFQ
Requirement already satisfied: pyAFQ in /Users/chris/py312/lib/python3.12/site-packages (2.1)

but I get an error:

% python
Python 3.12.12 (main, Oct  9 2025, 11:07:00) [Clang 17.0.0 (clang-1700.0.13.3)] on darwin
Type "help", "copyright", "credits" or "license" for more information.

>>> import AFQ.data.fetch as afd
>>> bids_path = afd.fetch_hbn_preproc(
...     ["NDARAA948VFH"])[1]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/chris/py312/lib/python3.12/site-packages/AFQ/data/fetch.py", line 1678, in fetch_hbn_preproc
    client = boto3.client('s3', config=Config(signature_version=UNSIGNED))
             ^^^^^
UnboundLocalError: cannot access local variable 'boto3' where it is not associated with a value

@neurolabusc
Copy link
Contributor Author

Actually, the issue seems to be with the dependencies associated with the pyAFQ, I can get the file if I explicitly install boto3:

pip install boto3

neurolabusc and others added 3 commits March 9, 2026 16:48
- Add WebGPU installation and usage to root README.md
- Fix typos in README ("donaload" → "download", "both CUDA and Metal" → all 3)
- Clean up partial GPU buffers on allocation failure in WebGPU tracker
- Vectorize Metal get_buffer_size() to match WebGPU (numpy instead of loop)
- Complete CUDA file tree in CLAUDE.md

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@36000
Copy link
Collaborator

36000 commented Mar 10, 2026

Hello @neuroblasc, I have fixed some unrelated bugs in the python example script using this PR, I hope you do not mind. I have also added a small sphere option and discuss this in the informative error. Now I run:

 python run_gpu_streamlines.py ~/AFQ_data/HBN/derivatives/qsiprep/sub-NDARAA948VFH/ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-preproc_dwi.nii.gz ~/AFQ_data/HBN/derivatives/qsiprep/sub-NDARAA948VFH/ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-preproc_dwi.bval ~/AFQ_data/HBN/derivatives/qsiprep/sub-NDARAA948VFH/ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-preproc_dwi.bvec ~/AFQ_data/HBN/derivatives/qsiprep/sub-NDARAA948VFH/ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-brain_mask.nii.gz ~/AFQ_data/HBN/derivatives/qsiprep/sub-NDARAA948VFH/ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-brain_mask.nii.gz --device webgpu --dg prob --sphere small

And get:

parsing arguments
Using webgpu backend
Fitting Tensor
Computing anisotropy measures (FA,MD,RGB)
Running CSD model...
Unable to find extension: VK_EXT_physical_device_drm
 0%|                                                                               | 0/100000 [00:00<?, ?it/s]
Traceback (most recent call last):
 File "/home/john/GPUStreamlines/run_gpu_streamlines.py", line 285, in <module>
   sft = gpu_tracker.generate_sft(seed_mask, img)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/john/GPUStreamlines/cuslines/webgpu/wg_tractography.py", line 222, in generate_sft
   self.seed_propagator.propagate(chunk)
 File "/home/john/GPUStreamlines/cuslines/webgpu/wg_propagate_seeds.py", line 165, in propagate
   self.gpu_tracker.dg.generateStreamlines(nseeds, block, grid, self)
 File "/home/john/GPUStreamlines/cuslines/webgpu/wg_direction_getters.py", line 224, in generateStreamlines
   bg1 = device.create_bind_group(
         ^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/john/miniforge3/envs/afq/lib/python3.12/site-packages/wgpu/backends/wgpu_native/_api.py", line 1784, in create_bind_group
   id = libf.wgpuDeviceCreateBindGroup(self._internal, struct)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/john/miniforge3/envs/afq/lib/python3.12/site-packages/wgpu/backends/wgpu_native/_helpers.py", line 373, in proxy_func
   raise wgpu_error
 File "/home/john/miniforge3/envs/afq/lib/python3.12/site-packages/wgpu/backends/wgpu_native/_api.py", line 1784, in create_bind_group
   id = libf.wgpuDeviceCreateBindGroup(self._internal, struct)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
wgpu._classes.GPUValidationError: Validation Error

Caused by:
 In wgpuDeviceCreateBindGroup
   Buffer binding 4 range 2909266920 exceeds `max_*_buffer_binding_size` limit 2147483647

So, I think we need informative errors for each device bind group. This is not a particularly large dataset, so I think this would be a common error for users.

@36000
Copy link
Collaborator

36000 commented Mar 10, 2026

@neurolabusc Also I do not believe its being limited by the GPU's VRAM, because my GPU has 16GB of VRAM but still is hitting this error. So maybe we should remove that from the error description?

Anyways, setting chunk size to 25000 fixed this latest error and it ran. This is the default elsewhere so I will just make it the new default in this example script.

@36000
Copy link
Collaborator

36000 commented Mar 10, 2026

I also invited you as a maintainer for this repo, so you can merge the code whenever you feel it is ready!

neurolabusc and others added 2 commits March 10, 2026 14:35
Check the streamline output buffer size against
max-storage-buffer-binding-size before allocation. On Vulkan
this limit is typically 2 GB, which CSD models can exceed at
large chunk sizes. The error message reports the buffer size,
streamline count, and suggests reducing --chunk-size.

Refs dipy#37

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Rename the misspelled constructor parameter in GPUTracker,
MetalGPUTracker, and WebGPUTracker. The parameter is always
passed positionally so no callers are affected.

Fixes dipy#38

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@neurolabusc
Copy link
Contributor Author

@36000 why don't you check that you are happy with my inclusion of the informative error. If so, I think you should merge this.

@36000
Copy link
Collaborator

36000 commented Mar 11, 2026

This all looks good, thank you! I will be busy from now till late April, but sometime late April / Early March I will write up documentation for this package. Merging

@36000 36000 merged commit c76ecd7 into dipy:master Mar 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants