You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using PyTorch Lightning with DDP and static_graph=True, model parameters are not synchronized properly across processes. I tested this against vanilla PyTorch DDP and confirmed that the issue only appears in Lightning.
📄 Minimal Reproducible Example
I created a minimal script that compares model parameter changes across DDP processes after each optimizer step. It runs 2 training steps and logs the changed indices and delta of the weights from the first fully connected layer.
This script runs 4 experiments:
• Lightning with static_graph=True
• Lightning with static_graph=False
• Vanilla PyTorch DDP with static_graph=True
• Vanilla PyTorch DDP with static_graph=False
Only the Lightning + static_graph=True case shows inconsistent or missing synchronization.
🔍 Observed Behavior
• When using Lightning + DDP + static_graph=True, each GPU maintains a different version of the model after training steps.
• When using Vanilla PyTorch DDP + static_graph=True, synchronization works as expected.
✅ Expected Behavior
Model parameters should remain synchronized across DDP processes, even when static_graph=True.
What version are you seeing the problem on?
master
How to reproduce the bug
importosimportargparseimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFfromtorch.utils.dataimportDataLoader, random_splitfromtorchvisionimportdatasets, transforms# Lightning importsimportlightningasLfromlightningimportTrainerfromlightning.pytorch.strategiesimportDDPStrategy# PyTorch DDP importsimporttorch.multiprocessingasmpimporttorch.distributedasdistfromtorch.nn.parallelimportDistributedDataParallelasDDP# ------------------------------------------------------------------------------# Helper Function for Reporting Weight Changes# ------------------------------------------------------------------------------defreport_weight_changes(rank, mode_name, step, prev_weights, prev_weights_sum, current_weights, suffix):
""" Computes and reports the change in parameters (for fc1 layer) between training steps. Args: rank (int): Process/GPU rank. mode_name (str): Mode name (e.g., "Lightning" or "PyTorch"). step (int): The current training step (or batch index). prev_weights (Tensor or None): The weight vector from the previous step. prev_weights_sum (Tensor or None): Sum of the previous weights. current_weights (Tensor): The weight vector at the current step. suffix (str): Suffix to be appended to log file name (e.g., 'sgTrue' or 'sgFalse'). Returns: current_weights_sum, current_weights: Updated weight sum and weight vector. """current_weights_sum=current_weights.sum()
# Only report if we have previous weights to compare.ifprev_weights_sumisnotNone:
delta_weights=current_weights-prev_weightschanged_indices=delta_weights.nonzero()[:10]
file_name=f"{mode_name}_{suffix}_{rank}.txt"withopen(file_name, "a") asf:
f.write(f"[{mode_name} GPU {rank}] Step {step} Changed indices: {changed_indices.tolist()}\n")
f.write(f"[{mode_name} GPU {rank}] Step {step} Weight delta: {delta_weights[changed_indices]}\n")
returncurrent_weights_sum, current_weights# ------------------------------------------------------------------------------# Shared Model Definition# ------------------------------------------------------------------------------classBaseModel(nn.Module):
def__init__(self):
super(BaseModel, self).__init__()
self.flatten=nn.Flatten()
self.fc1=nn.Linear(28*28, 128)
self.relu=nn.ReLU()
self.fc2=nn.Linear(128, 10)
defforward(self, x):
x=self.flatten(x)
x=self.relu(self.fc1(x))
x=self.fc2(x)
returnx# ------------------------------------------------------------------------------# PyTorch Lightning Module & Data Module# ------------------------------------------------------------------------------classLitClassifier(L.LightningModule):
def__init__(self, graph_mode):
""" Args: graph_mode (bool): True if using static_graph=True, False otherwise. """super().__init__()
self.model=BaseModel()
# For tracking weight changes for fc1self.prev_weights_sum=Noneself.prev_weights=Noneself.automatic_optimization=False# using manual optimization.# Save the graph mode suffix for logging purposes.self.graph_suffix=f"sg{'True'ifgraph_modeelse'False'}"deftraining_step(self, batch, batch_idx):
opt=self.optimizers() # manual optimizer access.x, y=batchlogits=self.model(x)
loss=F.cross_entropy(logits, y)
opt.zero_grad()
self.manual_backward(loss)
opt.step()
# Compute current weights for fc1 and compare with previous step.current_weights=self.model.fc1.weight.data.view(-1).clone().detach().cpu()
self.prev_weights_sum, self.prev_weights=report_weight_changes(
self.global_rank, "Lightning", batch_idx, self.prev_weights, self.prev_weights_sum, current_weights,
self.graph_suffix
)
returnlossdefconfigure_optimizers(self):
returntorch.optim.Adam(self.model.parameters(), lr=1e-3)
classMNISTDataModule(L.LightningDataModule):
def__init__(self, batch_size=64):
super().__init__()
self.batch_size=batch_sizedefprepare_data(self):
datasets.MNIST("data", train=True, download=True)
datasets.MNIST("data", train=False, download=True)
defsetup(self, stage=None):
transform=transforms.ToTensor()
full_dataset=datasets.MNIST("data", train=True, transform=transform)
self.train_set, _=random_split(full_dataset, [55000, 5000])
deftrain_dataloader(self):
returnDataLoader(self.train_set, batch_size=self.batch_size, num_workers=4)
# ------------------------------------------------------------------------------# Vanilla PyTorch DDP Implementation# ------------------------------------------------------------------------------defsetup_ddp(rank, world_size, port):
os.environ["MASTER_ADDR"] ="localhost"os.environ["MASTER_PORT"] =str(port)
dist.init_process_group("nccl", rank=rank, world_size=world_size)
defcleanup_ddp():
dist.destroy_process_group()
defprepare_dataloader_ddp(rank, world_size, batch_size=64):
transform=transforms.ToTensor()
dataset=datasets.MNIST("data", train=True, download=True, transform=transform)
train_set, _=random_split(dataset, [55000, 5000])
sampler=torch.utils.data.distributed.DistributedSampler(train_set, num_replicas=world_size, rank=rank)
returnDataLoader(train_set, batch_size=batch_size, sampler=sampler, num_workers=4)
defddp_train(rank, world_size, steps, static_graph, port):
setup_ddp(rank, world_size, port)
device=torch.device(f"cuda:{rank}")
model=BaseModel().to(device)
# Pass the static_graph flag from the argument.ddp_model=DDP(model, device_ids=[rank], static_graph=static_graph)
optimizer=torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
train_loader=prepare_dataloader_ddp(rank, world_size)
# Determine the suffix for the log filename.graph_suffix=f"sg{'True'ifstatic_graphelse'False'}"step=0prev_weights_sum=Noneprev_weights=Noneddp_model.train()
forepochinrange(10): # Loop over epochs if necessary.forbatch_idx, (x, y) inenumerate(train_loader):
ifstep>=steps:
breakx, y=x.to(device), y.to(device)
optimizer.zero_grad()
output=ddp_model(x)
loss=F.cross_entropy(output, y)
loss.backward()
optimizer.step()
withtorch.no_grad():
current_weights=ddp_model.module.fc1.weight.data.view(-1).clone().detach().cpu()
prev_weights_sum, prev_weights=report_weight_changes(
rank, "PyTorch", step, prev_weights, prev_weights_sum, current_weights, graph_suffix
)
step+=1ifstep>=steps:
breakcleanup_ddp()
# ------------------------------------------------------------------------------# Main: Running Both Versions for Lightning and PyTorch# ------------------------------------------------------------------------------defrun_lightning(static_graph):
print(f"Running Lightning mode with static_graph={static_graph} for 2 training steps")
model=LitClassifier(graph_mode=static_graph)
dm=MNISTDataModule(batch_size=64)
trainer=Trainer(
max_epochs=1,
accelerator="gpu",
devices=torch.cuda.device_count(),
strategy=DDPStrategy(static_graph=static_graph),
num_sanity_val_steps=0,
deterministic=True,
limit_train_batches=2,
)
trainer.fit(model, dm)
defrun_pytorch(static_graph, port):
print(f"Running vanilla PyTorch DDP mode with static_graph={static_graph} for 2 training steps (port={port})")
world_size=torch.cuda.device_count()
mp.spawn(ddp_train, args=(world_size, 2, static_graph, port), nprocs=world_size, join=True)
if__name__=="__main__":
# Run Lightning with static_graph True and False:run_lightning(static_graph=True)
run_lightning(static_graph=False)
# Run vanilla PyTorch DDP with static_graph True and False on different ports.run_pytorch(static_graph=True, port=12356)
run_pytorch(static_graph=False, port=12357)
Error messages and logs
Running the reproducing script, we can check that Pytorch Lightning DDP with static_graph=True has different model parameters across different processes throughout the training step.
Bug description
When using PyTorch Lightning with DDP and static_graph=True, model parameters are not synchronized properly across processes. I tested this against vanilla PyTorch DDP and confirmed that the issue only appears in Lightning.
📄 Minimal Reproducible Example
I created a minimal script that compares model parameter changes across DDP processes after each optimizer step. It runs 2 training steps and logs the changed indices and delta of the weights from the first fully connected layer.
This script runs 4 experiments:
• Lightning with static_graph=True
• Lightning with static_graph=False
• Vanilla PyTorch DDP with static_graph=True
• Vanilla PyTorch DDP with static_graph=False
Only the Lightning + static_graph=True case shows inconsistent or missing synchronization.
🔍 Observed Behavior
• When using Lightning + DDP + static_graph=True, each GPU maintains a different version of the model after training steps.
• When using Vanilla PyTorch DDP + static_graph=True, synchronization works as expected.
✅ Expected Behavior
Model parameters should remain synchronized across DDP processes, even when static_graph=True.
What version are you seeing the problem on?
master
How to reproduce the bug
Error messages and logs
Running the reproducing script, we can check that Pytorch Lightning DDP with static_graph=True has different model parameters across different processes throughout the training step.
Lightning_SGTrue_0.txt:
Lightning_SGTrue_1.txt:
Environment
Current environment
- GPU:
- NVIDIA RTX A6000
- NVIDIA RTX A6000
- NVIDIA RTX A6000
- NVIDIA RTX A6000
- available: True
- version: 12.6
- lightning: 2.5.0.post0
- lightning-sdk: 0.2.5
- lightning-utilities: 0.14.0
- lion-pytorch: 0.2.3
- pytorch-lightning: 2.5.0.post0
- pytorch-triton: 3.3.0+git96316ce5
- torch: 2.8.0.dev20250407+cu126
- torch-tb-profiler: 0.4.3
- torchmetrics: 1.6.2
- torchvision: 0.22.0.dev20250407+cu126
- absl-py: 2.1.0
- accelerate: 1.4.0
- aiohappyeyeballs: 2.5.0
- aiohttp: 3.11.13
- aiosignal: 1.3.2
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- anyio: 4.9.0
- attn-gym: 0.0.4.dev12+g41a96b6
- attrs: 25.1.0
- autocommand: 2.2.2
- backoff: 2.2.1
- backports.tarfile: 1.2.0
- beautifulsoup4: 4.13.3
- blinker: 1.4
- boto3: 1.37.10
- botocore: 1.37.10
- bs4: 0.0.2
- certifi: 2025.1.31
- charset-normalizer: 3.4.1
- click: 8.1.8
- cloudpickle: 3.1.1
- contourpy: 1.3.1
- cryptography: 3.4.8
- cssselect: 1.3.0
- cycler: 0.12.1
- dacite: 1.9.2
- datasets: 3.3.2
- dbus-python: 1.2.18
- dill: 0.3.8
- distro: 1.7.0
- distro-info: 1.1+ubuntu0.2
- docker: 7.1.0
- docker-pycreds: 0.4.0
- einops: 0.8.1
- faiss: 1.10.0
- fastapi: 0.115.11
- feedfinder2: 0.0.4
- feedparser: 6.0.11
- filelock: 3.16.1
- flash-attn: 2.7.4.post1
- fonttools: 4.56.0
- frozenlist: 1.5.0
- fsspec: 2024.10.0
- ftfy: 6.3.1
- gitdb: 4.0.12
- gitpython: 3.1.44
- grpcio: 1.71.0
- h11: 0.14.0
- h5py: 3.13.0
- hkkang-utils: 0.2.57
- htmlmin: 0.1.12
- httplib2: 0.20.2
- huggingface-hub: 0.29.3
- hydra-core: 1.3.2
- idna: 3.10
- importlib-metadata: 8.0.0
- inflect: 7.3.1
- jaraco.collections: 5.1.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jeepney: 0.7.1
- jieba3k: 0.35.1
- jinja2: 3.1.4
- jmespath: 1.0.1
- joblib: 1.4.2
- jsonargparse: 4.37.0
- keyring: 23.5.0
- kiwisolver: 1.4.8
- langdetect: 1.0.9
- launchpadlib: 1.10.16
- lazr.restfulclient: 0.14.4
- lazr.uri: 1.0.6
- legacy-cgi: 2.6.2
- lightning: 2.5.0.post0
- lightning-sdk: 0.2.5
- lightning-utilities: 0.14.0
- lion-pytorch: 0.2.3
- lxml: 5.3.1
- lxml-html-clean: 0.4.1
- markdown: 3.7
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- matplotlib: 3.10.1
- mdurl: 0.1.2
- more-itertools: 10.3.0
- mpmath: 1.3.0
- multidict: 6.1.0
- multiprocess: 0.70.16
- networkx: 3.4.2
- newspaper3k: 0.2.8
- nltk: 3.9.1
- numpy: 2.2.3
- nvidia-cublas-cu12: 12.6.4.1
- nvidia-cuda-cupti-cu12: 12.6.80
- nvidia-cuda-nvrtc-cu12: 12.6.77
- nvidia-cuda-runtime-cu12: 12.6.77
- nvidia-cudnn-cu12: 9.5.1.17
- nvidia-cufft-cu12: 11.3.0.4
- nvidia-cufile-cu12: 1.11.1.6
- nvidia-curand-cu12: 10.3.7.77
- nvidia-cusolver-cu12: 11.7.1.2
- nvidia-cusparse-cu12: 12.5.4.2
- nvidia-cusparselt-cu12: 0.6.3
- nvidia-nccl-cu12: 2.26.2
- nvidia-nvjitlink-cu12: 12.6.85
- nvidia-nvtx-cu12: 12.6.77
- oauthlib: 3.2.0
- omegaconf: 2.3.0
- orjson: 3.10.15
- packaging: 24.2
- pandas: 2.2.3
- pglast: 7.3
- pillow: 11.1.0
- pip: 25.0.1
- platformdirs: 4.2.2
- propcache: 0.3.0
- protobuf: 5.29.3
- psutil: 7.0.0
- psycopg: 3.2.5
- psycopg-binary: 3.2.5
- psycopg-pool: 3.2.6
- pyarrow: 19.0.1
- pydantic: 2.10.6
- pydantic-core: 2.27.2
- pygments: 2.19.1
- pygobject: 3.42.1
- pyjwt: 2.3.0
- pyparsing: 2.4.7
- python-apt: 2.4.0+ubuntu4
- python-dateutil: 2.9.0.post0
- python-dotenv: 1.0.1
- pytorch-lightning: 2.5.0.post0
- pytorch-triton: 3.3.0+git96316ce5
- pytz: 2025.1
- pyyaml: 6.0.2
- regex: 2024.11.6
- requests: 2.32.3
- requests-file: 2.1.0
- rich: 13.9.4
- s3transfer: 0.11.4
- safetensors: 0.5.3
- secretstorage: 3.3.1
- sentencepiece: 0.2.0
- sentry-sdk: 2.22.0
- setproctitle: 1.3.5
- setuptools: 75.8.0
- sgmllib3k: 1.0.0
- simple-term-menu: 1.6.6
- six: 1.16.0
- slack-sdk: 3.34.0
- smmap: 5.0.2
- sniffio: 1.3.1
- soupsieve: 2.6
- standard-imghdr: 3.13.0
- starlette: 0.46.1
- sympy: 1.13.3
- tensorboard: 2.19.0
- tensorboard-data-server: 0.7.2
- tensordict: 0.7.2
- tinysegmenter: 0.3
- tldextract: 5.1.3
- tokenizers: 0.21.0
- tomli: 2.0.1
- torch: 2.8.0.dev20250407+cu126
- torch-tb-profiler: 0.4.3
- torchmetrics: 1.6.2
- torchvision: 0.22.0.dev20250407+cu126
- tqdm: 4.67.1
- transformers: 4.49.0
- triton: 3.2.0
- typeguard: 4.3.0
- typing-extensions: 4.12.2
- tzdata: 2025.1
- ujson: 5.10.0
- unattended-upgrades: 0.1
- urllib3: 2.3.0
- uvicorn: 0.34.0
- wadllib: 1.3.6
- wandb: 0.19.8
- wcwidth: 0.2.13
- websocket-client: 1.8.0
- werkzeug: 3.1.3
- wget: 3.2
- wheel: 0.43.0
- xxhash: 3.5.0
- yarl: 1.18.3
- zipp: 3.19.2
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.13.2
- release: 5.15.0-107-generic
- version: Errata in the readme? #117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024
More info
No response
The text was updated successfully, but these errors were encountered: