From d7362a7f130eedaa5298308813a2bca202565654 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Thu, 17 Apr 2025 12:43:05 -0700 Subject: [PATCH 1/5] [wip] context parallelism --- recipes/full_finetune_distributed.py | 23 ++++++++--- torchtune/training/__init__.py | 2 + torchtune/training/_distributed.py | 60 +++++++++++++++++++++++++--- 3 files changed, 74 insertions(+), 11 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 9dbed787db..54e929e96f 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -148,6 +148,7 @@ def __init__(self, cfg: DictConfig) -> None: raise ValueError( "Tensor Parallel plan needs to be provided when tensor parallel is enabled." ) + self.cp_degree = cfg.get("context_parallel_dim", 1) data_shard = cfg.get("data_parallel_shard_dim", -1) # -1 means to infer data_replicate = cfg.get("data_parallel_replicate_dim", 1) @@ -156,6 +157,7 @@ def __init__(self, cfg: DictConfig) -> None: dp_replicate=data_replicate, dp_shard=data_shard, tp=self.tp_degree, + cp=self.cp_degree, world_size=self.world_size, ) self.world_mesh = self.parallel_dims.build_mesh(device_type=device_type) @@ -874,7 +876,14 @@ def train(self) -> None: torch.cuda.memory._record_memory_history() utils.batch_to_device(batch, self._device) - + optional_context_parallel_context_manager = ( + training.get_context_parallel_context_manager( + cp_enabled=self.cp_degree > 1, + model=self._model, + cp_mesh=self.world_mesh["cp"], + model_inputs=list(batch.values()), + ) + ) # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step current_num_tokens = ( @@ -884,7 +893,8 @@ def train(self) -> None: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients - current_loss = self._loss_step(batch) * current_num_tokens + with optional_context_parallel_context_manager: + current_loss = self._loss_step(batch) * current_num_tokens running_loss += current_loss # For optimizer in backward, we need to normalize before calling backward @@ -893,8 +903,8 @@ def train(self) -> None: torch.distributed.all_reduce(num_tokens) torch.distributed.all_reduce(running_loss) current_loss = current_loss * (self.dp_degree / num_tokens) - - current_loss.backward() + with optional_context_parallel_context_manager: + current_loss.backward() # Optimizer step (if not fused in backward call) if (idx + 1) % self._gradient_accumulation_steps == 0: if not self._optimizer_in_bwd: @@ -903,7 +913,10 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, self.dp_degree / num_tokens) + # TODO: check this + training.scale_grads( + self._model, self.dp_degree * self.cp_degree / num_tokens + ) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index b2c327c617..fd72757634 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -11,6 +11,7 @@ from torchtune.training._compile import compile_loss, compile_model from torchtune.training._distributed import ( gather_cpu_state_dict, + get_context_parallel_context_manager, get_distributed_backend, get_full_optimizer_state_dict, get_shard_conditions, @@ -142,4 +143,5 @@ "get_distributed_backend", "disable_dropout", "DATALOADER_KEY", + "get_context_parallel_context_manager", ] diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 875f7c313d..8bb237947f 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. +import contextlib import logging import os from dataclasses import dataclass @@ -27,6 +28,8 @@ ) from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.fsdp import FSDPModule, ShardingStrategy +from torch.distributed.tensor.experimental import context_parallel +from torch.distributed.tensor.experimental._attention import set_rotate_method from torch.nn.modules.module import _IncompatibleKeys from torch.optim import Optimizer from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 @@ -53,26 +56,28 @@ class ParallelDims: dp_replicate: int dp_shard: int tp: int + cp: int world_size: int def __post_init__(self): self._validate() def _validate(self): - dp_replicate, dp_shard, tp = ( + dp_replicate, dp_shard, tp, cp = ( self.dp_replicate, self.dp_shard, self.tp, + self.cp, ) - for d in (dp_replicate, tp): + for d in (dp_replicate, tp, cp): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." if dp_shard < 0: - self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp) + self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp * cp) assert dp_shard >= 1 - assert dp_replicate * dp_shard * tp == self.world_size, ( + assert dp_replicate * dp_shard * tp * cp == self.world_size, ( f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " f"tp({tp}) != WORLD_SIZE({self.world_size})" ) @@ -81,8 +86,8 @@ def build_mesh(self, device_type): dims = [] names = [] for d, name in zip( - [self.dp_replicate, self.dp_shard, self.tp], - ["dp_replicate", "dp_shard", "tp"], + [self.dp_replicate, self.dp_shard, self.tp, self.cp], + ["dp_replicate", "dp_shard", "tp", "cp"], ): if d > 1: dims.append(d) @@ -95,15 +100,30 @@ def build_mesh(self, device_type): # initialized: # Mesh for data loading (no communication on this mesh) dp_mesh_dim_names = [] + dp_shard_cp_mesh_dim_names = [] + dp_cp_mesh_dim_names = [] if self.dp_replicate_enabled: dp_mesh_dim_names.append("dp_replicate") + dp_cp_mesh_dim_names.append("dp_replicate") if self.dp_shard_enabled: dp_mesh_dim_names.append("dp_shard") + dp_shard_cp_mesh_dim_names.append("dp_shard") + dp_cp_mesh_dim_names.append("dp_shard") + if self.cp_enabled: + dp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") if dp_mesh_dim_names != []: mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + if dp_shard_cp_mesh_dim_names != []: + mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( + mesh_dim_name="dp_shard_cp" + ) + if dp_cp_mesh_dim_names != []: + mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + return mesh @property @@ -122,6 +142,10 @@ def dp_shard_enabled(self): def tp_enabled(self): return self.tp > 1 + @property + def cp_enabled(self): + return self.cp > 1 + def _get_sharding_strategy(strategy: str) -> ShardingStrategy: """Helper function to convert sharding strategy strings to ShardingStrategy enum.""" @@ -710,3 +734,27 @@ def prepare_mha_for_tp( if is_fusion_model: model.decoder = decoder return model + + +# TODO: move this elsewhere + + +def get_context_parallel_context_manager( + cp_enabled: bool, + model: TransformerDecoder, # TODO: generalize + cp_mesh: DeviceMesh, + model_inputs: List[torch.Tensor], +) -> contextlib.contextmanager: + if not cp_enabled: + return contextlib.nullcontext() + # if "cp" not in mesh: + # raise ValueError("CP mesh not found in device mesh") + # cp_mesh = mesh["cp"] + set_rotate_method("allgather") # TODO: hardcode for now + buffers = list(model.buffers()) + return context_parallel( + cp_mesh, + buffers=model_inputs + buffers, + buffer_seq_dims=[1] * len(model_inputs) + [0] * len(buffers), + no_restore_buffers=set(model_inputs), + ) From 421b1faaf85bf44e3b4d93b39e906bf310ae20b7 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Thu, 17 Apr 2025 14:37:29 -0700 Subject: [PATCH 2/5] messing with context manager --- recipes/full_finetune_distributed.py | 4 +- torchtune/training/_distributed.py | 85 ++++++++++++++++++++++------ 2 files changed, 72 insertions(+), 17 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 54e929e96f..c928608ed0 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -761,7 +761,9 @@ def _setup_data( collate_fn, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, - pad_to_multiple_of=self.tp_degree, + pad_to_multiple_of=self.tp_degree + * self.cp_degree + * 2, # TODO: overkill? https://github.com/pytorch/pytorch/blob/4f62dccbdae90d266e3cce4a499b77008f8f840f/torch/distributed/tensor/experimental/_attention.py#L1246 ) if not packed else padded_collate_packed diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 8bb237947f..cff6eccbc4 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -10,7 +10,7 @@ import os from dataclasses import dataclass from itertools import chain -from typing import Any, Callable, cast, Dict, List, Optional, Tuple +from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -30,6 +30,7 @@ from torch.distributed.fsdp import FSDPModule, ShardingStrategy from torch.distributed.tensor.experimental import context_parallel from torch.distributed.tensor.experimental._attention import set_rotate_method +from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.modules.module import _IncompatibleKeys from torch.optim import Optimizer from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 @@ -736,25 +737,77 @@ def prepare_mha_for_tp( return model -# TODO: move this elsewhere - - +@contextlib.contextmanager def get_context_parallel_context_manager( cp_enabled: bool, model: TransformerDecoder, # TODO: generalize cp_mesh: DeviceMesh, model_inputs: List[torch.Tensor], -) -> contextlib.contextmanager: - if not cp_enabled: - return contextlib.nullcontext() - # if "cp" not in mesh: - # raise ValueError("CP mesh not found in device mesh") - # cp_mesh = mesh["cp"] +): set_rotate_method("allgather") # TODO: hardcode for now buffers = list(model.buffers()) - return context_parallel( - cp_mesh, - buffers=model_inputs + buffers, - buffer_seq_dims=[1] * len(model_inputs) + [0] * len(buffers), - no_restore_buffers=set(model_inputs), - ) + if not cp_enabled: + return contextlib.nullcontext() + else: + with contextlib.ExitStack() as stack: + stack.enter_context( + sdpa_kernel( + [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.CUDNN_ATTENTION, + ] + ) + ).enter_context( + context_parallel( + cp_mesh, + buffers=model_inputs + buffers, + buffer_seq_dims=[1] * len(model_inputs) + [0] * len(buffers), + no_restore_buffers=set(model_inputs), + ) + ) + yield + + +# TODO: move this elsewhere + + +# @contextlib.contextma +# def get_context_parallel_context_manager( +# cp_enabled: bool, +# model: TransformerDecoder, # TODO: generalize +# cp_mesh: DeviceMesh, +# model_inputs: List[torch.Tensor], +# ) -> Union[Generator[None, None, None], contextlib.nullcontext]: +# if not cp_enabled: +# return contextlib.nullcontext() +# # if "cp" not in mesh: +# # raise ValueError("CP mesh not found in device mesh") +# # cp_mesh = mesh["cp"] +# set_rotate_method("allgather") # TODO: hardcode for now +# buffers = list(model.buffers()) + +# @contextlib.contextmanager +# def context(): +# with contextlib.ExitStack() as stack: +# stack.enter_context( +# sdpa_kernel( +# [ +# SDPBackend.FLASH_ATTENTION, +# SDPBackend.EFFICIENT_ATTENTION, +# SDPBackend.CUDNN_ATTENTION, +# ] +# ) +# ) + +# cp_context = context_parallel( +# cp_mesh, +# buffers=model_inputs + buffers, +# buffer_seq_dims=[1] * len(model_inputs) + [0] * len(buffers), +# no_restore_buffers=set(model_inputs), +# ) + +# stack.enter_context(cp_context) +# yield + +# return context From 898bfa90b9dad78c2f9d6664228f62a79c8ad7af Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Thu, 1 May 2025 14:09:28 -0700 Subject: [PATCH 3/5] wip changes --- recipes/full_finetune_distributed.py | 33 +++-- torchtune/training/__init__.py | 1 + torchtune/training/_distributed.py | 151 +++++++++++++++++++++- torchtune/training/context_managers.py | 169 +++++++++++++++++++++++++ 4 files changed, 341 insertions(+), 13 deletions(-) create mode 100644 torchtune/training/context_managers.py diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 3e10b4092a..8fb19af65c 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -892,14 +892,25 @@ def train(self) -> None: torch.cuda.memory._record_memory_history() utils.batch_to_device(batch, self._device) + model_inputs = list(batch.values()) + buffers = list(self._model.buffers()) optional_context_parallel_context_manager = ( - training.get_context_parallel_context_manager( + training.create_consolidated_train_context( cp_enabled=self.cp_degree > 1, - model=self._model, cp_mesh=self.world_mesh["cp"], - model_inputs=list(batch.values()), + cp_buffers=model_inputs + buffers, + cp_seq_dims=[1] * len(model_inputs) + [0] * len(buffers), + cp_no_restore_buffers=set(model_inputs), ) ) + # optional_context_parallel_context_manager = ( + # training.get_context_parallel_context_manager( + # cp_enabled=self.cp_degree > 1, + # model=self._model, + # cp_mesh=self.world_mesh["cp"], + # model_inputs=list(batch.values()), + # ) + # ) # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step current_num_tokens = ( @@ -911,16 +922,14 @@ def train(self) -> None: # This way we can normalize by the total number of tokens if we're accumulating gradients with optional_context_parallel_context_manager: current_loss = self._loss_step(batch) * current_num_tokens - running_loss += current_loss - - # For optimizer in backward, we need to normalize before calling backward - # This case and gradient accumulation are mutually exclusive - if self._optimizer_in_bwd: - torch.distributed.all_reduce(num_tokens) - torch.distributed.all_reduce(running_loss) - current_loss = current_loss * (self.dp_degree / num_tokens) - with optional_context_parallel_context_manager: + # For optimizer in backward, we need to normalize before calling backward + # This case and gradient accumulation are mutually exclusive + if self._optimizer_in_bwd: + torch.distributed.all_reduce(num_tokens) + torch.distributed.all_reduce(running_loss) + current_loss = current_loss * (self.dp_degree / num_tokens) current_loss.backward() + # Optimizer step (if not fused in backward call) if (idx + 1) % self._gradient_accumulation_steps == 0: if not self._optimizer_in_bwd: diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index fd72757634..5ea8d70a1f 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -10,6 +10,7 @@ ) from torchtune.training._compile import compile_loss, compile_model from torchtune.training._distributed import ( + create_consolidated_train_context, gather_cpu_state_dict, get_context_parallel_context_manager, get_distributed_backend, diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index c2f0bdad53..efc7c0833d 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -128,6 +128,10 @@ def build_mesh(self, device_type): return mesh + @property + def cp_enabled(self): + return self.cp > 1 + @property def dp_enabled(self): return self.dp_replicate > 1 or self.dp_shard > 1 @@ -762,7 +766,8 @@ def get_context_parallel_context_manager( SDPBackend.CUDNN_ATTENTION, ] ) - ).enter_context( + ) + stack.enter_context( context_parallel( cp_mesh, buffers=model_inputs + buffers, @@ -815,3 +820,147 @@ def get_context_parallel_context_manager( # yield # return context + +import contextlib +from typing import Callable, Generator, List, Optional, Set, TypeVar + +import torch + + +T = TypeVar("T") + + +def get_train_context() -> ( + Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]] +): + """ + Creates a training context that enables loss parallel and compiled autograd as specified. + + Args: + enable_loss_parallel: Whether to enable loss parallel. + enable_compiled_autograd: Whether to enable compiled autograd. + + Returns: + A context manager function that takes an optional context parallel context. + """ + + @contextlib.contextmanager + def context(cp_context: Generator[None, None, None] | None = None): + with contextlib.ExitStack() as stack: + if cp_context is not None: + from torch.nn.attention import sdpa_kernel, SDPBackend + + stack.enter_context( + sdpa_kernel( + [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.CUDNN_ATTENTION, + ] + ) + ) + stack.enter_context(cp_context) + + yield + + return context + + +def create_context_parallel_ctx( + cp_mesh: torch.distributed.DeviceMesh, + cp_buffers: List[torch.Tensor], + cp_seq_dims: List[int], + cp_no_restore_buffers: Set[torch.Tensor], +) -> Generator[None, None, None]: + """ + Creates a context parallel context. + + Args: + cp_mesh: Device mesh for context parallel. + cp_buffers: List of tensors for context parallel buffers. + cp_seq_dims: List of sequence dimensions for context parallel buffers. + cp_no_restore_buffers: Set of tensors that should not be restored. + + Returns: + A context manager for context parallel. + """ + try: + from torch.distributed.tensor.experimental import context_parallel + from torch.distributed.tensor.experimental._attention import set_rotate_method + except ImportError: + print( + f"PyTorch version {torch.__version__} does not include the experimental " + "Context Parallel API. Please update to a newer version." + ) + return contextlib.nullcontext() + + set_rotate_method("allgather") + return context_parallel( + cp_mesh, + buffers=cp_buffers, + buffer_seq_dims=cp_seq_dims, + no_restore_buffers=cp_no_restore_buffers, + ) + + +def create_consolidated_train_context( + cp_enabled: bool = False, + cp_mesh: Optional[torch.distributed.DeviceMesh] = None, + cp_buffers: Optional[List[torch.Tensor]] = None, + cp_seq_dims: Optional[List[int]] = None, + cp_no_restore_buffers: Optional[Set[torch.Tensor]] = None, +) -> Generator[None, None, None]: + """ + Creates a consolidated training context that combines loss parallel, compiled autograd, + and context parallel settings in a single function. + + This function combines the functionality of get_train_context and create_context_parallel_ctx + into a single context manager. + + Args: + cp_enabled: Whether context parallel is enabled. + cp_mesh: Device mesh for context parallel. + cp_buffers: List of tensors for context parallel buffers. + cp_seq_dims: List of sequence dimensions for context parallel buffers. + cp_no_restore_buffers: Set of tensors that should not be restored. + + Returns: + A context manager that applies all the specified contexts. + + Example: + ```python + with create_consolidated_train_context( + cp_enabled=parallel_dims.cp_enabled, + cp_mesh=world_mesh["cp"] if parallel_dims.cp_enabled else None, + cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts] if parallel_dims.cp_enabled else None, + cp_seq_dims=[1, 1] + [0 for _ in model_parts] if parallel_dims.cp_enabled else None, + cp_no_restore_buffers={inputs, labels} if parallel_dims.cp_enabled else None, + ): + # Training code here + ``` + """ + + @contextlib.contextmanager + def context(): + # Create context parallel context if enabled + cp_context = None + if ( + cp_enabled + and cp_mesh is not None + and cp_buffers is not None + and cp_seq_dims is not None + ): + cp_context = create_context_parallel_ctx( + cp_mesh=cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=cp_seq_dims, + cp_no_restore_buffers=cp_no_restore_buffers or set(), + ) + + # Create and enter the train context with the optional cp_context + train_context = get_train_context() + + with train_context(cp_context): + yield + + return context() diff --git a/torchtune/training/context_managers.py b/torchtune/training/context_managers.py new file mode 100644 index 0000000000..0f6a9be77a --- /dev/null +++ b/torchtune/training/context_managers.py @@ -0,0 +1,169 @@ +import contextlib +from typing import Callable, Generator, List, Optional, Set, TypeVar + +import torch + + +T = TypeVar("T") + + +def get_train_context( + enable_loss_parallel: bool, enable_compiled_autograd: bool +) -> Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]: + """ + Creates a training context that enables loss parallel and compiled autograd as specified. + + Args: + enable_loss_parallel: Whether to enable loss parallel. + enable_compiled_autograd: Whether to enable compiled autograd. + + Returns: + A context manager function that takes an optional context parallel context. + """ + + @contextlib.contextmanager + def context(cp_context: Generator[None, None, None] | None = None): + with contextlib.ExitStack() as stack: + if enable_loss_parallel: + stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) + + if enable_compiled_autograd: + stack.enter_context( + torch._dynamo.utils.maybe_enable_compiled_autograd(True) + ) + + if cp_context is not None: + from torch.nn.attention import sdpa_kernel, SDPBackend + + stack.enter_context( + sdpa_kernel( + [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.CUDNN_ATTENTION, + ] + ) + ) + stack.enter_context(cp_context) + + yield + + return context + + +def create_context_parallel_ctx( + cp_mesh: torch.distributed.DeviceMesh, + cp_buffers: List[torch.Tensor], + cp_seq_dims: List[int], + cp_no_restore_buffers: Set[torch.Tensor], + cp_rotate_method: str, +) -> Generator[None, None, None]: + """ + Creates a context parallel context. + + Args: + cp_mesh: Device mesh for context parallel. + cp_buffers: List of tensors for context parallel buffers. + cp_seq_dims: List of sequence dimensions for context parallel buffers. + cp_no_restore_buffers: Set of tensors that should not be restored. + cp_rotate_method: Method for rotating in context parallel. + + Returns: + A context manager for context parallel. + """ + try: + from torch.distributed.tensor.experimental import context_parallel + from torch.distributed.tensor.experimental._attention import set_rotate_method + except ImportError: + print( + f"PyTorch version {torch.__version__} does not include the experimental " + "Context Parallel API. Please update to a newer version." + ) + return contextlib.nullcontext() + + set_rotate_method(cp_rotate_method) + return context_parallel( + cp_mesh, + buffers=cp_buffers, + buffer_seq_dims=cp_seq_dims, + no_restore_buffers=cp_no_restore_buffers, + ) + + +def create_consolidated_train_context( + # Parameters for train context + enable_loss_parallel: bool, + enable_compiled_autograd: bool, + # Parameters for context parallel + cp_enabled: bool = False, + cp_mesh: Optional[torch.distributed.DeviceMesh] = None, + cp_buffers: Optional[List[torch.Tensor]] = None, + cp_seq_dims: Optional[List[int]] = None, + cp_no_restore_buffers: Optional[Set[torch.Tensor]] = None, + cp_rotate_method: Optional[str] = None, +) -> Generator[None, None, None]: + """ + Creates a consolidated training context that combines loss parallel, compiled autograd, + and context parallel settings in a single function. + + This function combines the functionality of get_train_context and create_context_parallel_ctx + into a single context manager. + + Args: + enable_loss_parallel: Whether to enable loss parallel. + enable_compiled_autograd: Whether to enable compiled autograd. + cp_enabled: Whether context parallel is enabled. + cp_mesh: Device mesh for context parallel. + cp_buffers: List of tensors for context parallel buffers. + cp_seq_dims: List of sequence dimensions for context parallel buffers. + cp_no_restore_buffers: Set of tensors that should not be restored. + cp_rotate_method: Method for rotating in context parallel. + + Returns: + A context manager that applies all the specified contexts. + + Example: + ```python + with create_consolidated_train_context( + enable_loss_parallel=parallel_dims.loss_parallel_enabled, + enable_compiled_autograd=parallelism_config.enable_compiled_autograd, + cp_enabled=parallel_dims.cp_enabled, + cp_mesh=world_mesh["cp"] if parallel_dims.cp_enabled else None, + cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts] if parallel_dims.cp_enabled else None, + cp_seq_dims=[1, 1] + [0 for _ in model_parts] if parallel_dims.cp_enabled else None, + cp_no_restore_buffers={inputs, labels} if parallel_dims.cp_enabled else None, + cp_rotate_method=job_config.parallelism.context_parallel_rotate_method if parallel_dims.cp_enabled else None, + ): + # Training code here + ``` + """ + + @contextlib.contextmanager + def context(): + # Create context parallel context if enabled + cp_context = None + if ( + cp_enabled + and cp_mesh is not None + and cp_buffers is not None + and cp_seq_dims is not None + and cp_rotate_method is not None + ): + cp_context = create_context_parallel_ctx( + cp_mesh=cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=cp_seq_dims, + cp_no_restore_buffers=cp_no_restore_buffers or set(), + cp_rotate_method=cp_rotate_method, + ) + + # Create and enter the train context with the optional cp_context + train_context = get_train_context( + enable_loss_parallel=enable_loss_parallel, + enable_compiled_autograd=enable_compiled_autograd, + ) + + with train_context(cp_context): + yield + + return context() From 91d874d7477eb6fa4f5bf2bee749c12a047cff5c Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Fri, 2 May 2025 09:06:29 -0700 Subject: [PATCH 4/5] some changes --- recipes/full_finetune_distributed.py | 13 +++++++++++-- torchtune/models/llama3_2/_model_builders.py | 18 ++++++++++++++---- torchtune/modules/attention.py | 2 +- torchtune/training/_distributed.py | 13 +++++++------ 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 8fb19af65c..1323886e57 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -797,8 +797,10 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: labels = batch.pop("labels") with self.activations_handling_ctx: + print("gonna forward model") outputs = self._model(**batch) - + print("done forward model") + print("gonna forward loss") if self.linear_loss: weight = self._model.linear_projection_weight loss = self._loss_fn(weight, outputs, labels) @@ -806,6 +808,7 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: labels = labels.reshape(-1) outputs = outputs.reshape(-1, outputs.size(-1)) loss = self._loss_fn(outputs, labels) + print("done forward loss") # free logits otherwise it peaks backward memory del outputs @@ -881,6 +884,7 @@ def train(self) -> None: pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero) self._dataloader.sampler.set_epoch(curr_epoch) for idx, batch in enumerate(self._dataloader): + print(batch["tokens"].shape) # Start tracking CUDA memory for active steps for just the first epoch if ( self._is_rank_zero @@ -897,7 +901,7 @@ def train(self) -> None: optional_context_parallel_context_manager = ( training.create_consolidated_train_context( cp_enabled=self.cp_degree > 1, - cp_mesh=self.world_mesh["cp"], + world_mesh=self.world_mesh, cp_buffers=model_inputs + buffers, cp_seq_dims=[1] * len(model_inputs) + [0] * len(buffers), cp_no_restore_buffers=set(model_inputs), @@ -921,14 +925,19 @@ def train(self) -> None: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients with optional_context_parallel_context_manager: + print("gonna forward") current_loss = self._loss_step(batch) * current_num_tokens + print("done forward") + running_loss += current_loss # For optimizer in backward, we need to normalize before calling backward # This case and gradient accumulation are mutually exclusive if self._optimizer_in_bwd: torch.distributed.all_reduce(num_tokens) torch.distributed.all_reduce(running_loss) current_loss = current_loss * (self.dp_degree / num_tokens) + print("gonna backward") current_loss.backward() + print("done backward") # Optimizer step (if not fused in backward call) if (idx + 1) % self._gradient_accumulation_steps == 0: diff --git a/torchtune/models/llama3_2/_model_builders.py b/torchtune/models/llama3_2/_model_builders.py index 506de5000d..f2860a3027 100644 --- a/torchtune/models/llama3_2/_model_builders.py +++ b/torchtune/models/llama3_2/_model_builders.py @@ -3,8 +3,8 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List from functools import partial +from typing import List from torchtune.models.llama3_2._component_builders import llama3_2, lora_llama3_2 @@ -16,12 +16,14 @@ the llama3_2_1b model builder uses the llama3_2 component builder to create the Llama3.2 1B model. """ + + def llama3_2_1b( tie_word_embeddings: bool = True, ) -> TransformerDecoder: """ Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values. - + Args: tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied. @@ -42,6 +44,8 @@ def llama3_2_1b( scale_factor=32, tie_word_embeddings=tie_word_embeddings, ) + + def llama3_2_3b( tie_word_embeddings: bool = True, ) -> TransformerDecoder: @@ -68,6 +72,8 @@ def llama3_2_3b( scale_factor=32, tie_word_embeddings=tie_word_embeddings, ) + + def lora_llama3_2_1b( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, @@ -84,7 +90,7 @@ def lora_llama3_2_1b( The Llama3.2 defaults are the same as in :func:`~torchtune.models.llama3_2.llama3_2_1b`, while LoRA default params are based on https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. - + Args: lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are @@ -126,6 +132,8 @@ def lora_llama3_2_1b( quantize_base=quantize_base, tie_word_embeddings=tie_word_embeddings, ) + + def lora_llama3_2_3b( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, @@ -162,7 +170,7 @@ def lora_llama3_2_3b( Returns: TransformerDecoder: Instantiation of Llama3.2 3B model with LoRA applied """ - + return lora_llama3_2( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, @@ -185,6 +193,8 @@ def lora_llama3_2_3b( quantize_base=quantize_base, tie_word_embeddings=tie_word_embeddings, ) + + qlora_llama3_2_1b = partial(lora_llama3_2_1b, quantize_base=True) qlora_llama3_2_1b.__doc__ = """ Builder for creating a Llama3.2 1B model with QLoRA enabled. Base model weights in linear layers diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index ff6faccb5d..169f23e2bf 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -289,7 +289,7 @@ def forward( expand_shape = (b, self.num_kv_heads, q_per_kv, -1, self.head_dim) k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) - + # torch.distributed.breakpoint() output = self._attention_call( q, k, diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index efc7c0833d..8d047cc347 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -88,8 +88,8 @@ def build_mesh(self, device_type): dims = [] names = [] for d, name in zip( - [self.dp_replicate, self.dp_shard, self.tp, self.cp], - ["dp_replicate", "dp_shard", "tp", "cp"], + [self.dp_replicate, self.dp_shard, self.cp, self.tp], + ["dp_replicate", "dp_shard", "cp", "tp"], ): if d > 1: dims.append(d) @@ -113,7 +113,7 @@ def build_mesh(self, device_type): dp_shard_cp_mesh_dim_names.append("dp_shard") dp_cp_mesh_dim_names.append("dp_shard") if self.cp_enabled: - dp_mesh_dim_names.append("cp") + dp_shard_cp_mesh_dim_names.append("cp") dp_cp_mesh_dim_names.append("cp") if dp_mesh_dim_names != []: @@ -903,9 +903,10 @@ def create_context_parallel_ctx( ) +# TODO: refactor these args def create_consolidated_train_context( cp_enabled: bool = False, - cp_mesh: Optional[torch.distributed.DeviceMesh] = None, + world_mesh: Optional[torch.distributed.DeviceMesh] = None, cp_buffers: Optional[List[torch.Tensor]] = None, cp_seq_dims: Optional[List[int]] = None, cp_no_restore_buffers: Optional[Set[torch.Tensor]] = None, @@ -946,12 +947,12 @@ def context(): cp_context = None if ( cp_enabled - and cp_mesh is not None + and world_mesh is not None and cp_buffers is not None and cp_seq_dims is not None ): cp_context = create_context_parallel_ctx( - cp_mesh=cp_mesh, + cp_mesh=world_mesh["cp"], cp_buffers=cp_buffers, cp_seq_dims=cp_seq_dims, cp_no_restore_buffers=cp_no_restore_buffers or set(), From da41f809ddbc3412bc7076fcb8143250087628f3 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Fri, 2 May 2025 16:01:41 -0700 Subject: [PATCH 5/5] cleanup --- recipes/full_finetune_distributed.py | 28 +--- torchtune/modules/attention.py | 2 +- torchtune/training/__init__.py | 5 +- torchtune/training/_distributed.py | 207 ++++++------------------- torchtune/training/context_managers.py | 169 -------------------- 5 files changed, 55 insertions(+), 356 deletions(-) delete mode 100644 torchtune/training/context_managers.py diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 5ce72005cd..5f4a0da71c 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -796,10 +796,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: labels = batch.pop("labels") with self.activations_handling_ctx: - print("gonna forward model") outputs = self._model(**batch) - print("done forward model") - print("gonna forward loss") + if self.linear_loss: weight = self._model.linear_projection_weight loss = self._loss_fn(weight, outputs, labels) @@ -807,7 +805,6 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: labels = labels.reshape(-1) outputs = outputs.reshape(-1, outputs.size(-1)) loss = self._loss_fn(outputs, labels) - print("done forward loss") # free logits otherwise it peaks backward memory del outputs @@ -883,7 +880,6 @@ def train(self) -> None: pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero) self._dataloader.sampler.set_epoch(curr_epoch) for idx, batch in enumerate(self._dataloader): - print(batch["tokens"].shape) # Start tracking CUDA memory for active steps for just the first epoch if ( self._is_rank_zero @@ -895,25 +891,19 @@ def train(self) -> None: torch.cuda.memory._record_memory_history() utils.batch_to_device(batch, self._device) + + # Define optional context manager for context parallelism model_inputs = list(batch.values()) buffers = list(self._model.buffers()) optional_context_parallel_context_manager = ( - training.create_consolidated_train_context( + training.get_context_parallel_context( cp_enabled=self.cp_degree > 1, world_mesh=self.world_mesh, - cp_buffers=model_inputs + buffers, - cp_seq_dims=[1] * len(model_inputs) + [0] * len(buffers), - cp_no_restore_buffers=set(model_inputs), + model_inputs=list(batch.values()), + model_buffers=list(self._model.buffers()), ) ) - # optional_context_parallel_context_manager = ( - # training.get_context_parallel_context_manager( - # cp_enabled=self.cp_degree > 1, - # model=self._model, - # cp_mesh=self.world_mesh["cp"], - # model_inputs=list(batch.values()), - # ) - # ) + # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step current_num_tokens = ( @@ -924,9 +914,7 @@ def train(self) -> None: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients with optional_context_parallel_context_manager: - print("gonna forward") current_loss = self._loss_step(batch) * current_num_tokens - print("done forward") running_loss += current_loss # For optimizer in backward, we need to normalize before calling backward # This case and gradient accumulation are mutually exclusive @@ -934,9 +922,7 @@ def train(self) -> None: torch.distributed.all_reduce(num_tokens) torch.distributed.all_reduce(running_loss) current_loss = current_loss * (self.dp_degree / num_tokens) - print("gonna backward") current_loss.backward() - print("done backward") # Optimizer step (if not fused in backward call) if (idx + 1) % self._gradient_accumulation_steps == 0: diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 169f23e2bf..ff6faccb5d 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -289,7 +289,7 @@ def forward( expand_shape = (b, self.num_kv_heads, q_per_kv, -1, self.head_dim) k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) - # torch.distributed.breakpoint() + output = self._attention_call( q, k, diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index 5ea8d70a1f..824b7c26a7 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -10,9 +10,8 @@ ) from torchtune.training._compile import compile_loss, compile_model from torchtune.training._distributed import ( - create_consolidated_train_context, gather_cpu_state_dict, - get_context_parallel_context_manager, + get_context_parallel_context, get_distributed_backend, get_full_optimizer_state_dict, get_shard_conditions, @@ -144,5 +143,5 @@ "get_distributed_backend", "disable_dropout", "DATALOADER_KEY", - "get_context_parallel_context_manager", + "get_context_parallel_context", ] diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 8d047cc347..13be3c9192 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -11,7 +11,7 @@ from dataclasses import dataclass from functools import cached_property from itertools import chain -from typing import Any, Callable, cast, Dict, List, Optional, Tuple +from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple import torch import torch.distributed as dist @@ -745,100 +745,11 @@ def prepare_mha_for_tp( return model -@contextlib.contextmanager -def get_context_parallel_context_manager( - cp_enabled: bool, - model: TransformerDecoder, # TODO: generalize - cp_mesh: DeviceMesh, - model_inputs: List[torch.Tensor], -): - set_rotate_method("allgather") # TODO: hardcode for now - buffers = list(model.buffers()) - if not cp_enabled: - return contextlib.nullcontext() # noqa - else: - with contextlib.ExitStack() as stack: - stack.enter_context( - sdpa_kernel( - [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.CUDNN_ATTENTION, - ] - ) - ) - stack.enter_context( - context_parallel( - cp_mesh, - buffers=model_inputs + buffers, - buffer_seq_dims=[1] * len(model_inputs) + [0] * len(buffers), - no_restore_buffers=set(model_inputs), - ) - ) - yield # noqa - - -# TODO: move this elsewhere - - -# @contextlib.contextma -# def get_context_parallel_context_manager( -# cp_enabled: bool, -# model: TransformerDecoder, # TODO: generalize -# cp_mesh: DeviceMesh, -# model_inputs: List[torch.Tensor], -# ) -> Union[Generator[None, None, None], contextlib.nullcontext]: -# if not cp_enabled: -# return contextlib.nullcontext() -# # if "cp" not in mesh: -# # raise ValueError("CP mesh not found in device mesh") -# # cp_mesh = mesh["cp"] -# set_rotate_method("allgather") # TODO: hardcode for now -# buffers = list(model.buffers()) - -# @contextlib.contextmanager -# def context(): -# with contextlib.ExitStack() as stack: -# stack.enter_context( -# sdpa_kernel( -# [ -# SDPBackend.FLASH_ATTENTION, -# SDPBackend.EFFICIENT_ATTENTION, -# SDPBackend.CUDNN_ATTENTION, -# ] -# ) -# ) - -# cp_context = context_parallel( -# cp_mesh, -# buffers=model_inputs + buffers, -# buffer_seq_dims=[1] * len(model_inputs) + [0] * len(buffers), -# no_restore_buffers=set(model_inputs), -# ) - -# stack.enter_context(cp_context) -# yield - -# return context - -import contextlib -from typing import Callable, Generator, List, Optional, Set, TypeVar - -import torch - - -T = TypeVar("T") - - -def get_train_context() -> ( +def _get_sdpa_context() -> ( Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]] ): """ - Creates a training context that enables loss parallel and compiled autograd as specified. - - Args: - enable_loss_parallel: Whether to enable loss parallel. - enable_compiled_autograd: Whether to enable compiled autograd. + Creates a context manager to confine to flash/efficient/cuDNN attention backends. Returns: A context manager function that takes an optional context parallel context. @@ -848,8 +759,6 @@ def get_train_context() -> ( def context(cp_context: Generator[None, None, None] | None = None): with contextlib.ExitStack() as stack: if cp_context is not None: - from torch.nn.attention import sdpa_kernel, SDPBackend - stack.enter_context( sdpa_kernel( [ @@ -866,81 +775,54 @@ def context(cp_context: Generator[None, None, None] | None = None): return context -def create_context_parallel_ctx( - cp_mesh: torch.distributed.DeviceMesh, - cp_buffers: List[torch.Tensor], - cp_seq_dims: List[int], - cp_no_restore_buffers: Set[torch.Tensor], -) -> Generator[None, None, None]: - """ - Creates a context parallel context. - - Args: - cp_mesh: Device mesh for context parallel. - cp_buffers: List of tensors for context parallel buffers. - cp_seq_dims: List of sequence dimensions for context parallel buffers. - cp_no_restore_buffers: Set of tensors that should not be restored. - - Returns: - A context manager for context parallel. - """ - try: - from torch.distributed.tensor.experimental import context_parallel - from torch.distributed.tensor.experimental._attention import set_rotate_method - except ImportError: - print( - f"PyTorch version {torch.__version__} does not include the experimental " - "Context Parallel API. Please update to a newer version." - ) - return contextlib.nullcontext() - - set_rotate_method("allgather") - return context_parallel( - cp_mesh, - buffers=cp_buffers, - buffer_seq_dims=cp_seq_dims, - no_restore_buffers=cp_no_restore_buffers, - ) - - -# TODO: refactor these args -def create_consolidated_train_context( +def get_context_parallel_context( + *, cp_enabled: bool = False, - world_mesh: Optional[torch.distributed.DeviceMesh] = None, - cp_buffers: Optional[List[torch.Tensor]] = None, - cp_seq_dims: Optional[List[int]] = None, - cp_no_restore_buffers: Optional[Set[torch.Tensor]] = None, + world_mesh: torch.distributed.DeviceMesh, + model_inputs: List[torch.Tensor], + model_buffers: Optional[List[torch.Tensor]] = None, ) -> Generator[None, None, None]: """ - Creates a consolidated training context that combines loss parallel, compiled autograd, - and context parallel settings in a single function. - - This function combines the functionality of get_train_context and create_context_parallel_ctx - into a single context manager. + Context manager for applying context parallelism to a model. In addition to applying the + standard context manager to patch SDPA and shard model inputs and buffers along the sequence + dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends. Args: - cp_enabled: Whether context parallel is enabled. - cp_mesh: Device mesh for context parallel. - cp_buffers: List of tensors for context parallel buffers. - cp_seq_dims: List of sequence dimensions for context parallel buffers. - cp_no_restore_buffers: Set of tensors that should not be restored. + cp_enabled (bool): Whether context parallel is enabled. Default: False + world_mesh (torch.distributed.DeviceMesh): Global device mesh. + model_inputs (List[torch.Tensor]): List of any model inputs which should be + sharded along sequence dimension. + model_buffers (Optional[List[torch.Tensor]]): List of any model buffers. These should also be + sharded along sequence dimension. Default: None Returns: - A context manager that applies all the specified contexts. + A context manager applying context parallelism if cp_enabled is True. Otherwise a context manager + disabling the math SDPA backend. + + Raises: + ValueError: if cp_enabled is True but world_mesh does not contain a "cp" dimension Example: ```python - with create_consolidated_train_context( + batch = {"inputs": inputs, "labels": labels} + with get_context_parallel_context( cp_enabled=parallel_dims.cp_enabled, cp_mesh=world_mesh["cp"] if parallel_dims.cp_enabled else None, - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts] if parallel_dims.cp_enabled else None, - cp_seq_dims=[1, 1] + [0 for _ in model_parts] if parallel_dims.cp_enabled else None, - cp_no_restore_buffers={inputs, labels} if parallel_dims.cp_enabled else None, + model_inputs=list(batch.values()), + model_buffers=model.buffers(), ): - # Training code here + logits = model(inputs) + loss = loss(logits, labels) + loss.backward() ``` """ + if cp_enabled and "cp" not in world_mesh.mesh_dim_names: + raise ValueError( + "Context parallel is enabled but no context parallel device mesh is provided." + ) + model_buffers = model_buffers or [] + @contextlib.contextmanager def context(): # Create context parallel context if enabled @@ -948,20 +830,21 @@ def context(): if ( cp_enabled and world_mesh is not None - and cp_buffers is not None - and cp_seq_dims is not None + and model_inputs + model_buffers is not None ): - cp_context = create_context_parallel_ctx( - cp_mesh=world_mesh["cp"], - cp_buffers=cp_buffers, - cp_seq_dims=cp_seq_dims, - cp_no_restore_buffers=cp_no_restore_buffers or set(), + # TODO: do we need to parametrize this? + set_rotate_method("allgather") + cp_context = context_parallel( + world_mesh["cp"], + buffers=model_inputs + model_buffers, + buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers), + no_restore_buffers=set(model_inputs), ) # Create and enter the train context with the optional cp_context - train_context = get_train_context() + sdpa_context = _get_sdpa_context() - with train_context(cp_context): + with sdpa_context(cp_context): yield return context() diff --git a/torchtune/training/context_managers.py b/torchtune/training/context_managers.py deleted file mode 100644 index 0f6a9be77a..0000000000 --- a/torchtune/training/context_managers.py +++ /dev/null @@ -1,169 +0,0 @@ -import contextlib -from typing import Callable, Generator, List, Optional, Set, TypeVar - -import torch - - -T = TypeVar("T") - - -def get_train_context( - enable_loss_parallel: bool, enable_compiled_autograd: bool -) -> Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]: - """ - Creates a training context that enables loss parallel and compiled autograd as specified. - - Args: - enable_loss_parallel: Whether to enable loss parallel. - enable_compiled_autograd: Whether to enable compiled autograd. - - Returns: - A context manager function that takes an optional context parallel context. - """ - - @contextlib.contextmanager - def context(cp_context: Generator[None, None, None] | None = None): - with contextlib.ExitStack() as stack: - if enable_loss_parallel: - stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) - - if enable_compiled_autograd: - stack.enter_context( - torch._dynamo.utils.maybe_enable_compiled_autograd(True) - ) - - if cp_context is not None: - from torch.nn.attention import sdpa_kernel, SDPBackend - - stack.enter_context( - sdpa_kernel( - [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.CUDNN_ATTENTION, - ] - ) - ) - stack.enter_context(cp_context) - - yield - - return context - - -def create_context_parallel_ctx( - cp_mesh: torch.distributed.DeviceMesh, - cp_buffers: List[torch.Tensor], - cp_seq_dims: List[int], - cp_no_restore_buffers: Set[torch.Tensor], - cp_rotate_method: str, -) -> Generator[None, None, None]: - """ - Creates a context parallel context. - - Args: - cp_mesh: Device mesh for context parallel. - cp_buffers: List of tensors for context parallel buffers. - cp_seq_dims: List of sequence dimensions for context parallel buffers. - cp_no_restore_buffers: Set of tensors that should not be restored. - cp_rotate_method: Method for rotating in context parallel. - - Returns: - A context manager for context parallel. - """ - try: - from torch.distributed.tensor.experimental import context_parallel - from torch.distributed.tensor.experimental._attention import set_rotate_method - except ImportError: - print( - f"PyTorch version {torch.__version__} does not include the experimental " - "Context Parallel API. Please update to a newer version." - ) - return contextlib.nullcontext() - - set_rotate_method(cp_rotate_method) - return context_parallel( - cp_mesh, - buffers=cp_buffers, - buffer_seq_dims=cp_seq_dims, - no_restore_buffers=cp_no_restore_buffers, - ) - - -def create_consolidated_train_context( - # Parameters for train context - enable_loss_parallel: bool, - enable_compiled_autograd: bool, - # Parameters for context parallel - cp_enabled: bool = False, - cp_mesh: Optional[torch.distributed.DeviceMesh] = None, - cp_buffers: Optional[List[torch.Tensor]] = None, - cp_seq_dims: Optional[List[int]] = None, - cp_no_restore_buffers: Optional[Set[torch.Tensor]] = None, - cp_rotate_method: Optional[str] = None, -) -> Generator[None, None, None]: - """ - Creates a consolidated training context that combines loss parallel, compiled autograd, - and context parallel settings in a single function. - - This function combines the functionality of get_train_context and create_context_parallel_ctx - into a single context manager. - - Args: - enable_loss_parallel: Whether to enable loss parallel. - enable_compiled_autograd: Whether to enable compiled autograd. - cp_enabled: Whether context parallel is enabled. - cp_mesh: Device mesh for context parallel. - cp_buffers: List of tensors for context parallel buffers. - cp_seq_dims: List of sequence dimensions for context parallel buffers. - cp_no_restore_buffers: Set of tensors that should not be restored. - cp_rotate_method: Method for rotating in context parallel. - - Returns: - A context manager that applies all the specified contexts. - - Example: - ```python - with create_consolidated_train_context( - enable_loss_parallel=parallel_dims.loss_parallel_enabled, - enable_compiled_autograd=parallelism_config.enable_compiled_autograd, - cp_enabled=parallel_dims.cp_enabled, - cp_mesh=world_mesh["cp"] if parallel_dims.cp_enabled else None, - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts] if parallel_dims.cp_enabled else None, - cp_seq_dims=[1, 1] + [0 for _ in model_parts] if parallel_dims.cp_enabled else None, - cp_no_restore_buffers={inputs, labels} if parallel_dims.cp_enabled else None, - cp_rotate_method=job_config.parallelism.context_parallel_rotate_method if parallel_dims.cp_enabled else None, - ): - # Training code here - ``` - """ - - @contextlib.contextmanager - def context(): - # Create context parallel context if enabled - cp_context = None - if ( - cp_enabled - and cp_mesh is not None - and cp_buffers is not None - and cp_seq_dims is not None - and cp_rotate_method is not None - ): - cp_context = create_context_parallel_ctx( - cp_mesh=cp_mesh, - cp_buffers=cp_buffers, - cp_seq_dims=cp_seq_dims, - cp_no_restore_buffers=cp_no_restore_buffers or set(), - cp_rotate_method=cp_rotate_method, - ) - - # Create and enter the train context with the optional cp_context - train_context = get_train_context( - enable_loss_parallel=enable_loss_parallel, - enable_compiled_autograd=enable_compiled_autograd, - ) - - with train_context(cp_context): - yield - - return context()