From 6274771613c44e0775f624ac51157272f80f6e62 Mon Sep 17 00:00:00 2001 From: Azily Date: Tue, 9 Jul 2024 12:51:26 +0800 Subject: [PATCH 1/5] [FIX] an error of script --- contrib/text2image/finetune_t2i.sh | 12 ++++++------ contrib/text2image/t2i_dataset.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/contrib/text2image/finetune_t2i.sh b/contrib/text2image/finetune_t2i.sh index 8e9bb324c..023bf11bd 100644 --- a/contrib/text2image/finetune_t2i.sh +++ b/contrib/text2image/finetune_t2i.sh @@ -2,8 +2,8 @@ model_name_or_path=stabilityai/stable-diffusion-2-1 model_type="unet" dataset_path=data/example -output_dir=output -main_port=29500 +output_dir=output_dir +main_process_port=29500 img_size=768 while [[ $# -ge 1 ]]; do @@ -25,8 +25,8 @@ while [[ $# -ge 1 ]]; do output_dir="$2" shift ;; - -p|--main_port) - main_port="$2" + -p|--main_process_port) + main_process_port="$2" shift ;; -i|--img_size) @@ -44,13 +44,13 @@ echo "model_name_or_path: ${model_name_or_path}" echo "model_type: ${model_type}" echo "dataset_path: ${dataset_path}" echo "output_dir: ${output_dir}" -echo "main_port: ${main_port}" +echo "main_process_port: ${main_process_port}" echo "img_size: ${img_size}" accelerate launch \ --config_file=./accelerate_t2i_config.yaml \ - --main_port=${main_port} \ + --main_process_port=${main_process_port} \ finetune_t2i.py \ --model_name_or_path=${model_name_or_path} \ --model_type=${model_type} \ diff --git a/contrib/text2image/t2i_dataset.py b/contrib/text2image/t2i_dataset.py index aafe2de6e..a7a9b03f9 100644 --- a/contrib/text2image/t2i_dataset.py +++ b/contrib/text2image/t2i_dataset.py @@ -26,7 +26,7 @@ def __init__(self, data_args: T2IDatasetArguments): self.data_file = osp.join(data_args.dataset_path, data_args.train_file) self.data_dict = json.load(open(self.data_file, "r")) - assert self.data_dict["type"] == "text-image", "The dataset type must be text-image." + assert self.data_dict["type"] == "image_text", "The dataset type must be image_text" self.data_instances = self.data_dict["instances"] From 91c82ab446f2592818634c16d97982e677131ddd Mon Sep 17 00:00:00 2001 From: Azily Date: Tue, 9 Jul 2024 13:03:30 +0800 Subject: [PATCH 2/5] [ADD] merge text2image to source code --- src/lmflow/args.py | 243 +++++++++++++++++- src/lmflow/datasets/dataset.py | 5 +- src/lmflow/datasets/text2image_dataset.py | 149 +++++++++++ src/lmflow/pipeline/auto_pipeline.py | 2 + src/lmflow/pipeline/diffuser_tuner.py | 289 ++++++++++++++++++++++ 5 files changed, 684 insertions(+), 4 deletions(-) create mode 100644 src/lmflow/datasets/text2image_dataset.py create mode 100644 src/lmflow/pipeline/diffuser_tuner.py diff --git a/src/lmflow/args.py b/src/lmflow/args.py index 7cb44f86d..ce1e010d2 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -60,7 +60,7 @@ class ModelArguments: a string representing the path or name of a pretrained model checkpoint for weights initialization. If None, a model will be trained from scratch. - model_type : str + arch_type : str a string representing the type of model to use if training from scratch. If not provided, a pretrained model will be used. @@ -140,7 +140,7 @@ class ModelArguments: ) } ) - model_type: Optional[str] = field( + arch_type: Optional[str] = field( default=None, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, ) @@ -358,7 +358,63 @@ def __post_init__(self): logger.warning("use_qlora is set to True, but use_lora is not set to True. Setting use_lora to True.") self.use_lora = True - +@dataclass +class DiffuserModelArguments: + """ + Arguments for Diffuser model + + model_name_or_path : str + a string representing the path or name of a pretrained + model checkpoint for weights initialization. If None, a model will be trained from scratch. + + arch_type : str + a string representing the type of model to use + + use_lora: bool + a boolean indicating whether to use lora. + + lora_r: int + an integer indicating the rank of the lora parameters. + + lora_alpha: int + an integer indicating the merging ratio between the fine-tuned model and the original. + + lora_target_modules: List[str] + a list of strings representing the modules to apply lora. + + lora_dropout: float + a float indicating the dropout rate in lora.linear + """ + + model_name_or_path: Optional[str] = field( + default=None, metadata={"help": "The model name or path."} + ) + + arch_type: Optional[str] = field( + default=None, metadata={"help": "The model type."} + ) + + use_lora: bool = field( + default=False, + metadata={"help": "Whether to lora."}, + ) + + lora_r: int = field( + default=8, + metadata={"help": "the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has."}, + ) + lora_alpha: int = field( + default=8, + metadata={ + "help": "Merging ratio between the fine-tuned model and the original. This is controlled by a parameter called alpha in the paper."}, + ) + lora_target_modules: List[str] = field( + default=None, metadata={"help": "Modules to apply lora."} + ) + lora_dropout: float = field( + default=0.1, + metadata={"help": "The dropout rate in lora.linear."}, + ) @dataclass class VisModelArguments(ModelArguments): low_resource: Optional[bool] = field( @@ -613,6 +669,77 @@ def __post_init__(self): extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." +@dataclass +class T2IDatasetArguments(DatasetArguments): + """ + Additional Arguments for T2I dataset + + image_folder : str + a string representing the folder of the image file. + + image_size : int + an integer indicating the size of the image. + + image_crop_type : str + a string representing the type of image crop. + + preprocessor_kind : str + a string representing how to get text embedding. + + is_t2i : bool + a boolean indicating the flag for the modality type. + """ + + image_folder: Optional[str] = field( + default=None, metadata={"help": "The folder of the image file."} + ) + + image_size: Optional[int] = field( + default=512, metadata={"help": "The size of the image."} + ) + + image_crop_type: Optional[str] = field( + default="center", metadata={"help": "The type of image crop."} + ) + + preprocessor_kind: Optional[str] = field( + default="simple", metadata={"help": "How to get text embedding."} + ) + + is_t2i: Optional[bool] = field( + default=True, metadata={"help": "Flag for the modality type."} + ) + + def __post_init__(self): + def check_extension(file_path: str, extension: str): + assert file_path.split(".")[-1] == extension, f"The file must be a {extension} file." + + + if self.dataset_path is None or self.image_folder is None: + raise ValueError("The dataset_path, image_folder must be provided.") + + else: + root_path = Path(self.dataset_path) + if self.train_file is None: + if root_path.joinpath("train.json").exists(): + self.train_file = "train.json" + else: + raise ValueError("The train_file must be provided.") + + check_extension(self.train_file, "json") + if (self.validation_file is not None and self.test_file is None)\ + or (self.validation_file is None and self.test_file is not None): + same_file = self.validation_file if self.validation_file is not None else self.test_file + self.validation_file = same_file + self.test_file = same_file + if self.validation_file is not None: + check_extension(self.validation_file, "json") + if not root_path.joinpath(self.validation_file).exists(): + self.validation_file = None + if self.test_file is not None: + check_extension(self.test_file, "json") + if not root_path.joinpath(self.test_file).exists(): + self.test_file = None @dataclass class MultiModalDatasetArguments(DatasetArguments): @@ -1372,6 +1499,115 @@ class IterativeAlignerArguments(InferencerArguments): pass +@dataclass +class DiffuserTunerArguments: + """ + Arguments for Diffuser Tuner + + output_dir : str + The output directory. + + logging_dir : str + The logging directory. {output_dir}/{logging_dir} + + overwrite_output_dir : bool + Overwrite the content of the output directory. + + mixed_precision : str + Whether to use mixed precision. ["no", "bf16", "fp16"] + + do_train : bool + Whether to run training. + num_train_epochs : int + The number of training epochs. + train_batch_size : int + The number of batch size in training. + learning_rate : float + The learning rate. + weight_decay : float + The weight decay. + save_steps : int + How often to save the model. + save_total_limit : int + The total number of checkpoints to save. + + do_valid : bool + Whether to run evaluation. + valid_steps : int + How often to run evaluation. + valid_seed : int + The seed for evaluation. + + do_valid : bool + Whether to run test. + valid_seed : int + The seed for test. + """ + + output_dir: Optional[str] = field( + default="output", metadata={"help": "The output directory."} + ) + + logging_dir: Optional[str] = field( + default="logs", metadata={"help": "The logging directory."} + ) + + overwrite_output_dir: bool = field( + default=False, metadata={"help": "Overwrite the content of the output directory."} + ) + + mixed_precision: str = field( + default="no", metadata={"help": "Whether to use mixed precision."} + ) + + do_train: bool = field( + default=True, metadata={"help": "Whether to run training."} + ) + + num_train_epochs: Optional[int] = field( + default=50, metadata={"help": "The number of training epochs."} + ) + + train_batch_size: Optional[int] = field( + default=1, metadata={"help": "The number of batch size in training."} + ) + + learning_rate: Optional[float] = field( + default=1e-4, metadata={"help": "The learning rate."} + ) + + weight_decay: Optional[float] = field( + default=0.0, metadata={"help": "The weight decay."} + ) + + do_valid: bool = field( + default=True, metadata={"help": "Whether to run evaluation."} + ) + + do_test: bool = field( + default=True, metadata={"help": "Whether to run testing."} + ) + + valid_steps: Optional[int] = field( + default=50, metadata={"help": "The evaluation steps."} + ) + + valid_seed: Optional[int] = field( + default=42, metadata={"help": "The seed for validation."} + ) + + test_seed: Optional[int] = field( + default=42, metadata={"help": "The seed for testing."} + ) + + save_steps: Optional[int] = field( + default=500, metadata={"help": "The saving steps."} + ) + + save_total_limit: Optional[int] = field( + default=None, metadata={"help": "The total number of checkpoints to save."} + ) + PIPELINE_ARGUMENT_MAPPING = { "finetuner": FinetunerArguments, "evaluator": EvaluatorArguments, @@ -1382,6 +1618,7 @@ class IterativeAlignerArguments(InferencerArguments): "dpo_aligner": DPOAlignerArguments, "rm_tuner": RewardModelTunerArguments, "dpov2_aligner": DPOv2AlignerArguments, + "diffuser_tuner": DiffuserTunerArguments, } diff --git a/src/lmflow/datasets/dataset.py b/src/lmflow/datasets/dataset.py index 377be0465..d2dfe40a9 100644 --- a/src/lmflow/datasets/dataset.py +++ b/src/lmflow/datasets/dataset.py @@ -32,7 +32,7 @@ ) from .multi_modal_dataset import CustomMultiModalDataset - +from .text2image_dataset import build_t2i_dataset logger = logging.getLogger(__name__) @@ -129,6 +129,9 @@ def __init__(self, data_args: DatasetArguments=None, backend: str="huggingface", # FIXME refactor the backend name raw_dataset = CustomMultiModalDataset(self.dataset_path, data_args) self.backend_dataset = raw_dataset + elif backend == "t2i": + raw_dataset = build_t2i_dataset(data_args, *args, **kwargs) + self.backend_dataset = raw_dataset else: raise NotImplementedError(f'Unsupported dataset backend "{backend}"') diff --git a/src/lmflow/datasets/text2image_dataset.py b/src/lmflow/datasets/text2image_dataset.py new file mode 100644 index 000000000..b245bef24 --- /dev/null +++ b/src/lmflow/datasets/text2image_dataset.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# coding=utf-8 + +"""This Python code defines a class T2I Dataset. +""" +import json +from PIL import Image +import os.path as osp +from tqdm import tqdm +import logging + +from torch.utils.data import Dataset +from torchvision import transforms + +from lmflow.args import T2IDatasetArguments + +logger = logging.getLogger(__name__) + +class CustomT2IDataset(Dataset): + """ + Dataset for T2I data + + Parameters + ------------ + data_args: T2IDatasetArguments + The arguments for the dataset. + """ + + def __init__(self, data_args: T2IDatasetArguments): + self.data_args = data_args + self.image_folder = osp.join(data_args.dataset_path, data_args.image_folder) + self.data_file = osp.join(data_args.dataset_path, data_args.train_file) + + self.data_dict = json.load(open(self.data_file, "r")) + assert self.data_dict["type"] == "image_text", "The dataset type must be text-image." + + self.data_instances = self.data_dict["instances"] + + def __len__(self): + return len(self.data_instances) + + def __getitem__(self, idx): + instance = self.data_instances[idx] + image_path = osp.join(self.image_folder, instance["image"]) + image = Image.open(image_path) + image = image.convert("RGB") + + return { + "image": image, + "text": instance["text"], + } + +class EncodePreprocessor(object): + """ + This class implement the preparation of the data for the model. + For different Diffusion model, the preparation is different. + + Parameters + ------------ + data_args: T2IDatasetArguments + The arguments for the dataset. + + **kwargs + The arguments for the preprocessor. + + Example + ------------ + >>> data_args.preprocessor_kind + simple + >>> kwargs = {"tokenizer": tokenizer, "text_encoder": text_encoder, "vae": vae} + >>> raw_dataset = CustomT2IDataset(data_args) + >>> preprocessor = EncodePreprocessor(data_args=data_args, **kwargs) + >>> dataset = PreprocessedT2IDataset(raw_dataset, data_args, preprocessor) + """ + + def __init__(self, data_args: T2IDatasetArguments, + **kwargs): + self.transform = transforms.Compose( + [ + transforms.Resize(data_args.image_size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(data_args.image_size) if data_args.image_crop_type == "center" else transforms.RandomCrop(data_args.image_size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5], std=[0.5]), + ] + ) + + self.pre_func = None + if data_args.preprocessor_kind == "simple": + self.register_simple_func(**kwargs) + else: + raise NotImplementedError(f"The preprocessor kind {data_args.preprocessor_kind} is not implemented.") + + def register_simple_func(self, + tokenizer, + text_encoder, + vae): + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.vae = vae + + def simple_func(data_item): + image = self.transform(data_item["image"]) + latents = self.vae.encode(image.to(self.vae.device, dtype=self.vae.dtype).unsqueeze(0)).latent_dist.sample() + encoded_image = latents * self.vae.config.scaling_factor + encoded_image = encoded_image.detach() + encoded_image=encoded_image.squeeze(0).cpu() + + max_length = self.tokenizer.model_max_length + tokens = self.tokenizer([data_item["text"]], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids + encoded_text = self.text_encoder(tokens.to(self.text_encoder.device))[0] + encoded_text = encoded_text.detach() + encoded_text =encoded_text.squeeze(0).cpu() + + return { + "image": encoded_image, + "text": encoded_text, + } + + self.pre_func = simple_func + + def __call__(self, data_item): + return self.pre_func(data_item) + +class PreprocessedT2IDataset(Dataset): + "Preprocess dataset with prompt" + + def __init__(self, raw_dataset:Dataset, + data_args: T2IDatasetArguments, + preprocessor:EncodePreprocessor): + del data_args # Unused variable + self.data_dict = [] + + logger.info("Preprocessing data ...") + for data_item in tqdm(raw_dataset): + self.data_dict.append(preprocessor(data_item)) + + def __len__(self): + return len(self.data_dict) + + def __getitem__(self, idx): + return self.data_dict[idx] + +def build_t2i_dataset(data_args: T2IDatasetArguments, + **kwargs): + raw_dataset = CustomT2IDataset(data_args) + preprocessor = EncodePreprocessor(data_args=data_args, **kwargs) + dataset = PreprocessedT2IDataset(raw_dataset, data_args, preprocessor) + + return dataset \ No newline at end of file diff --git a/src/lmflow/pipeline/auto_pipeline.py b/src/lmflow/pipeline/auto_pipeline.py index 82212802f..d42db112e 100644 --- a/src/lmflow/pipeline/auto_pipeline.py +++ b/src/lmflow/pipeline/auto_pipeline.py @@ -22,6 +22,7 @@ def is_package_version_at_least(package_name, min_version): from lmflow.pipeline.dpov2_aligner import DPOv2Aligner from lmflow.pipeline.rm_tuner import RewardModelTuner from lmflow.pipeline.rm_inferencer import RewardModelInferencer +from lmflow.pipeline.diffuser_tuner import DiffuserModelTuner PIPELINE_MAPPING = { "evaluator": Evaluator, "finetuner": Finetuner, @@ -31,6 +32,7 @@ def is_package_version_at_least(package_name, min_version): "dpo_aligner": DPOAligner, "dpov2_aligner": DPOv2Aligner, "rm_tuner": RewardModelTuner, + "diffuser_tuner": DiffuserModelTuner, } if not is_package_version_at_least('transformers', '4.35.0'): diff --git a/src/lmflow/pipeline/diffuser_tuner.py b/src/lmflow/pipeline/diffuser_tuner.py new file mode 100644 index 000000000..2ac04b38c --- /dev/null +++ b/src/lmflow/pipeline/diffuser_tuner.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python +# coding=utf-8 +""" The Finetuner class is used to finetune a diffuser model. +""" + +import json +import os +import logging +import gc +import copy + +from tqdm import tqdm +import wandb +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from diffusers import ( + DiffusionPipeline, + UNet2DConditionModel, + DDPMScheduler, +) +from diffusers.loaders import LoraLoaderMixin +from diffusers.utils import ( + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, +) +from diffusers.utils.torch_utils import is_compiled_module +from diffusers.optimization import get_scheduler +from accelerate import Accelerator +from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict + +from lmflow.pipeline.finetuner import BaseTuner +from lmflow.args import T2IDatasetArguments, DiffuserModelArguments, DiffuserTunerArguments +from lmflow.datasets import Dataset + +logger = logging.getLogger(__name__) + +def log_validation( + pipeline, + accelerator: Accelerator, + pipeline_args: dict, + save_dir, + global_step, +): + del global_step # unused + pipeline.to(accelerator.device) + pipeline.vae.to(torch.float32) + + with torch.no_grad(): + prompt_images = [ + (pipeline_arg["prompt"], pipeline(**pipeline_arg).images[0]) for pipeline_arg in pipeline_args + ] + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {prompt}") for i, (prompt, image) in enumerate(prompt_images) + ] + } + ) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + for i, (prompt, image) in enumerate(prompt_images): + image.save(os.path.join(save_dir, f"{prompt.replace(' ', '_')}.png")) + + del pipeline + torch.cuda.empty_cache() + gc.collect() + + return + +class DiffuserModelTuner(BaseTuner): + """Initializes the `RewardModelTuner` class. + + Parameters + ---------- + model_args : ModelArguments object. + Contains the arguments required to load the model. + + data_args : DatasetArguments object. + Contains the arguments required to load the dataset. + + finetuner_args : RewardModelTunerArguments object. + Contains the arguments required to perform finetuning. + + args : Optional. + Positional arguments. + + kwargs : Optional. + Keyword arguments. + """ + def __init__( + self, + model_args: DiffuserModelArguments, + data_args: T2IDatasetArguments, + finetuner_args: DiffuserTunerArguments, + *args, + **kwargs + ): + self.model_args = model_args + self.data_args = data_args + self.finetuner_args = finetuner_args + + def tune( + self, + accelerator: Accelerator, + model, + dataset: Dataset, + ): + dataloader = DataLoader(dataset=dataset.backend_dataset, batch_size=self.finetuner_args.train_batch_size, shuffle=True) + + noise_scheduler = DDPMScheduler.from_pretrained(self.model_args.model_name_or_path, subfolder="scheduler") + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # filter trainable parameters + params_to_optimize = list(filter(lambda p: p.requires_grad, model.parameters())) + accelerator.print(len(params_to_optimize)) + + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=self.finetuner_args.learning_rate, + weight_decay=self.finetuner_args.weight_decay + ) + + lr_scheduler = get_scheduler( + "constant", + optimizer=optimizer, + ) + + model, dataloader, optimizer, lr_scheduler = accelerator.prepare( + model, dataloader, optimizer, lr_scheduler + ) + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + + progress_bar = tqdm( + range(self.finetuner_args.num_train_epochs * len(dataloader)), + desc="Training", + disable=not accelerator.is_main_process + ) + + global_step = 0 + for epoch in range(self.finetuner_args.num_train_epochs): + model.train() + for batch in dataloader: + clean_latents = batch["image"].to(dtype=weight_dtype) + text_embedding = batch["text"].to(dtype=weight_dtype) + + bsz, channel, height, width = clean_latents.shape + noise = torch.randn_like(clean_latents).to(dtype=weight_dtype) + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_latents.device + ) + timesteps = timesteps.long() + + noisy_latents = noise_scheduler.add_noise(clean_latents, noise, timesteps) + model_pred = model( + noisy_latents, timesteps, text_embedding, + )[0] + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(clean_latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + progress_bar.update(1) + + if accelerator.is_main_process: + logs = {"loss": loss.item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + global_step += 1 + + # validation + if accelerator.is_main_process and self.finetuner_args.do_valid and self.data_args.validation_file is not None: + if global_step % self.finetuner_args.valid_steps == 0: + with torch.no_grad(): + pipeline = DiffusionPipeline.from_pretrained( + self.model_args.model_name_or_path, + torch_dtype=weight_dtype, + ) + if self.model_args.arch_type == "unet": + pipeline.unet = unwrap_model(model) + elif self.model_args.arch_type == "transformer": + pipeline.transformer = unwrap_model(model) + else: + raise ValueError(f"Unknown model type {self.model_args.arch_type}") + + with open(os.path.join(self.data_args.dataset_path, self.data_args.validation_file), "r") as f: + validation_data = json.load(f) + generator = torch.Generator(device=accelerator.device).manual_seed(self.finetuner_args.valid_seed) + pipeline_args = [ + {"prompt": item["text"], "generator": generator, "width": self.data_args.image_size, "height": self.data_args.image_size} + for item in validation_data["instances"] + ] + log_validation( + pipeline, + accelerator, + pipeline_args, + os.path.join(self.finetuner_args.output_dir, f"step_{global_step}_validation"), + global_step, + ) + + if accelerator.is_main_process and global_step % self.finetuner_args.save_steps == 0: + os.makedirs(os.path.join(self.finetuner_args.output_dir, f"checkpoints"), exist_ok=True) + if len(os.listdir(os.path.join(self.finetuner_args.output_dir, f"checkpoints"))) > self.finetuner_args.max_checkpoints: + os.remove(os.path.join(self.finetuner_args.output_dir, f"checkpoints", sorted(os.listdir(os.path.join(self.finetuner_args.output_dir, f"checkpoints")))[0])) + if self.model_args.use_lora: + temp_model = unwrap_model(copy.deepcopy(model)) + temp_model = temp_model.to(torch.float32) + model_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(temp_model)) + LoraLoaderMixin.save_lora_weights( + save_directory=os.path.join(self.finetuner_args.output_dir, f"checkpoints", f"final"), + unet_lora_layers=model_lora_state_dict if self.model_args.arch_type == "unet" else None, + transformer_lora_layers=model_lora_state_dict if self.model_args.arch_type == "transformer" else None, + ) + del temp_model + else: + accelerator.save( + accelerator.get_state_dict(model), + os.path.join(self.finetuner_args.output_dir, f"checkpoints", f"final.pt") + ) + + accelerator.wait_for_everyone() + progress_bar.close() + if accelerator.is_main_process: + if self.finetuner_args.do_test and self.data_args.test_file is not None: + pipeline = DiffusionPipeline.from_pretrained( + self.model_args.model_name_or_path, + torch_dtype=weight_dtype, + ) + if self.model_args.arch_type == "unet": + pipeline.unet = unwrap_model(model) + elif self.model_args.arch_type == "transformer": + pipeline.transformer = unwrap_model(model) + else: + raise ValueError(f"Unknown model type {self.model_args.arch_type}") + + with open(os.path.join(self.data_args.dataset_path, self.data_args.test_file), "r") as f: + test_data = json.load(f) + generator = torch.Generator(device=accelerator.device).manual_seed(self.finetuner_args.test_seed) + pipeline_args = [ + {"prompt": item["text"], "generator": generator, "width": self.data_args.image_size, "height": self.data_args.image_size} + for item in test_data["instances"] + ] + log_validation( + pipeline, + accelerator, + pipeline_args, + os.path.join(self.finetuner_args.output_dir, f"test_final"), + global_step, + ) + + os.makedirs(os.path.join(self.finetuner_args.output_dir, f"checkpoints"), exist_ok=True) + if self.model_args.use_lora: + model = unwrap_model(model) + model = model.to(torch.float32) + model_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) + LoraLoaderMixin.save_lora_weights( + save_directory=os.path.join(self.finetuner_args.output_dir, f"checkpoints", f"final"), + unet_lora_layers=model_lora_state_dict if self.model_args.arch_type == "unet" else None, + transformer_lora_layers=model_lora_state_dict if self.model_args.arch_type == "transformer" else None, + ) + # pipeline.load_lora_weights(output_dir, weight_name="pytorch_lora_weights.safetensors") + else: + accelerator.save( + accelerator.get_state_dict(model), + os.path.join(self.finetuner_args.output_dir, f"checkpoints", f"final.pt") + ) + + return + \ No newline at end of file From 591119fe9e36ccfe3d104a2d0a991ed51cab5d55 Mon Sep 17 00:00:00 2001 From: Azily Date: Tue, 9 Jul 2024 13:18:07 +0800 Subject: [PATCH 3/5] [ADD] fine-tune text2image script --- .gitignore | 2 + configs/accelerate_t2i_config.yaml | 17 +++++ examples/finetune_t2i.py | 105 +++++++++++++++++++++++++++ requirements.txt | 3 +- scripts/diffuser/README.md | 75 +++++++++++++++++++ scripts/diffuser/run_finetune_t2i.sh | 72 ++++++++++++++++++ 6 files changed, 273 insertions(+), 1 deletion(-) create mode 100644 configs/accelerate_t2i_config.yaml create mode 100644 examples/finetune_t2i.py create mode 100644 scripts/diffuser/README.md create mode 100644 scripts/diffuser/run_finetune_t2i.sh diff --git a/.gitignore b/.gitignore index ae14d4951..45555fafc 100644 --- a/.gitignore +++ b/.gitignore @@ -135,6 +135,8 @@ dmypy.json # vscode .vs .vscode +# vscode extensions +.history # Pycharm .idea diff --git a/configs/accelerate_t2i_config.yaml b/configs/accelerate_t2i_config.yaml new file mode 100644 index 000000000..1fb1d0f2c --- /dev/null +++ b/configs/accelerate_t2i_config.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: fp16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/finetune_t2i.py b/examples/finetune_t2i.py new file mode 100644 index 000000000..a19ea0053 --- /dev/null +++ b/examples/finetune_t2i.py @@ -0,0 +1,105 @@ +import sys +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["WANDB_MODE"] = "offline" +import shutil +from pathlib import Path +import gc + +import torch +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel +) +from transformers import ( + AutoTokenizer, + CLIPTextModel +) +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration +from transformers import HfArgumentParser +from peft import LoraConfig + +from lmflow.args import ( + DiffuserModelArguments, + T2IDatasetArguments, + AutoArguments, +) +from lmflow.datasets import Dataset +from lmflow.pipeline.auto_pipeline import AutoPipeline + +def main(): + pipeline_name = "diffuser_tuner" + PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) + + parser = HfArgumentParser((DiffuserModelArguments, T2IDatasetArguments, PipelineArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() + + + logging_dir = Path(pipeline_args.output_dir, pipeline_args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=pipeline_args.output_dir, logging_dir=logging_dir) + accelerator = Accelerator( + mixed_precision=pipeline_args.mixed_precision, + log_with="wandb", + project_config=accelerator_project_config, + ) + + if accelerator.is_main_process and pipeline_args.overwrite_output_dir and os.path.exists(pipeline_args.output_dir): + shutil.rmtree(pipeline_args.output_dir) + + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(model_args.model_name_or_path, subfolder="text_encoder").to("cuda") + vae = AutoencoderKL.from_pretrained(model_args.model_name_or_path, subfolder="vae").to("cuda") + + # dataset = build_t2i_dataset(data_args, tokenizer, text_encoder, vae) + kwargs = {"tokenizer": tokenizer, "text_encoder": text_encoder, "vae": vae} + dataset = Dataset(data_args, backend="t2i", **kwargs) + + del tokenizer, text_encoder, vae + torch.cuda.empty_cache() + gc.collect() + + model = None + if model_args.arch_type == "unet": + model = UNet2DConditionModel.from_pretrained(model_args.model_name_or_path, subfolder=model_args.arch_type) + elif model_args.arch_type == "transformer": + raise NotImplementedError("Transformer model is not implemented.") + else: + raise ValueError("The model type is not supported.") + if model_args.use_lora: + accelerator.print(f"Using LoRA of {model_args.lora_target_modules} for training") + model.requires_grad_(False) + lora_config = LoraConfig( + r=model_args.lora_r, + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout, + init_lora_weights="gaussian", + target_modules=model_args.lora_target_modules, + ) + model.add_adapter(lora_config) + else: + model.requires_grad_(True) + + fintuner = AutoPipeline.get_pipeline( + pipeline_name=pipeline_name, + model_args=model_args, + data_args=data_args, + pipeline_args=pipeline_args, + ) + accelerator.init_trackers("text2image-finetune", config={ + "data_args": data_args, + "model_args": model_args, + "pipeline_args": pipeline_args, + }) + + accelerator.wait_for_everyone() + fintuner.tune( + accelerator=accelerator, + model=model, dataset=dataset + ) + +if __name__ == '__main__': + main() diff --git a/requirements.txt b/requirements.txt index f526b103d..976f0139f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,5 @@ pydantic gradio accelerate>=0.27.2 einops>=0.6.1 -vllm>=0.4.1 \ No newline at end of file +vllm>=0.4.1 +diffusers>=0.29.2 \ No newline at end of file diff --git a/scripts/diffuser/README.md b/scripts/diffuser/README.md new file mode 100644 index 000000000..fd6cfaf57 --- /dev/null +++ b/scripts/diffuser/README.md @@ -0,0 +1,75 @@ +# Diffusers fine-tuning + +## Data Preparation + +Here is a tree struct of the required data organization. In detail, under a `dataset_path` *example*, by default, an `img` directory is used for image files, and `train.json`, `valid.json` and `test.json` are used for reference of training, validation and testinig data. The `valid.json` and `test.json` are optional. If one is provided and the other is not, the two files will be set as the same. + +```bash +data +└── example + ├── img + │   ├── 00.jpg + │   ├── 01.jpg + │   ├── 02.jpg + │   ├── 03.jpg + │   └── 04.jpg + ├── train.json + ├── [valid.json] + └── [test.json] +``` + +The `train.json` should be the format as follow: + +```json +{ + "type": "image_text", + "instances": [ + { + "images": "00.jpg", + "text": "A photo of a dog" + }, + ... + ] +} +``` + +And the `valid.json` and `test.json` should be the format as follow: + +```json +{ + "type": "text-only", + "instances": [ + { + "text": "A photo of a dog in front of Eiffel Tower." + }, + ... + ] +} +``` + +Here is a specific example of the data [dog_t2i_data_example](https://drive.google.com/drive/folders/106ahvIrXbiuZMBw0NuOTjY0vnM_xXARW?usp=sharing) + +## Pretrained Models + +The script will automatically download pretrained weights from huggingface. Just pass the correct path of pretrained weight like `stabilityai/stable-diffusion-2-1` to the script arg `model_name_or_path`. + +## Finetune + +For convenience, we provide a script `run_finetune_t2i.sh` for fine-tuning. It can be used as follow: + +```bash +bash run_finetune_t2i.sh \ + model_name_or_path=stabilityai/stable-diffusion-2-1 \ + dataset_path=data/example +``` + +The `model_name_or_path` is the model name in [huggingface](https://huggingface.co/) or path of the pre-trained model. The `dataset_path` is the path of the dataset, which should be organized as the above tree struct. + +There are also some optional arguments for the script: + +- `arch_type`: The type of the model, which can be `unet` or `transformer`. Default is `unet`. (The `transformer` is not supported yet.) +- `output_dir`: The output directory of the fine-tuned model. Default is `output_dir`. +- `main_process_port`: The main port of the server. Default is `29500`. +- `img_size`: The size of the image for fine-tuning, validation and testing. Default is `768`. + +For more customization, you can refer to the [run_finetune_t2i.sh](./run_finetune_t2i.sh) and example [finetune_t2i.py](../../examples/finetune_t2i.py). diff --git a/scripts/diffuser/run_finetune_t2i.sh b/scripts/diffuser/run_finetune_t2i.sh new file mode 100644 index 000000000..d7c9f30f6 --- /dev/null +++ b/scripts/diffuser/run_finetune_t2i.sh @@ -0,0 +1,72 @@ +# Parses arguments +model_name_or_path=stabilityai/stable-diffusion-2-1 +arch_type="unet" +dataset_path=data/example +output_dir=output +main_process_port=29500 +img_size=768 + +while [[ $# -ge 1 ]]; do + key="$1" + case ${key} in + -m|--model_name_or_path) + model_name_or_path="$2" + shift + ;; + -t|--arch_type) + arch_type="$2" + shift + ;; + -d|--dataset_path) + dataset_path="$2" + shift + ;; + -o|--output_dir) + output_dir="$2" + shift + ;; + -p|--main_process_port) + main_process_port="$2" + shift + ;; + -i|--img_size) + img_size="$2" + shift + ;; + *) + echo "error: unknown option \"${key}\"" 1>&2 + exit 1 + esac + shift +done + +echo "model_name_or_path: ${model_name_or_path}" +echo "arch_type: ${arch_type}" +echo "dataset_path: ${dataset_path}" +echo "output_dir: ${output_dir}" +echo "main_process_port: ${main_process_port}" +echo "img_size: ${img_size}" + + +accelerate launch \ + --config_file=configs/accelerate_t2i_config.yaml \ + --main_process_port=${main_process_port} \ + examples/finetune_t2i.py \ + --model_name_or_path=${model_name_or_path} \ + --arch_type=${arch_type} \ + --use_lora=True \ + --lora_target_module "to_k" "to_q" "to_v" "to_out.0" "add_k_proj" "add_v_proj" \ + --dataset_path=${dataset_path} \ + --image_folder="img" \ + --image_size=${img_size} \ + --train_file="train.json" \ + --validation_file="valid.json" \ + --test_file="test.json" \ + --output_dir=${output_dir} \ + --logging_dir="logs" \ + --overwrite_output_dir=True \ + --mixed_precision="fp16" \ + --num_train_epochs=100 \ + --train_batch_size=1 \ + --learning_rate=1e-4 \ + --valid_steps=50 From 8d89082a0c901fad0c36da19c0791eb6e182500b Mon Sep 17 00:00:00 2001 From: Azily Date: Tue, 9 Jul 2024 13:20:40 +0800 Subject: [PATCH 4/5] [MODI] align with standard --- contrib/text2image/README.md | 4 ++-- contrib/text2image/t2i_dataset.py | 2 +- src/lmflow/datasets/text2image_dataset.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/contrib/text2image/README.md b/contrib/text2image/README.md index 53e277409..4d58342e7 100644 --- a/contrib/text2image/README.md +++ b/contrib/text2image/README.md @@ -28,10 +28,10 @@ The `train.json` should be the format as follow: ```json { - "type": "text-image", + "type": "image_text", "instances": [ { - "image": "00.jpg", + "images": "00.jpg", "text": "A photo of a dog" }, ... diff --git a/contrib/text2image/t2i_dataset.py b/contrib/text2image/t2i_dataset.py index a7a9b03f9..d1484a62f 100644 --- a/contrib/text2image/t2i_dataset.py +++ b/contrib/text2image/t2i_dataset.py @@ -35,7 +35,7 @@ def __len__(self): def __getitem__(self, idx): instance = self.data_instances[idx] - image_path = osp.join(self.image_folder, instance["image"]) + image_path = osp.join(self.image_folder, instance["images"]) image = Image.open(image_path) image = image.convert("RGB") diff --git a/src/lmflow/datasets/text2image_dataset.py b/src/lmflow/datasets/text2image_dataset.py index b245bef24..5ac4d0b4a 100644 --- a/src/lmflow/datasets/text2image_dataset.py +++ b/src/lmflow/datasets/text2image_dataset.py @@ -41,7 +41,7 @@ def __len__(self): def __getitem__(self, idx): instance = self.data_instances[idx] - image_path = osp.join(self.image_folder, instance["image"]) + image_path = osp.join(self.image_folder, instance["images"]) image = Image.open(image_path) image = image.convert("RGB") From 5f882f413e75932a4bd37062e5b7e61e465bab2a Mon Sep 17 00:00:00 2001 From: Azily Date: Tue, 9 Jul 2024 13:37:20 +0800 Subject: [PATCH 5/5] [MODI] readme and default output_path --- contrib/text2image/README.md | 4 ++-- scripts/diffuser/README.md | 6 +++--- scripts/diffuser/run_finetune_t2i.sh | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/contrib/text2image/README.md b/contrib/text2image/README.md index 4d58342e7..1a6f1ee42 100644 --- a/contrib/text2image/README.md +++ b/contrib/text2image/README.md @@ -61,8 +61,8 @@ For convenience, we provide a script `finetune_t2i.sh` for fine-tuning. It can b ```bash bash finetune_t2i.sh \ - model_name_or_path=stabilityai/stable-diffusion-2-1 \ - dataset_path=data/example + --model_name_or_path "stabilityai/stable-diffusion-2-1" \ + --dataset_path "data/example" ``` The `model_name_or_path` is the model name in [huggingface](https://huggingface.co/) or path of the pre-trained model. The `dataset_path` is the path of the dataset, which should be organized as the above tree struct. diff --git a/scripts/diffuser/README.md b/scripts/diffuser/README.md index fd6cfaf57..3be24b172 100644 --- a/scripts/diffuser/README.md +++ b/scripts/diffuser/README.md @@ -58,9 +58,9 @@ The script will automatically download pretrained weights from huggingface. Just For convenience, we provide a script `run_finetune_t2i.sh` for fine-tuning. It can be used as follow: ```bash -bash run_finetune_t2i.sh \ - model_name_or_path=stabilityai/stable-diffusion-2-1 \ - dataset_path=data/example +bash scripts/diffuser/run_finetune_t2i.sh \ + --model_name_or_path "stabilityai/stable-diffusion-2-1" \ + --dataset_path "data/example" ``` The `model_name_or_path` is the model name in [huggingface](https://huggingface.co/) or path of the pre-trained model. The `dataset_path` is the path of the dataset, which should be organized as the above tree struct. diff --git a/scripts/diffuser/run_finetune_t2i.sh b/scripts/diffuser/run_finetune_t2i.sh index d7c9f30f6..3e3963ff8 100644 --- a/scripts/diffuser/run_finetune_t2i.sh +++ b/scripts/diffuser/run_finetune_t2i.sh @@ -2,7 +2,7 @@ model_name_or_path=stabilityai/stable-diffusion-2-1 arch_type="unet" dataset_path=data/example -output_dir=output +output_dir=output_dir main_process_port=29500 img_size=768