-
Notifications
You must be signed in to change notification settings - Fork 593
[wip] context parallelism #2668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d7362a7
421b1fa
1dad580
898bfa9
91d874d
3b81a35
f025f63
da41f80
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,12 +5,13 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import contextlib | ||
import logging | ||
import os | ||
from dataclasses import dataclass | ||
from functools import cached_property | ||
from itertools import chain | ||
from typing import Any, Callable, cast, Dict, List, Optional, Tuple | ||
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
@@ -28,6 +29,9 @@ | |
) | ||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh | ||
from torch.distributed.fsdp import FSDPModule, ShardingStrategy | ||
from torch.distributed.tensor.experimental import context_parallel | ||
from torch.distributed.tensor.experimental._attention import set_rotate_method | ||
from torch.nn.attention import sdpa_kernel, SDPBackend | ||
from torch.nn.modules.module import _IncompatibleKeys | ||
from torch.optim import Optimizer | ||
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 | ||
|
@@ -54,26 +58,28 @@ class ParallelDims: | |
dp_replicate: int | ||
dp_shard: int | ||
tp: int | ||
cp: int | ||
world_size: int | ||
|
||
def __post_init__(self): | ||
self._validate() | ||
|
||
def _validate(self): | ||
dp_replicate, dp_shard, tp = ( | ||
dp_replicate, dp_shard, tp, cp = ( | ||
self.dp_replicate, | ||
self.dp_shard, | ||
self.tp, | ||
self.cp, | ||
) | ||
for d in (dp_replicate, tp): | ||
for d in (dp_replicate, tp, cp): | ||
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" | ||
|
||
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." | ||
if dp_shard < 0: | ||
self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp) | ||
self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp * cp) | ||
assert dp_shard >= 1 | ||
|
||
assert dp_replicate * dp_shard * tp == self.world_size, ( | ||
assert dp_replicate * dp_shard * tp * cp == self.world_size, ( | ||
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " | ||
f"tp({tp}) != WORLD_SIZE({self.world_size})" | ||
) | ||
|
@@ -82,8 +88,8 @@ def build_mesh(self, device_type): | |
dims = [] | ||
names = [] | ||
for d, name in zip( | ||
[self.dp_replicate, self.dp_shard, self.tp], | ||
["dp_replicate", "dp_shard", "tp"], | ||
[self.dp_replicate, self.dp_shard, self.cp, self.tp], | ||
["dp_replicate", "dp_shard", "cp", "tp"], | ||
): | ||
if d > 1: | ||
dims.append(d) | ||
|
@@ -96,17 +102,36 @@ def build_mesh(self, device_type): | |
# initialized: | ||
# Mesh for data loading (no communication on this mesh) | ||
dp_mesh_dim_names = [] | ||
dp_shard_cp_mesh_dim_names = [] | ||
dp_cp_mesh_dim_names = [] | ||
|
||
if self.dp_replicate_enabled: | ||
dp_mesh_dim_names.append("dp_replicate") | ||
dp_cp_mesh_dim_names.append("dp_replicate") | ||
if self.dp_shard_enabled: | ||
dp_mesh_dim_names.append("dp_shard") | ||
dp_shard_cp_mesh_dim_names.append("dp_shard") | ||
dp_cp_mesh_dim_names.append("dp_shard") | ||
if self.cp_enabled: | ||
dp_shard_cp_mesh_dim_names.append("cp") | ||
dp_cp_mesh_dim_names.append("cp") | ||
|
||
if dp_mesh_dim_names != []: | ||
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") | ||
|
||
if dp_shard_cp_mesh_dim_names != []: | ||
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( | ||
mesh_dim_name="dp_shard_cp" | ||
) | ||
if dp_cp_mesh_dim_names != []: | ||
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") | ||
|
||
return mesh | ||
|
||
@property | ||
def cp_enabled(self): | ||
return self.cp > 1 | ||
|
||
@property | ||
def dp_enabled(self): | ||
return self.dp_replicate > 1 or self.dp_shard > 1 | ||
|
@@ -718,3 +743,108 @@ def prepare_mha_for_tp( | |
if is_fusion_model: | ||
model.decoder = decoder | ||
return model | ||
|
||
|
||
def _get_sdpa_context() -> ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this mean CP doesn't work with FlexAttention? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, at least until pytorch/pytorch#151497 lands There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But I also think this is somewhat orthogonal. Like flex does not have its own backend (see here). My assumption is that it should be using the flash attention backend (but need to confirm) |
||
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]] | ||
): | ||
""" | ||
Creates a context manager to confine to flash/efficient/cuDNN attention backends. | ||
|
||
Returns: | ||
A context manager function that takes an optional context parallel context. | ||
""" | ||
|
||
@contextlib.contextmanager | ||
def context(cp_context: Generator[None, None, None] | None = None): | ||
with contextlib.ExitStack() as stack: | ||
if cp_context is not None: | ||
stack.enter_context( | ||
sdpa_kernel( | ||
[ | ||
SDPBackend.FLASH_ATTENTION, | ||
SDPBackend.EFFICIENT_ATTENTION, | ||
SDPBackend.CUDNN_ATTENTION, | ||
] | ||
) | ||
) | ||
stack.enter_context(cp_context) | ||
|
||
yield | ||
|
||
return context | ||
|
||
|
||
def get_context_parallel_context( | ||
*, | ||
cp_enabled: bool = False, | ||
world_mesh: torch.distributed.DeviceMesh, | ||
model_inputs: List[torch.Tensor], | ||
model_buffers: Optional[List[torch.Tensor]] = None, | ||
) -> Generator[None, None, None]: | ||
""" | ||
Context manager for applying context parallelism to a model. In addition to applying the | ||
standard context manager to patch SDPA and shard model inputs and buffers along the sequence | ||
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends. | ||
|
||
Args: | ||
cp_enabled (bool): Whether context parallel is enabled. Default: False | ||
world_mesh (torch.distributed.DeviceMesh): Global device mesh. | ||
model_inputs (List[torch.Tensor]): List of any model inputs which should be | ||
sharded along sequence dimension. | ||
model_buffers (Optional[List[torch.Tensor]]): List of any model buffers. These should also be | ||
sharded along sequence dimension. Default: None | ||
|
||
Returns: | ||
A context manager applying context parallelism if cp_enabled is True. Otherwise a context manager | ||
disabling the math SDPA backend. | ||
|
||
Raises: | ||
ValueError: if cp_enabled is True but world_mesh does not contain a "cp" dimension | ||
|
||
Example: | ||
```python | ||
batch = {"inputs": inputs, "labels": labels} | ||
with get_context_parallel_context( | ||
cp_enabled=parallel_dims.cp_enabled, | ||
cp_mesh=world_mesh["cp"] if parallel_dims.cp_enabled else None, | ||
model_inputs=list(batch.values()), | ||
model_buffers=model.buffers(), | ||
): | ||
logits = model(inputs) | ||
loss = loss(logits, labels) | ||
loss.backward() | ||
``` | ||
""" | ||
|
||
if cp_enabled and "cp" not in world_mesh.mesh_dim_names: | ||
raise ValueError( | ||
"Context parallel is enabled but no context parallel device mesh is provided." | ||
) | ||
model_buffers = model_buffers or [] | ||
|
||
@contextlib.contextmanager | ||
def context(): | ||
# Create context parallel context if enabled | ||
cp_context = None | ||
if ( | ||
cp_enabled | ||
and world_mesh is not None | ||
and model_inputs + model_buffers is not None | ||
): | ||
# TODO: do we need to parametrize this? | ||
set_rotate_method("allgather") | ||
cp_context = context_parallel( | ||
world_mesh["cp"], | ||
buffers=model_inputs + model_buffers, | ||
buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers), | ||
no_restore_buffers=set(model_inputs), | ||
) | ||
|
||
# Create and enter the train context with the optional cp_context | ||
sdpa_context = _get_sdpa_context() | ||
|
||
with sdpa_context(cp_context): | ||
yield | ||
|
||
return context() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the naming we're using for other optional ctx managers? We have "activations_handling_ctx", though I'd prefer to consolidate on something like "context_parallel" or "maybe_context_parallel", I think the "with" statement says it's a context manager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm good taking out the "optional" here and matching what we do for activation offloading