Skip to content

Add padding in dynamic sharding for tensors before all2all #2944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

aporialiao
Copy link
Member

Summary:
Given we can't expect shards in an embedding module to have the same dimensions for both dim 0 and dim 1, we have to pad the tensors passed into all_to_all_single collective to ensure we only call the expensive collective once.

This diff:

  1. adds the logic for padding tensors in both dimensions
  2. adds logic to remove the padding when updating the state dict after resharding
  3. Removes the original implentation of concatenating input tensors by dim 1 (which assumes dim 0 can be variable but dim 1 is consistent across all shards) and transposing
    1. This ensures that the existing CW unit test is leveraging the padding logic, as CW unit test was the previous one that failed due to inconsistent dimensions.

Padding leverages nn.Functional.pad, and pads tensors with value 0 on the right and bottom: e.g.

t = [[1, 2]
     [3, 4]]
max_dim_0 = 4
max_dim_1 = 3
t = pad_tensor_to_max_dims(t, max_dim_0, max_dim_1)
print(t)
>>> [[1, 2, 0, 0]
     [3, 4, 0, 0]
     [0, 0, 0, 0]]

Max dimensions for dim 0 and 1 are determined by going through all shard sizes of an embedding module. This is because we need to ensure the output_tensor passing into a2a has large enough size.

NOTE: This will be optimized later to go through only shard sizes that are being redistributed.

Differential Revision: D74150894

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 5, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D74150894

aporialiao added a commit to aporialiao/torchrec that referenced this pull request May 5, 2025
)

Summary:

Given we can't expect shards in an embedding module to have the same dimensions for both dim 0 and dim 1, we have to pad the tensors passed into `all_to_all_single` collective to ensure we only call the expensive collective once. 

This diff:
1. adds the logic for padding tensors in both dimensions
2. adds logic to remove the padding when updating the state dict after resharding
3. Removes the original implentation of concatenating input tensors by dim 1 (which assumes dim 0 can be variable but dim 1 is consistent across all shards) and transposing
    1. This ensures that the existing CW unit test is leveraging the padding logic, as CW unit test was the previous one that failed due to inconsistent dimensions.


Padding leverages nn.Functional.pad, and pads tensors with value 0 on the right and bottom: e.g.
```
t = [[1, 2]
     [3, 4]]
max_dim_0 = 4
max_dim_1 = 3
t = pad_tensor_to_max_dims(t, max_dim_0, max_dim_1)
print(t)
>>> [[1, 2, 0, 0]
     [3, 4, 0, 0]
     [0, 0, 0, 0]]
```


Max dimensions for dim 0 and 1 are determined by going through all shard sizes of an embedding module.  This is because we need to ensure the `output_tensor` passing into a2a has large enough size. 
> NOTE: This will be optimized later to go through only shard sizes that are being redistributed.

Differential Revision: D74150894
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D74150894

Summary:
Given we can't expect shards in an embedding module to have the same dimensions for both dim 0 and dim 1, we have to pad the tensors passed into `all_to_all_single` collective to ensure we only call the expensive collective once. 

This diff:
1. adds the logic for padding tensors in both dimensions
2. adds logic to remove the padding when updating the state dict after resharding
3. Removes the original implentation of concatenating input tensors by dim 1 (which assumes dim 0 can be variable but dim 1 is consistent across all shards) and transposing
    1. This ensures that the existing CW unit test is leveraging the padding logic, as CW unit test was the previous one that failed due to inconsistent dimensions.


Padding leverages `nn.Functional.pad`, and pads tensors with value 0 on the right and bottom: e.g.
```
t = [[1, 2]
     [3, 4]]
max_dim_0 = 4
max_dim_1 = 3
t = pad_tensor_to_max_dims(t, max_dim_0, max_dim_1)
print(t)
>>> [[1, 2, 0, 0]
     [3, 4, 0, 0]
     [0, 0, 0, 0]]
```


Max dimensions for dim 0 and 1 are determined by going through all shard sizes of an embedding module.  This is because we need to ensure the `output_tensor` passing into a2a has large enough size. 
> NOTE: This will be optimized later to go through only shard sizes that are being redistributed.

Differential Revision: D74150894
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D74150894

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants