Skip to content

Chtruong/r2.3.0 cherry picks pt9 #13424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: r2.3.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,27 @@ jobs:
SCRIPT: L2_NeMo_2_NEVA_LOAD_GENERATE
IS_OPTIONAL: true

L2_NeMo_2_QWEN2VL_MOCK_FINETUNE_TP2:
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_QWEN2VL_MOCK_FINETUNE_TP2')
with:
RUNNER: self-hosted-azure
SCRIPT: L2_NeMo_2_QWEN2VL_MOCK_FINETUNE_TP2

L2_NeMo_2_QWEN2VL_PRELOADED_FINETUNE_TP2:
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_QWEN2VL_PRELOADED_FINETUNE_TP2')
with:
RUNNER: self-hosted-azure
SCRIPT: L2_NeMo_2_QWEN2VL_PRELOADED_FINETUNE_TP2

L2_NeMo_2_QWEN2VL_ENERGON_FINETUNE_TP2:
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_QWEN2VL_ENERGON_FINETUNE_TP2')
with:
RUNNER: self-hosted-azure
SCRIPT: L2_NeMo_2_QWEN2VL_ENERGON_FINETUNE_TP2

L2_NeMo_2_LLAVA_IMPORT:
needs: [pre-flight, cicd-test-container-build]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -2186,6 +2207,9 @@ jobs:
- L2_NeMo_2_NEVA_MOCK_FINETUNE_PP2
- L2_NeMo_2_NEVA_MOCK_FINETUNE_CP2
- OPTIONAL_L2_NeMo_2_NEVA_LOAD_GENERATE
- L2_NeMo_2_QWEN2VL_MOCK_FINETUNE_TP2
- L2_NeMo_2_QWEN2VL_PRELOADED_FINETUNE_TP2
- L2_NeMo_2_QWEN2VL_MOCK_FINETUNE_TP2
- Optional_L2_NeMo_2_MLLAMA_MOCK_FINETUNE_TP2
- L2_NeMo_2_MLLAMA_PRELOADED_FINETUNE_TP2
- Optional_L2_NEMO_2_MLLAMA_Inference
Expand Down
26 changes: 10 additions & 16 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
LazyNeMoTarredIterator,
expand_sharded_filepaths,
)
from nemo.collections.common.data.lhotse.sampling import PlaceholderFilter
from nemo.collections.common.data.lhotse.text_adapters import (
LhotseTextAdapter,
LhotseTextPairAdapter,
Expand All @@ -50,19 +49,17 @@ def read_cutset_from_config(config: Union[DictConfig, dict]) -> Tuple[CutSet, bo
if not isinstance(config, DictConfig):
config = DictConfig(config)
if config.get("input_cfg") is not None:
return read_dataset_config(config)
# Now, we'll figure out if we should read Lhotse manifest or NeMo manifest.
use_nemo_manifest = all(config.get(opt) is None for opt in ("cuts_path", "shar_path"))
if use_nemo_manifest:
if config.get("manifest_filepath") is None:
raise IncompleteConfigError("You must specify either: manifest_filepath, cuts_path, or shar_path")
cuts, is_tarred = read_nemo_manifest(config)
cuts, is_tarred = read_dataset_config(config)
else:
cuts, is_tarred = read_lhotse_manifest(config)
# Now, we'll figure out if we should read Lhotse manifest or NeMo manifest.
use_nemo_manifest = all(config.get(opt) is None for opt in ("cuts_path", "shar_path"))
if use_nemo_manifest:
if config.get("manifest_filepath") is None:
raise IncompleteConfigError("You must specify either: manifest_filepath, cuts_path, or shar_path")
cuts, is_tarred = read_nemo_manifest(config)
else:
cuts, is_tarred = read_lhotse_manifest(config)

# After reading cuts we filter cutsets to exclude cuts with valid "_skipme" values.
# This filtration is done before mixing cutsets as well. Here it is being done for non-mixed cutsets.
cuts = cuts.filter(PlaceholderFilter())
return cuts, is_tarred


Expand Down Expand Up @@ -351,6 +348,7 @@ def parse_and_combine_datasets(
assert len(weights) == 0 or len(cuts) == len(
weights
), "Missing dataset weight. When weighting datasets, every dataset must have a specified weight."

if len(cuts) > 1:
cuts = mux(
*cuts,
Expand Down Expand Up @@ -426,7 +424,6 @@ def read_lhotse_manifest(config) -> tuple[CutSet, bool]:
cutsets.append(cs)
weights.append(weight)

cutsets = [cutset.filter(PlaceholderFilter()) for cutset in cutsets]
cuts = mux(
*cutsets,
weights=weights,
Expand Down Expand Up @@ -620,9 +617,6 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]:
else:
cutsets.append(CutSet(nemo_iter))
weights.append(weight)
# Finally, we multiplex the dataset streams to mix the data.
# Before that we filter cutsets to exclude cuts with valid "_skipme" values to mix the data correctly.
cutsets = [cutset.filter(PlaceholderFilter()) for cutset in cutsets]
cuts = mux(
*cutsets,
weights=weights,
Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/common/data/lhotse/nemo_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def __iter__(self) -> Generator[Cut, None, None]:
# Propagate the random seed
extra_fields = [ExtraField.from_dict({"seed": seed, **field_cfg}) for field_cfg in self.extra_fields or ()]
for data in self.source:
# filter out entries with valid "_skipme" values.
if data.get("_skipme", False):
continue
audio_path = get_full_path(str(data.pop("audio_filepath")), str(self.path), force_cache=False)
duration = data.pop("duration")
offset = data.pop("offset", None)
Expand Down Expand Up @@ -409,6 +412,9 @@ def basename(d: dict) -> str:
)
cuts_for_recording = []
for data in sorted(shard_manifest[tar_info.name], key=lambda d: d["audio_filepath"]):
# filter out entries with valid "_skipme" values.
if data.get("_skipme", False):
continue
# Cut the recording into corresponding segment and discard audio data outside the segment.
cut = make_cut_with_subset_inmemory_recording(
recording, offset=data.get("offset", 0.0), duration=data.get("duration")
Expand Down
14 changes: 0 additions & 14 deletions nemo/collections/common/data/lhotse/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,17 +381,3 @@ def _measure_tokens(cut: Cut) -> int:
def _measure_tps(cut: Cut) -> float:
num_tokens = _measure_tokens(cut)
return num_tokens / cut.duration


class PlaceholderFilter:
"""
Callable, returns ``True`` if a cut's "_skipme" is set and ``False`` otherwise.
Acts as a pass-through for objects of other type than Cut.
"""

def __call__(self, example) -> bool:
if not isinstance(example, Cut):
return True

custom = getattr(example, "custom", None)
return custom is None or not custom.pop("_skipme", False)
229 changes: 229 additions & 0 deletions nemo/collections/vlm/recipes/qwen2vl_2b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Optional

import lightning.pytorch as pl
import nemo_run as run
import torch
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from transformers import Qwen2VLImageProcessor

from nemo import lightning as nl
from nemo.collections import llm, vlm
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm.recipes.log.default import tensorboard_logger
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.collections.vlm.qwen2vl.data.mock import Qwen2VLMockDataModule
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.utils.exp_manager import TimingCallback

NAME = "qwen2vl_2b"

HF_MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct"


@run.cli.factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Qwen2VL 2B model configuration.

Returns:
run.Config[pl.LightningModule]: Configuration for the Qwen2VL 2B model model.

Examples:
CLI usage:
$ nemo llm pretrain model=qwen2vl_2b ...

Python API usage:
>>> model_config = model()
>>> print(model_config)
"""
return run.Config(vlm.Qwen2VLModel, config=run.Config(vlm.Qwen2VLConfig2B))


@run.cli.factory(target=llm.finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'none',
) -> run.Partial:
"""
Create a fine-tuning recipe for Qwen2VL 2B model.

This function sets up a complete configuration for fine-tuning, including
model, trainer, data, logging, optimization, and resumption settings.
The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None.

Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the fine-tuning run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.

Returns:
run.Partial: Partial configuration for fine-tuning.

Examples:
CLI usage:
$ nemo llm finetune --factory qwen2vl_2b

Python API usage:
>>> recipe = finetune_recipe(name="qwen2vl_2b_finetune", num_nodes=1)
>>> print(recipe)

Note:
This recipe uses the Mock dataset for fine-tuning.
"""

strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
encoder_pipeline_model_parallel_size=0,
sequence_parallel=True,
pipeline_dtype=torch.bfloat16,
ddp=run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=True,
overlap_param_gather=True,
average_in_collective=True,
),
)

trainer = run.Config(
nl.Trainer,
accelerator="gpu",
accumulate_grad_batches=1,
devices=num_gpus_per_node,
limit_val_batches=10,
log_every_n_steps=1,
max_steps=10,
num_nodes=num_nodes,
plugins=bf16_mixed(),
strategy=strategy,
val_check_interval=1000,
callbacks=[
run.Config(TimingCallback),
run.Config(MegatronCommOverlapCallback, tp_comm_overlap=True),
],
)
tokenizer = run.Config(AutoTokenizer, HF_MODEL_NAME)
image_processor = run.Config(Qwen2VLImageProcessor)

max_sequence_length = 4096

language_transformer_config = run.Config(llm.Qwen2Config1P5B, seq_length=max_sequence_length)

vision_transformer_config = run.Config(vlm.Qwen2VLVisionConfig)

vision_projection_config = run.Config(
vlm.MultimodalProjectorConfig,
projector_type="mcore_mlp",
input_size=vision_transformer_config.ffn_hidden_size,
hidden_size=language_transformer_config.hidden_size,
ffn_hidden_size=vision_transformer_config.ffn_hidden_size,
)

# Qwen2VL model configuration
qwen2vl_config = run.Config(
vlm.Qwen2VLConfig,
language_transformer_config=language_transformer_config,
vision_transformer_config=vision_transformer_config,
vision_projection_config=vision_projection_config,
freeze_language_model=False,
freeze_vision_model=True,
)

model = run.Config(vlm.Qwen2VLModel, qwen2vl_config, tokenizer=tokenizer)
nemo_resume = run.Config(
nl.AutoResume,
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
)

opt_config = run.Config(
OptimizerConfig,
optimizer='adam',
lr=2.0e-06,
adam_beta1=0.9,
adam_beta2=0.95,
use_distributed_optimizer=True,
bf16=True,
)
sched = run.Config(
CosineAnnealingScheduler, max_steps=trainer.max_steps, warmup_steps=0, constant_steps=1000, min_lr=1.0e-07
)
opt = run.Config(MegatronOptimizerModule, opt_config, sched)

recipe = run.Partial(
llm.finetune,
model=model,
trainer=trainer,
data=run.Config(
Qwen2VLMockDataModule,
seq_length=max_sequence_length,
global_batch_size=128,
micro_batch_size=2,
tokenizer=tokenizer,
image_processor=image_processor,
num_workers=4,
),
log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
optim=opt,
resume=nemo_resume,
)

if peft_scheme is None or peft_scheme.lower() == 'none':
recipe.trainer.strategy.tensor_model_parallel_size = 1
recipe.optim.config.lr = 2e-05
elif peft_scheme.lower() == 'lora':
recipe.peft = run.Config(
vlm.LoRA,
target_modules=[
"linear_qkv",
"linear_proj",
"linear_fc1",
"linear_fc2",
],
)
recipe.optim.config.lr = 1e-4
else:
raise ValueError(f"Unrecognized peft scheme: {peft_scheme}")

return recipe


if __name__ == "__main__":
env_vars = {
"CUDA_VISIBLE_DEVICES": "0,1",
}
recipe = finetune_recipe(num_gpus_per_node=2)
recipe.trainer.max_steps = 10
recipe.trainer.val_check_interval = 10
recipe.trainer.limit_val_batches = 0.0
recipe.data.global_batch_size = 8
recipe.trainer.strategy.tensor_model_parallel_size = 1

executor = run.LocalExecutor(ntasks_per_node=2, launcher="torchrun", env_vars=env_vars)

run.run(recipe, executor=executor, name="qwen2vl_2b_finetune")
8 changes: 6 additions & 2 deletions scripts/performance/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ The following line shows an example of how you can launch a pre-training experim
- -ng/--num_gpus: Number of gpus.
- -gn/--gpus_per_node: Number of gpus per node. Defaults to 8.
- -ms/--max_steps: Number of train steps. Defaults to 100.
- -cg/--cuda_graphs: Enable CUDA graphs. Disabled by default.
- -cg/--cuda_graphs: Enable CUDA graphs. Options: 'true', '1', 't', 'yes', 'y' to set it to True, 'false', '0', 'f', 'n', 'no' to set it to False. Defaults to None, in which case the program tries to load default values from recommended model configs, if failed defaults to false.
- -fsdp/--use_mcore_fsdp: Enable megatron-core FSDP. Options: 'true', '1', 't', 'yes', 'y' to set it to True, 'false', '0', 'f', 'n', 'no' to set it to False. Defaults to None, in which case the program tries to load default values from recommended model configs, if failed defaults to false.
- -rl/--recompute_layers: Number of transformer layers to recompute activations during training. Defaults to None, in which case the program tries to load default values from recommended model configs, if failed defaults to 0.
- -ol/--activation_offload_layers: Number of transformer layers to offload activations to CPU during training. Defaults to None, in which case the program tries to load default values from recommended model configs, if failed defaults to 0.
- -rm/--recompute_modules: Comma separated string of modules in a transformer layer to recompute. If set, program will use selective recompute for all layers. Users should provide zero, one or more than one values. Options are "core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe". Defaults to None, in which case the program tries to load default values from recommended model configs, if failed defaults to None, which mean no selective recompute.
- -cm/--custom_mounts: Comma separated string of mounts.
- You don't need to set any value for `--enable_nsys`, `--tensorboard` and `--dryrun`. See the below example for reference-
- You don't need to set any value for `--enable_nsys`, `--tensorboard`, `--wandb`, and `--dryrun`. See the below example for reference-
`python -m scripts.performance.llm.llama3_8b --account <your_slurm_account> -p <your_slurm_partition> -en --tensorboard -d`
Loading
Loading