Skip to content

[wip] context parallelism #2668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(self, cfg: DictConfig) -> None:
raise ValueError(
"Tensor Parallel plan needs to be provided when tensor parallel is enabled."
)
self.cp_degree = cfg.get("context_parallel_dim", 1)
data_shard = cfg.get("data_parallel_shard_dim", -1) # -1 means to infer
data_replicate = cfg.get("data_parallel_replicate_dim", 1)

Expand All @@ -159,6 +160,7 @@ def __init__(self, cfg: DictConfig) -> None:
dp_replicate=data_replicate,
dp_shard=data_shard,
tp=self.tp_degree,
cp=self.cp_degree,
world_size=self.world_size,
)
self.world_mesh = self.parallel_dims.build_mesh(device_type=device_type)
Expand Down Expand Up @@ -771,12 +773,14 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
# TODO: overkill?
# https://github.com/pytorch/pytorch/blob/4f62dcc/torch/distributed/tensor/experimental/_attention.py#L1246
collate_fn=(
partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
pad_to_multiple_of=self.tp_degree,
pad_to_multiple_of=self.tp_degree * self.cp_degree * 2,
)
if not packed
else padded_collate_packed
Expand Down Expand Up @@ -888,6 +892,18 @@ def train(self) -> None:

utils.batch_to_device(batch, self._device)

# Define optional context manager for context parallelism
model_inputs = list(batch.values())
buffers = list(self._model.buffers())
optional_context_parallel_context_manager = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the naming we're using for other optional ctx managers? We have "activations_handling_ctx", though I'd prefer to consolidate on something like "context_parallel" or "maybe_context_parallel", I think the "with" statement says it's a context manager.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm good taking out the "optional" here and matching what we do for activation offloading

training.get_context_parallel_context(
cp_enabled=self.cp_degree > 1,
world_mesh=self.world_mesh,
model_inputs=list(batch.values()),
model_buffers=list(self._model.buffers()),
)
)

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
Expand All @@ -897,17 +913,17 @@ def train(self) -> None:

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss

# For optimizer in backward, we need to normalize before calling backward
# This case and gradient accumulation are mutually exclusive
if self._optimizer_in_bwd:
torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss * (self.dp_degree / num_tokens)
with optional_context_parallel_context_manager:
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
# For optimizer in backward, we need to normalize before calling backward
# This case and gradient accumulation are mutually exclusive
if self._optimizer_in_bwd:
torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss * (self.dp_degree / num_tokens)
current_loss.backward()

current_loss.backward()
# Optimizer step (if not fused in backward call)
if (idx + 1) % self._gradient_accumulation_steps == 0:
if not self._optimizer_in_bwd:
Expand All @@ -916,7 +932,10 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, self.dp_degree / num_tokens)
# TODO: check this
training.scale_grads(
self._model, self.dp_degree * self.cp_degree / num_tokens
)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
18 changes: 14 additions & 4 deletions torchtune/models/llama3_2/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from functools import partial
from typing import List

from torchtune.models.llama3_2._component_builders import llama3_2, lora_llama3_2

Expand All @@ -16,12 +16,14 @@
the llama3_2_1b model builder uses the llama3_2 component builder to create the
Llama3.2 1B model.
"""


def llama3_2_1b(
tie_word_embeddings: bool = True,
) -> TransformerDecoder:
"""
Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values.

Args:
tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.

Expand All @@ -42,6 +44,8 @@ def llama3_2_1b(
scale_factor=32,
tie_word_embeddings=tie_word_embeddings,
)


def llama3_2_3b(
tie_word_embeddings: bool = True,
) -> TransformerDecoder:
Expand All @@ -68,6 +72,8 @@ def llama3_2_3b(
scale_factor=32,
tie_word_embeddings=tie_word_embeddings,
)


def lora_llama3_2_1b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
Expand All @@ -84,7 +90,7 @@ def lora_llama3_2_1b(
The Llama3.2 defaults are the same as in :func:`~torchtune.models.llama3_2.llama3_2_1b`,
while LoRA default params are based on
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.

Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
Expand Down Expand Up @@ -126,6 +132,8 @@ def lora_llama3_2_1b(
quantize_base=quantize_base,
tie_word_embeddings=tie_word_embeddings,
)


def lora_llama3_2_3b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
Expand Down Expand Up @@ -162,7 +170,7 @@ def lora_llama3_2_3b(
Returns:
TransformerDecoder: Instantiation of Llama3.2 3B model with LoRA applied
"""

return lora_llama3_2(
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
Expand All @@ -185,6 +193,8 @@ def lora_llama3_2_3b(
quantize_base=quantize_base,
tie_word_embeddings=tie_word_embeddings,
)


qlora_llama3_2_1b = partial(lora_llama3_2_1b, quantize_base=True)
qlora_llama3_2_1b.__doc__ = """
Builder for creating a Llama3.2 1B model with QLoRA enabled. Base model weights in linear layers
Expand Down
2 changes: 2 additions & 0 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchtune.training._compile import compile_loss, compile_model
from torchtune.training._distributed import (
gather_cpu_state_dict,
get_context_parallel_context,
get_distributed_backend,
get_full_optimizer_state_dict,
get_shard_conditions,
Expand Down Expand Up @@ -142,4 +143,5 @@
"get_distributed_backend",
"disable_dropout",
"DATALOADER_KEY",
"get_context_parallel_context",
]
144 changes: 137 additions & 7 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
# LICENSE file in the root directory of this source tree.


import contextlib
import logging
import os
from dataclasses import dataclass
from functools import cached_property
from itertools import chain
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple

import torch
import torch.distributed as dist
Expand All @@ -28,6 +29,9 @@
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import FSDPModule, ShardingStrategy
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.modules.module import _IncompatibleKeys
from torch.optim import Optimizer
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4
Expand All @@ -54,26 +58,28 @@ class ParallelDims:
dp_replicate: int
dp_shard: int
tp: int
cp: int
world_size: int

def __post_init__(self):
self._validate()

def _validate(self):
dp_replicate, dp_shard, tp = (
dp_replicate, dp_shard, tp, cp = (
self.dp_replicate,
self.dp_shard,
self.tp,
self.cp,
)
for d in (dp_replicate, tp):
for d in (dp_replicate, tp, cp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"

assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
if dp_shard < 0:
self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp)
self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp * cp)
assert dp_shard >= 1

assert dp_replicate * dp_shard * tp == self.world_size, (
assert dp_replicate * dp_shard * tp * cp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"tp({tp}) != WORLD_SIZE({self.world_size})"
)
Expand All @@ -82,8 +88,8 @@ def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.dp_replicate, self.dp_shard, self.tp],
["dp_replicate", "dp_shard", "tp"],
[self.dp_replicate, self.dp_shard, self.cp, self.tp],
["dp_replicate", "dp_shard", "cp", "tp"],
):
if d > 1:
dims.append(d)
Expand All @@ -96,17 +102,36 @@ def build_mesh(self, device_type):
# initialized:
# Mesh for data loading (no communication on this mesh)
dp_mesh_dim_names = []
dp_shard_cp_mesh_dim_names = []
dp_cp_mesh_dim_names = []

if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")
dp_cp_mesh_dim_names.append("dp_replicate")
if self.dp_shard_enabled:
dp_mesh_dim_names.append("dp_shard")
dp_shard_cp_mesh_dim_names.append("dp_shard")
dp_cp_mesh_dim_names.append("dp_shard")
if self.cp_enabled:
dp_shard_cp_mesh_dim_names.append("cp")
dp_cp_mesh_dim_names.append("cp")

if dp_mesh_dim_names != []:
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")

if dp_shard_cp_mesh_dim_names != []:
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(
mesh_dim_name="dp_shard_cp"
)
if dp_cp_mesh_dim_names != []:
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")

return mesh

@property
def cp_enabled(self):
return self.cp > 1

@property
def dp_enabled(self):
return self.dp_replicate > 1 or self.dp_shard > 1
Expand Down Expand Up @@ -718,3 +743,108 @@ def prepare_mha_for_tp(
if is_fusion_model:
model.decoder = decoder
return model


def _get_sdpa_context() -> (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean CP doesn't work with FlexAttention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at least until pytorch/pytorch#151497 lands

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I also think this is somewhat orthogonal. Like flex does not have its own backend (see here). My assumption is that it should be using the flash attention backend (but need to confirm)

Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
):
"""
Creates a context manager to confine to flash/efficient/cuDNN attention backends.

Returns:
A context manager function that takes an optional context parallel context.
"""

@contextlib.contextmanager
def context(cp_context: Generator[None, None, None] | None = None):
with contextlib.ExitStack() as stack:
if cp_context is not None:
stack.enter_context(
sdpa_kernel(
[
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.CUDNN_ATTENTION,
]
)
)
stack.enter_context(cp_context)

yield

return context


def get_context_parallel_context(
*,
cp_enabled: bool = False,
world_mesh: torch.distributed.DeviceMesh,
model_inputs: List[torch.Tensor],
model_buffers: Optional[List[torch.Tensor]] = None,
) -> Generator[None, None, None]:
"""
Context manager for applying context parallelism to a model. In addition to applying the
standard context manager to patch SDPA and shard model inputs and buffers along the sequence
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.

Args:
cp_enabled (bool): Whether context parallel is enabled. Default: False
world_mesh (torch.distributed.DeviceMesh): Global device mesh.
model_inputs (List[torch.Tensor]): List of any model inputs which should be
sharded along sequence dimension.
model_buffers (Optional[List[torch.Tensor]]): List of any model buffers. These should also be
sharded along sequence dimension. Default: None

Returns:
A context manager applying context parallelism if cp_enabled is True. Otherwise a context manager
disabling the math SDPA backend.

Raises:
ValueError: if cp_enabled is True but world_mesh does not contain a "cp" dimension

Example:
```python
batch = {"inputs": inputs, "labels": labels}
with get_context_parallel_context(
cp_enabled=parallel_dims.cp_enabled,
cp_mesh=world_mesh["cp"] if parallel_dims.cp_enabled else None,
model_inputs=list(batch.values()),
model_buffers=model.buffers(),
):
logits = model(inputs)
loss = loss(logits, labels)
loss.backward()
```
"""

if cp_enabled and "cp" not in world_mesh.mesh_dim_names:
raise ValueError(
"Context parallel is enabled but no context parallel device mesh is provided."
)
model_buffers = model_buffers or []

@contextlib.contextmanager
def context():
# Create context parallel context if enabled
cp_context = None
if (
cp_enabled
and world_mesh is not None
and model_inputs + model_buffers is not None
):
# TODO: do we need to parametrize this?
set_rotate_method("allgather")
cp_context = context_parallel(
world_mesh["cp"],
buffers=model_inputs + model_buffers,
buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers),
no_restore_buffers=set(model_inputs),
)

# Create and enter the train context with the optional cp_context
sdpa_context = _get_sdpa_context()

with sdpa_context(cp_context):
yield

return context()
Loading