Skip to content

[wip] context parallelism #2668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft

Conversation

ebsmothers
Copy link
Contributor

Initial implementation of context parallelism in torchtune.

Initial test

tune run --nproc_per_node 8 full_finetune_distributed --config llama3/8B_full \
context_parallel_dim=4 metric_logger=torchtune.training.metric_logging.WandBLogger 
metric_logger.project=context-parallel metric_logger.name=llama3-8b-cp4-dp2
Screenshot 2025-05-02 at 4 08 51 PM

Also confirmed that we can run 1M sequence length on a single node (will paste results in here shortly)

Still to test

Should test (a) equivalent loss curves and (b) requisite memory improvements on a long-context dataset for each of the below:

Copy link

pytorch-bot bot commented May 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2668

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 2 Cancelled Jobs

As of commit da41f80 with merge base d39fd9b (image):

NEW FAILURE - The following job has failed:

  • GPU tests / gpu_test (3.9, stable) (gh)
    tests/recipes/test_full_finetune_distributed.py::TestFullFinetuneDistributedRecipe::test_training_state_on_resume_from_distributed_checkpoint_multi_rank[llama3/8B_full-llama3-tune-4-1-True]

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 2, 2025
@@ -718,3 +743,108 @@ def prepare_mha_for_tp(
if is_fusion_model:
model.decoder = decoder
return model


def _get_sdpa_context() -> (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean CP doesn't work with FlexAttention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at least until pytorch/pytorch#151497 lands

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I also think this is somewhat orthogonal. Like flex does not have its own backend (see here). My assumption is that it should be using the flash attention backend (but need to confirm)

# Define optional context manager for context parallelism
model_inputs = list(batch.values())
buffers = list(self._model.buffers())
optional_context_parallel_context_manager = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the naming we're using for other optional ctx managers? We have "activations_handling_ctx", though I'd prefer to consolidate on something like "context_parallel" or "maybe_context_parallel", I think the "with" statement says it's a context manager.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm good taking out the "optional" here and matching what we do for activation offloading

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants