-
Notifications
You must be signed in to change notification settings - Fork 590
[wip] context parallelism #2668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2668
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 2 Cancelled JobsAs of commit da41f80 with merge base d39fd9b ( NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -718,3 +743,108 @@ def prepare_mha_for_tp( | |||
if is_fusion_model: | |||
model.decoder = decoder | |||
return model | |||
|
|||
|
|||
def _get_sdpa_context() -> ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean CP doesn't work with FlexAttention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, at least until pytorch/pytorch#151497 lands
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I also think this is somewhat orthogonal. Like flex does not have its own backend (see here). My assumption is that it should be using the flash attention backend (but need to confirm)
# Define optional context manager for context parallelism | ||
model_inputs = list(batch.values()) | ||
buffers = list(self._model.buffers()) | ||
optional_context_parallel_context_manager = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the naming we're using for other optional ctx managers? We have "activations_handling_ctx", though I'd prefer to consolidate on something like "context_parallel" or "maybe_context_parallel", I think the "with" statement says it's a context manager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm good taking out the "optional" here and matching what we do for activation offloading
Initial implementation of context parallelism in torchtune.
Initial test
Also confirmed that we can run 1M sequence length on a single node (will paste results in here shortly)
Still to test
Should test (a) equivalent loss curves and (b) requisite memory improvements on a long-context dataset for each of the below: