Skip to content

DDP model not synchronizing when static_graph=True #20704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
hyukkyukang opened this issue Apr 8, 2025 · 0 comments
Open

DDP model not synchronizing when static_graph=True #20704

hyukkyukang opened this issue Apr 8, 2025 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.5.x

Comments

@hyukkyukang
Copy link

hyukkyukang commented Apr 8, 2025

Bug description

When using PyTorch Lightning with DDP and static_graph=True, model parameters are not synchronized properly across processes. I tested this against vanilla PyTorch DDP and confirmed that the issue only appears in Lightning.

📄 Minimal Reproducible Example

I created a minimal script that compares model parameter changes across DDP processes after each optimizer step. It runs 2 training steps and logs the changed indices and delta of the weights from the first fully connected layer.

This script runs 4 experiments:
• Lightning with static_graph=True
• Lightning with static_graph=False
• Vanilla PyTorch DDP with static_graph=True
• Vanilla PyTorch DDP with static_graph=False

Only the Lightning + static_graph=True case shows inconsistent or missing synchronization.

🔍 Observed Behavior
• When using Lightning + DDP + static_graph=True, each GPU maintains a different version of the model after training steps.
• When using Vanilla PyTorch DDP + static_graph=True, synchronization works as expected.

✅ Expected Behavior

Model parameters should remain synchronized across DDP processes, even when static_graph=True.

What version are you seeing the problem on?

master

How to reproduce the bug

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

# Lightning imports
import lightning as L
from lightning import Trainer
from lightning.pytorch.strategies import DDPStrategy

# PyTorch DDP imports
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# ------------------------------------------------------------------------------
# Helper Function for Reporting Weight Changes
# ------------------------------------------------------------------------------
def report_weight_changes(rank, mode_name, step, prev_weights, prev_weights_sum, current_weights, suffix):
    """
    Computes and reports the change in parameters (for fc1 layer) between training steps.
    
    Args:
        rank (int): Process/GPU rank.
        mode_name (str): Mode name (e.g., "Lightning" or "PyTorch").
        step (int): The current training step (or batch index).
        prev_weights (Tensor or None): The weight vector from the previous step.
        prev_weights_sum (Tensor or None): Sum of the previous weights.
        current_weights (Tensor): The weight vector at the current step.
        suffix (str): Suffix to be appended to log file name (e.g., 'sgTrue' or 'sgFalse').
        
    Returns:
        current_weights_sum, current_weights: Updated weight sum and weight vector.
    """
    current_weights_sum = current_weights.sum()
    # Only report if we have previous weights to compare.
    if prev_weights_sum is not None:
        delta_weights = current_weights - prev_weights
        changed_indices = delta_weights.nonzero()[:10]
        file_name = f"{mode_name}_{suffix}_{rank}.txt"
        with open(file_name, "a") as f:
            f.write(f"[{mode_name} GPU {rank}] Step {step} Changed indices: {changed_indices.tolist()}\n")
            f.write(f"[{mode_name} GPU {rank}] Step {step} Weight delta: {delta_weights[changed_indices]}\n")
    return current_weights_sum, current_weights

# ------------------------------------------------------------------------------
# Shared Model Definition
# ------------------------------------------------------------------------------
class BaseModel(nn.Module):
    def __init__(self):
        super(BaseModel, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# ------------------------------------------------------------------------------
# PyTorch Lightning Module & Data Module
# ------------------------------------------------------------------------------
class LitClassifier(L.LightningModule):
    def __init__(self, graph_mode):
        """
        Args:
            graph_mode (bool): True if using static_graph=True, False otherwise.
        """
        super().__init__()
        self.model = BaseModel()
        # For tracking weight changes for fc1
        self.prev_weights_sum = None
        self.prev_weights = None
        self.automatic_optimization = False  # using manual optimization.
        # Save the graph mode suffix for logging purposes.
        self.graph_suffix = f"sg{'True' if graph_mode else 'False'}"

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()  # manual optimizer access.
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        opt.zero_grad()
        self.manual_backward(loss)
        opt.step()

        # Compute current weights for fc1 and compare with previous step.
        current_weights = self.model.fc1.weight.data.view(-1).clone().detach().cpu()
        self.prev_weights_sum, self.prev_weights = report_weight_changes(
            self.global_rank, "Lightning", batch_idx, self.prev_weights, self.prev_weights_sum, current_weights,
            self.graph_suffix
        )
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=1e-3)

class MNISTDataModule(L.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        datasets.MNIST("data", train=True, download=True)
        datasets.MNIST("data", train=False, download=True)

    def setup(self, stage=None):
        transform = transforms.ToTensor()
        full_dataset = datasets.MNIST("data", train=True, transform=transform)
        self.train_set, _ = random_split(full_dataset, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, num_workers=4)

# ------------------------------------------------------------------------------
# Vanilla PyTorch DDP Implementation
# ------------------------------------------------------------------------------
def setup_ddp(rank, world_size, port):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup_ddp():
    dist.destroy_process_group()

def prepare_dataloader_ddp(rank, world_size, batch_size=64):
    transform = transforms.ToTensor()
    dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
    train_set, _ = random_split(dataset, [55000, 5000])
    sampler = torch.utils.data.distributed.DistributedSampler(train_set, num_replicas=world_size, rank=rank)
    return DataLoader(train_set, batch_size=batch_size, sampler=sampler, num_workers=4)

def ddp_train(rank, world_size, steps, static_graph, port):
    setup_ddp(rank, world_size, port)
    device = torch.device(f"cuda:{rank}")
    model = BaseModel().to(device)
    # Pass the static_graph flag from the argument.
    ddp_model = DDP(model, device_ids=[rank], static_graph=static_graph)
    optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
    train_loader = prepare_dataloader_ddp(rank, world_size)

    # Determine the suffix for the log filename.
    graph_suffix = f"sg{'True' if static_graph else 'False'}"

    step = 0
    prev_weights_sum = None
    prev_weights = None
    ddp_model.train()

    for epoch in range(10):  # Loop over epochs if necessary.
        for batch_idx, (x, y) in enumerate(train_loader):
            if step >= steps:
                break
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = ddp_model(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                current_weights = ddp_model.module.fc1.weight.data.view(-1).clone().detach().cpu()
                prev_weights_sum, prev_weights = report_weight_changes(
                    rank, "PyTorch", step, prev_weights, prev_weights_sum, current_weights, graph_suffix
                )
            step += 1
        if step >= steps:
            break
    cleanup_ddp()

# ------------------------------------------------------------------------------
# Main: Running Both Versions for Lightning and PyTorch
# ------------------------------------------------------------------------------
def run_lightning(static_graph):
    print(f"Running Lightning mode with static_graph={static_graph} for 2 training steps")
    model = LitClassifier(graph_mode=static_graph)
    dm = MNISTDataModule(batch_size=64)
    trainer = Trainer(
        max_epochs=1,
        accelerator="gpu",
        devices=torch.cuda.device_count(),
        strategy=DDPStrategy(static_graph=static_graph),
        num_sanity_val_steps=0,
        deterministic=True,
        limit_train_batches=2,
    )
    trainer.fit(model, dm)

def run_pytorch(static_graph, port):
    print(f"Running vanilla PyTorch DDP mode with static_graph={static_graph} for 2 training steps (port={port})")
    world_size = torch.cuda.device_count()
    mp.spawn(ddp_train, args=(world_size, 2, static_graph, port), nprocs=world_size, join=True)

if __name__ == "__main__":
    # Run Lightning with static_graph True and False:
    run_lightning(static_graph=True)
    run_lightning(static_graph=False)
    # Run vanilla PyTorch DDP with static_graph True and False on different ports.
    run_pytorch(static_graph=True, port=12356)
    run_pytorch(static_graph=False, port=12357)

Error messages and logs

Running the reproducing script, we can check that Pytorch Lightning DDP with static_graph=True has different model parameters across different processes throughout the training step.

Lightning_SGTrue_0.txt:

[Lightning GPU 0] Step 1 Changed indices: [[67], [68], [69], [70], [71], [72], [73], [74], [95], [96]]
[Lightning GPU 0] Step 1 Weight delta: tensor([[-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007]])

Lightning_SGTrue_1.txt:

[Lightning GPU 1] Step 1 Changed indices: [[39], [40], [66], [67], [68], [69], [70], [71], [72], [94]]
[Lightning GPU 1] Step 1 Weight delta: tensor([[-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0008],
        [-0.0007],
        [-0.0010],
        [-0.0007],
        [-0.0007],
        [-0.0007]])

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA RTX A6000
    - NVIDIA RTX A6000
    - NVIDIA RTX A6000
    - NVIDIA RTX A6000
    - available: True
    - version: 12.6
  • Lightning:
    - lightning: 2.5.0.post0
    - lightning-sdk: 0.2.5
    - lightning-utilities: 0.14.0
    - lion-pytorch: 0.2.3
    - pytorch-lightning: 2.5.0.post0
    - pytorch-triton: 3.3.0+git96316ce5
    - torch: 2.8.0.dev20250407+cu126
    - torch-tb-profiler: 0.4.3
    - torchmetrics: 1.6.2
    - torchvision: 0.22.0.dev20250407+cu126
  • Packages:
    - absl-py: 2.1.0
    - accelerate: 1.4.0
    - aiohappyeyeballs: 2.5.0
    - aiohttp: 3.11.13
    - aiosignal: 1.3.2
    - annotated-types: 0.7.0
    - antlr4-python3-runtime: 4.9.3
    - anyio: 4.9.0
    - attn-gym: 0.0.4.dev12+g41a96b6
    - attrs: 25.1.0
    - autocommand: 2.2.2
    - backoff: 2.2.1
    - backports.tarfile: 1.2.0
    - beautifulsoup4: 4.13.3
    - blinker: 1.4
    - boto3: 1.37.10
    - botocore: 1.37.10
    - bs4: 0.0.2
    - certifi: 2025.1.31
    - charset-normalizer: 3.4.1
    - click: 8.1.8
    - cloudpickle: 3.1.1
    - contourpy: 1.3.1
    - cryptography: 3.4.8
    - cssselect: 1.3.0
    - cycler: 0.12.1
    - dacite: 1.9.2
    - datasets: 3.3.2
    - dbus-python: 1.2.18
    - dill: 0.3.8
    - distro: 1.7.0
    - distro-info: 1.1+ubuntu0.2
    - docker: 7.1.0
    - docker-pycreds: 0.4.0
    - einops: 0.8.1
    - faiss: 1.10.0
    - fastapi: 0.115.11
    - feedfinder2: 0.0.4
    - feedparser: 6.0.11
    - filelock: 3.16.1
    - flash-attn: 2.7.4.post1
    - fonttools: 4.56.0
    - frozenlist: 1.5.0
    - fsspec: 2024.10.0
    - ftfy: 6.3.1
    - gitdb: 4.0.12
    - gitpython: 3.1.44
    - grpcio: 1.71.0
    - h11: 0.14.0
    - h5py: 3.13.0
    - hkkang-utils: 0.2.57
    - htmlmin: 0.1.12
    - httplib2: 0.20.2
    - huggingface-hub: 0.29.3
    - hydra-core: 1.3.2
    - idna: 3.10
    - importlib-metadata: 8.0.0
    - inflect: 7.3.1
    - jaraco.collections: 5.1.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jaraco.text: 3.12.1
    - jeepney: 0.7.1
    - jieba3k: 0.35.1
    - jinja2: 3.1.4
    - jmespath: 1.0.1
    - joblib: 1.4.2
    - jsonargparse: 4.37.0
    - keyring: 23.5.0
    - kiwisolver: 1.4.8
    - langdetect: 1.0.9
    - launchpadlib: 1.10.16
    - lazr.restfulclient: 0.14.4
    - lazr.uri: 1.0.6
    - legacy-cgi: 2.6.2
    - lightning: 2.5.0.post0
    - lightning-sdk: 0.2.5
    - lightning-utilities: 0.14.0
    - lion-pytorch: 0.2.3
    - lxml: 5.3.1
    - lxml-html-clean: 0.4.1
    - markdown: 3.7
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.10.1
    - mdurl: 0.1.2
    - more-itertools: 10.3.0
    - mpmath: 1.3.0
    - multidict: 6.1.0
    - multiprocess: 0.70.16
    - networkx: 3.4.2
    - newspaper3k: 0.2.8
    - nltk: 3.9.1
    - numpy: 2.2.3
    - nvidia-cublas-cu12: 12.6.4.1
    - nvidia-cuda-cupti-cu12: 12.6.80
    - nvidia-cuda-nvrtc-cu12: 12.6.77
    - nvidia-cuda-runtime-cu12: 12.6.77
    - nvidia-cudnn-cu12: 9.5.1.17
    - nvidia-cufft-cu12: 11.3.0.4
    - nvidia-cufile-cu12: 1.11.1.6
    - nvidia-curand-cu12: 10.3.7.77
    - nvidia-cusolver-cu12: 11.7.1.2
    - nvidia-cusparse-cu12: 12.5.4.2
    - nvidia-cusparselt-cu12: 0.6.3
    - nvidia-nccl-cu12: 2.26.2
    - nvidia-nvjitlink-cu12: 12.6.85
    - nvidia-nvtx-cu12: 12.6.77
    - oauthlib: 3.2.0
    - omegaconf: 2.3.0
    - orjson: 3.10.15
    - packaging: 24.2
    - pandas: 2.2.3
    - pglast: 7.3
    - pillow: 11.1.0
    - pip: 25.0.1
    - platformdirs: 4.2.2
    - propcache: 0.3.0
    - protobuf: 5.29.3
    - psutil: 7.0.0
    - psycopg: 3.2.5
    - psycopg-binary: 3.2.5
    - psycopg-pool: 3.2.6
    - pyarrow: 19.0.1
    - pydantic: 2.10.6
    - pydantic-core: 2.27.2
    - pygments: 2.19.1
    - pygobject: 3.42.1
    - pyjwt: 2.3.0
    - pyparsing: 2.4.7
    - python-apt: 2.4.0+ubuntu4
    - python-dateutil: 2.9.0.post0
    - python-dotenv: 1.0.1
    - pytorch-lightning: 2.5.0.post0
    - pytorch-triton: 3.3.0+git96316ce5
    - pytz: 2025.1
    - pyyaml: 6.0.2
    - regex: 2024.11.6
    - requests: 2.32.3
    - requests-file: 2.1.0
    - rich: 13.9.4
    - s3transfer: 0.11.4
    - safetensors: 0.5.3
    - secretstorage: 3.3.1
    - sentencepiece: 0.2.0
    - sentry-sdk: 2.22.0
    - setproctitle: 1.3.5
    - setuptools: 75.8.0
    - sgmllib3k: 1.0.0
    - simple-term-menu: 1.6.6
    - six: 1.16.0
    - slack-sdk: 3.34.0
    - smmap: 5.0.2
    - sniffio: 1.3.1
    - soupsieve: 2.6
    - standard-imghdr: 3.13.0
    - starlette: 0.46.1
    - sympy: 1.13.3
    - tensorboard: 2.19.0
    - tensorboard-data-server: 0.7.2
    - tensordict: 0.7.2
    - tinysegmenter: 0.3
    - tldextract: 5.1.3
    - tokenizers: 0.21.0
    - tomli: 2.0.1
    - torch: 2.8.0.dev20250407+cu126
    - torch-tb-profiler: 0.4.3
    - torchmetrics: 1.6.2
    - torchvision: 0.22.0.dev20250407+cu126
    - tqdm: 4.67.1
    - transformers: 4.49.0
    - triton: 3.2.0
    - typeguard: 4.3.0
    - typing-extensions: 4.12.2
    - tzdata: 2025.1
    - ujson: 5.10.0
    - unattended-upgrades: 0.1
    - urllib3: 2.3.0
    - uvicorn: 0.34.0
    - wadllib: 1.3.6
    - wandb: 0.19.8
    - wcwidth: 0.2.13
    - websocket-client: 1.8.0
    - werkzeug: 3.1.3
    - wget: 3.2
    - wheel: 0.43.0
    - xxhash: 3.5.0
    - yarl: 1.18.3
    - zipp: 3.19.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.13.2
    - release: 5.15.0-107-generic
    - version: Errata in the readme? #117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024

More info

No response

@hyukkyukang hyukkyukang added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 8, 2025
@hyukkyukang hyukkyukang changed the title DDP model synchronization issue when static_graph=True DDP model not synchronization when static_graph=True Apr 9, 2025
@hyukkyukang hyukkyukang changed the title DDP model not synchronization when static_graph=True DDP model not synchronizing when static_graph=True Apr 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.5.x
Projects
None yet
Development

No branches or pull requests

1 participant