Conversation
|
!test |
|
Review updated until commit 25b081c Description
|
| Relevant files | |||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||||
| Refactoring |
| ||||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Parameter Type Consistency
DeviceIdxType to int64_t for my_device_index and root_index in multiple post communication functions. While this works for the relative indices being passed, we should verify that all call sites are consistently updated and that the type change doesn't introduce any implicit conversion issues, especially when absolute device IDs need to be converted to relative indices. |
Greptile SummaryRefactored communication lowering to use relative device indices instead of absolute device IDs. The Key Changes:
Confidence Score: 5/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[lowerToScatter/Gather/Broadcast/SendRecv/Reduce] --> B[Create Team from DeviceMesh]
B --> C[getRelativeIndex: Convert absolute DeviceIdxType to relative int64_t]
C --> D[Create Communication with relative root_index]
D --> E[postSingleCommunication called at runtime]
E --> F[Convert my_device absolute to relative using getRelativeIndex]
F --> G{Communication Type?}
G -->|Broadcast| H[postBroadcast with relative indices]
G -->|Gather| I[postGather with relative indices]
G -->|SendRecv| J[postSendRecv with relative indices]
G -->|Scatter| K[postScatter with relative indices]
G -->|Reduce| L[postReduce with relative indices]
J --> M[sender_index = root_index]
M --> N[receiver_index = 1 - sender_index]
N --> O[backend->send/recv with relative indices]
Last reviewed commit: 25b081c |
Additional Comments (2)
After the changes in You'll need to add |
Additional Comments (1)
After this PR, This works only by coincidence because To be consistent with the new convention, this should pass the relative index: |
Additional Comments (1)
After this PR, To be consistent with the rest of the PR's intent, consider passing the relative index explicitly: |
Additional Comments (1)
For correctness and to actually test the relative-index contract, pass the relative index: |
| if (std::find(team.begin(), team.end(), my_device_index) == team.end()) { | ||
| return nullptr; | ||
| } | ||
| my_device_index = getRelativeIndex(team, my_device_index); |
There was a problem hiding this comment.
I can also do this in caller methods such that postSingleCommunication receives all relative indices.
There was a problem hiding this comment.
Using the same name for both absolute and relative index is a bit confusing. Maybe my_index_in_team of type int64_t will be clearer?
There was a problem hiding this comment.
Also, you may want to consider using _index for and only for relative indices. E.g. my_device vs my_device_index/my_index_in_team and root vs root_index. You sort of followed that convention in other methods already, e.g., sender vs sender_index.
|
!test |
|
!test |
| if (std::find(team.begin(), team.end(), my_device_index) == team.end()) { | ||
| return nullptr; | ||
| } | ||
| my_device_index = getRelativeIndex(team, my_device_index); |
There was a problem hiding this comment.
Using the same name for both absolute and relative index is a bit confusing. Maybe my_index_in_team of type int64_t will be clearer?
| if (std::find(team.begin(), team.end(), my_device_index) == team.end()) { | ||
| return nullptr; | ||
| } | ||
| my_device_index = getRelativeIndex(team, my_device_index); |
There was a problem hiding this comment.
Also, you may want to consider using _index for and only for relative indices. E.g. my_device vs my_device_index/my_index_in_team and root vs root_index. You sort of followed that convention in other methods already, e.g., sender vs sender_index.
|
A note for future work: communication libraries like NCCL take absolute ranks so all relative indices here will have to be translated to absolute by accessing the device mesh. This translation can be done in ProcessGroup (as is today), HostIrEvaluator, or LLVM IR that HostIrJit generates. For SOL, we may have to avoid doing this translation for every communication. For that, we may have to unify the identical meshes (recall that meshes are per TensorView) so LLVM's CSE can kick in. cc @Priya2698 |
NCCL accepts ranks (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reduce) instead of device IDs.
Similarly for NVSHMEM:
|
|
!test |
|
Renaming to distinguish device ids from indices in team in other files will be done in a follow-up PR. |
For cases such as broadcast-based allgather in a host loop, the root index is the for-loop index, which may not be the absolute device ID. I am changing all lowering methods to use relative root index which is what the backends use as well.