From 6c86a3dcde051f86559d046e0906baa829850d34 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 27 Apr 2025 22:11:32 +0200 Subject: [PATCH 01/19] add transformer --- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformer_hunyuan_video_framepack.py | 381 ++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 15 + 5 files changed, 401 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f51a4ef2b3f6..c64a116b525a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -175,6 +175,7 @@ "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", + "HunyuanVideoFramepackTransformer3DModel", "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", "Kandinsky3UNet", @@ -770,6 +771,7 @@ HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, + HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, I2VGenXLUNet, Kandinsky3UNet, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 276b1836a797..1ee2e2a3d023 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -79,6 +79,7 @@ _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] + _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] @@ -156,6 +157,7 @@ FluxTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, + HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 191484fd9692..b690ec6fc096 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -23,6 +23,7 @@ from .transformer_flux import FluxTransformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel + from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py new file mode 100644 index 000000000000..c97ccc88da6c --- /dev/null +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -0,0 +1,381 @@ +# Copyright 2025 The Framepack Team, The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers +from ..cache_utils import CacheMixin +from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous +from .transformer_hunyuan_video import ( + HunyuanVideoPatchEmbed, + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTokenRefiner, + HunyuanVideoTransformerBlock, +) + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class HunyuanVideoFramepackRotaryPosEmbed(nn.Module): + def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.rope_dim = rope_dim + self.theta = theta + + def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device): + frame_indices = frame_indices.unbind(0) + freqs = [self._forward(f, height, width, device) for f in frame_indices] + freqs_cos, freqs_sin = zip(*freqs) + freqs_cos = torch.stack(freqs_cos, dim=0) # [B, W * H * T, D / 2] + freqs_sin = torch.stack(freqs_sin, dim=0) # [B, W * H * T, D / 2] + return freqs_cos, freqs_sin + + def _forward(self, frame_indices, height, width, device): + height = height // self.patch_size + width = width // self.patch_size + grid = torch.meshgrid( + frame_indices.to(device=device, dtype=torch.float32), + torch.arange(0, height, device=device, dtype=torch.float32), + torch.arange(0, width, device=device, dtype=torch.float32), + indexing="ij", + ) # 3 * [W, H, T] + grid = torch.stack(grid, dim=0) # [3, W, H, T] + + freqs = [] + for i in range(3): + freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) + freqs.append(freq) + + freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) + freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + + return freqs_cos, freqs_sin + + +class FramepackClipVisionProjection(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.up = nn.Linear(in_channels, out_channels * 3) + self.down = nn.Linear(out_channels * 3, out_channels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.up(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.down(hidden_states) + return hidden_states + + +class HunyuanVideoPatchEmbedForCleanLatents(nn.Module): + def __init__(self, inner_dim: int): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward( + self, + clean_latents: Optional[torch.Tensor] = None, + clean_latents_2x: Optional[torch.Tensor] = None, + clean_latents_4x: Optional[torch.Tensor] = None, + ): + if clean_latents is not None: + clean_latents = self.proj(clean_latents) + clean_latents = clean_latents.flatten(2).transpose(1, 2) + if clean_latents_2x is not None: + clean_latents_2x = _pad_for_3d_conv(clean_latents_2x, (2, 4, 4)) + clean_latents_2x = self.proj_2x(clean_latents_2x) + clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2) + if clean_latents_4x is not None: + clean_latents_4x = _pad_for_3d_conv(clean_latents_4x, (4, 8, 8)) + clean_latents_4x = self.proj_4x(clean_latents_4x) + clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2) + return clean_latents, clean_latents_2x, clean_latents_4x + + +class HunyuanVideoFramepackTransformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin +): + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"] + _no_split_modules = [ + "HunyuanVideoTransformerBlock", + "HunyuanVideoSingleTransformerBlock", + "HunyuanVideoPatchEmbedForCleanLatents", # TODO + "HunyuanVideoTokenRefiner", + ] + + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_single_layers: int = 40, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 2, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + guidance_embeds: bool = True, + text_embed_dim: int = 4096, + pooled_projection_dim: int = 768, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int] = (16, 56, 56), + has_image_proj: int = False, + image_proj_dim: int = 1152, + has_clean_x_embedder: int = False, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Latent and condition embedders + self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.context_embedder = HunyuanVideoTokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) + self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) + + # 2. RoPE + self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + + # 3. Dual stream transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + + # 4. Single stream transformer blocks + self.single_transformer_blocks = nn.ModuleList( + [ + HunyuanVideoSingleTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_single_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + # Framepack specific modules + self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None + + self.clean_x_embedder = None + if has_clean_x_embedder: + self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(inner_dim) + + self.use_gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + pooled_projections: torch.Tensor, + image_embeddings: torch.Tensor, + latent_indices: torch.Tensor, + guidance: Optional[torch.Tensor] = None, + clean_latents: Optional[torch.Tensor] = None, + clean_latent_indices: Optional[torch.Tensor] = None, + clean_latents_2x: Optional[torch.Tensor] = None, + clean_latent_2x_indices: Optional[torch.Tensor] = None, + clean_latents_4x: Optional[torch.Tensor] = None, + clean_latent_4x_indices: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p, p_t = self.config.patch_size, self.config.patch_size_t + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + hidden_states, image_rotary_emb = self._pack_history_states( + hidden_states, + latent_indices, + clean_latents, + clean_latent_indices, + clean_latents_2x, + clean_latent_2x_indices, + clean_latents_4x, + clean_latent_4x_indices, + ) + + temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance) + encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) + + encoder_hidden_states_image = self.image_projection(image_embeddings) + attention_mask_image = encoder_attention_mask.new_ones((batch_size, encoder_hidden_states_image.shape[1])) + + # must cat before (not after) encoder_hidden_states, due to attn masking + encoder_hidden_states = torch.cat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + encoder_attention_mask = torch.cat([attention_mask_image, encoder_attention_mask], dim=1) + + latent_sequence_length = hidden_states.shape[1] + condition_sequence_length = encoder_hidden_states.shape[1] + sequence_length = latent_sequence_length + condition_sequence_length + attention_mask = torch.zeros( + batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool + ) # [B, N] + + effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] + effective_sequence_length = latent_sequence_length + effective_condition_sequence_length + + for i in range(batch_size): + attention_mask[i, : effective_sequence_length[i]] = True + # [B, 1, 1, N], for broadcasting across attention heads + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + # 5. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p + ) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (hidden_states,) + return Transformer2DModelOutput(sample=hidden_states) + + def _pack_history_states( + self, + hidden_states: torch.Tensor, + latent_indices: torch.Tensor, + clean_latents: Optional[torch.Tensor] = None, + clean_latents_2x: Optional[torch.Tensor] = None, + clean_latents_4x: Optional[torch.Tensor] = None, + clean_latent_indices: Optional[torch.Tensor] = None, + clean_latent_2x_indices: Optional[torch.Tensor] = None, + clean_latent_4x_indices: Optional[torch.Tensor] = None, + ): + batch_size, num_channels, num_frames, height, width = hidden_states.shape + if latent_indices is None: + latent_indices = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1) + + hidden_states = self.x_embedder(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + image_rotary_emb = self.rope( + frame_indices=latent_indices, height=height, width=width, device=hidden_states.device + ) + + clean_latents, clean_latents_2x, clean_latents_4x = self.clean_x_embedder( + clean_latents, clean_latents_2x, clean_latents_4x + ) + + if clean_latents is not None: + clean_rotary_emb_1x = self.rope( + frame_indices=clean_latent_indices, height=height, width=width, device=clean_latents.device + ) + hidden_states = torch.cat([clean_latents, hidden_states], dim=1) + image_rotary_emb = torch.cat([clean_rotary_emb_1x, image_rotary_emb], dim=1) + + if clean_latents_2x is not None and clean_latent_2x_indices is not None: + clean_rotary_emb_2x = self.rope( + frame_indices=clean_latent_2x_indices, height=height, width=width, device=clean_latents_2x.device + ) + clean_rotary_emb_2x = _pad_for_3d_conv(clean_rotary_emb_2x, (2, 2, 2)) + clean_rotary_emb_2x = _center_down_sample_3d(clean_rotary_emb_2x, (2, 2, 2)) + hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1) + image_rotary_emb = torch.cat([clean_rotary_emb_2x, image_rotary_emb], dim=1) + + if clean_latents_4x is not None and clean_latent_4x_indices is not None: + clean_rotary_emb_4x = self.rope( + frame_indices=clean_latent_4x_indices, height=height, width=width, device=clean_latents_4x.device + ) + clean_rotary_emb_4x = _pad_for_3d_conv(clean_rotary_emb_4x, (4, 4, 4)) + clean_rotary_emb_4x = _center_down_sample_3d(clean_rotary_emb_4x, (4, 4, 4)) + hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1) + image_rotary_emb = torch.cat([clean_rotary_emb_4x, image_rotary_emb], dim=1) + + return hidden_states, image_rotary_emb + + +def _pad_for_3d_conv(x, kernel_size): + b, c, t, h, w = x.shape + pt, ph, pw = kernel_size + pad_t = (pt - (t % pt)) % pt + pad_h = (ph - (h % ph)) % ph + pad_w = (pw - (w % pw)) % pw + return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate") + + +def _center_down_sample_3d(x, kernel_size): + return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index bf2f19ee2d26..e850c8f2c0b1 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -565,6 +565,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HunyuanVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] From 5a523cd52734886c6736d31af37f35495a301db9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 28 Apr 2025 23:19:36 +0200 Subject: [PATCH 02/19] add pipeline --- .../transformer_hunyuan_video_framepack.py | 114 +-- .../pipeline_hunyuan_video_framepack.py | 900 ++++++++++++++++++ .../hunyuan_video/pipeline_output.py | 19 + 3 files changed, 976 insertions(+), 57 deletions(-) create mode 100644 src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index c97ccc88da6c..11b03ff7b858 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -98,22 +98,22 @@ def __init__(self, inner_dim: int): def forward( self, - clean_latents: Optional[torch.Tensor] = None, - clean_latents_2x: Optional[torch.Tensor] = None, - clean_latents_4x: Optional[torch.Tensor] = None, + latents_clean: Optional[torch.Tensor] = None, + latents_clean_2x: Optional[torch.Tensor] = None, + latents_clean_4x: Optional[torch.Tensor] = None, ): - if clean_latents is not None: - clean_latents = self.proj(clean_latents) - clean_latents = clean_latents.flatten(2).transpose(1, 2) - if clean_latents_2x is not None: - clean_latents_2x = _pad_for_3d_conv(clean_latents_2x, (2, 4, 4)) - clean_latents_2x = self.proj_2x(clean_latents_2x) - clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2) - if clean_latents_4x is not None: - clean_latents_4x = _pad_for_3d_conv(clean_latents_4x, (4, 8, 8)) - clean_latents_4x = self.proj_4x(clean_latents_4x) - clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2) - return clean_latents, clean_latents_2x, clean_latents_4x + if latents_clean is not None: + latents_clean = self.proj(latents_clean) + latents_clean = latents_clean.flatten(2).transpose(1, 2) + if latents_clean_2x is not None: + latents_clean_2x = _pad_for_3d_conv(latents_clean_2x, (2, 4, 4)) + latents_clean_2x = self.proj_2x(latents_clean_2x) + latents_clean_2x = latents_clean_2x.flatten(2).transpose(1, 2) + if latents_clean_4x is not None: + latents_clean_4x = _pad_for_3d_conv(latents_clean_4x, (4, 8, 8)) + latents_clean_4x = self.proj_4x(latents_clean_4x) + latents_clean_4x = latents_clean_4x.flatten(2).transpose(1, 2) + return latents_clean, latents_clean_2x, latents_clean_4x class HunyuanVideoFramepackTransformer3DModel( @@ -206,15 +206,15 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, pooled_projections: torch.Tensor, - image_embeddings: torch.Tensor, + image_embeds: torch.Tensor, latent_indices: torch.Tensor, guidance: Optional[torch.Tensor] = None, - clean_latents: Optional[torch.Tensor] = None, - clean_latent_indices: Optional[torch.Tensor] = None, - clean_latents_2x: Optional[torch.Tensor] = None, - clean_latent_2x_indices: Optional[torch.Tensor] = None, - clean_latents_4x: Optional[torch.Tensor] = None, - clean_latent_4x_indices: Optional[torch.Tensor] = None, + latents_clean: Optional[torch.Tensor] = None, + indices_latents_clean: Optional[torch.Tensor] = None, + latents_history_2x: Optional[torch.Tensor] = None, + indices_latents_history_2x: Optional[torch.Tensor] = None, + latents_history_4x: Optional[torch.Tensor] = None, + indices_latents_history_4x: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ): @@ -242,18 +242,18 @@ def forward( hidden_states, image_rotary_emb = self._pack_history_states( hidden_states, latent_indices, - clean_latents, - clean_latent_indices, - clean_latents_2x, - clean_latent_2x_indices, - clean_latents_4x, - clean_latent_4x_indices, + latents_clean, + indices_latents_clean, + latents_history_2x, + indices_latents_history_2x, + latents_history_4x, + indices_latents_history_4x, ) temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance) encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) - encoder_hidden_states_image = self.image_projection(image_embeddings) + encoder_hidden_states_image = self.image_projection(image_embeds) attention_mask_image = encoder_attention_mask.new_ones((batch_size, encoder_hidden_states_image.shape[1])) # must cat before (not after) encoder_hidden_states, due to attn masking @@ -319,12 +319,12 @@ def _pack_history_states( self, hidden_states: torch.Tensor, latent_indices: torch.Tensor, - clean_latents: Optional[torch.Tensor] = None, - clean_latents_2x: Optional[torch.Tensor] = None, - clean_latents_4x: Optional[torch.Tensor] = None, - clean_latent_indices: Optional[torch.Tensor] = None, - clean_latent_2x_indices: Optional[torch.Tensor] = None, - clean_latent_4x_indices: Optional[torch.Tensor] = None, + latents_clean: Optional[torch.Tensor] = None, + latents_history_2x: Optional[torch.Tensor] = None, + latents_history_4x: Optional[torch.Tensor] = None, + indices_latents_clean: Optional[torch.Tensor] = None, + indices_latents_history_2x: Optional[torch.Tensor] = None, + indices_latents_history_4x: Optional[torch.Tensor] = None, ): batch_size, num_channels, num_frames, height, width = hidden_states.shape if latent_indices is None: @@ -336,34 +336,34 @@ def _pack_history_states( frame_indices=latent_indices, height=height, width=width, device=hidden_states.device ) - clean_latents, clean_latents_2x, clean_latents_4x = self.clean_x_embedder( - clean_latents, clean_latents_2x, clean_latents_4x + latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder( + latents_clean, latents_history_2x, latents_history_4x ) - if clean_latents is not None: - clean_rotary_emb_1x = self.rope( - frame_indices=clean_latent_indices, height=height, width=width, device=clean_latents.device + if latents_clean is not None: + image_rotary_emb_clean = self.rope( + frame_indices=indices_latents_clean, height=height, width=width, device=latents_clean.device ) - hidden_states = torch.cat([clean_latents, hidden_states], dim=1) - image_rotary_emb = torch.cat([clean_rotary_emb_1x, image_rotary_emb], dim=1) + hidden_states = torch.cat([latents_clean, hidden_states], dim=1) + image_rotary_emb = torch.cat([image_rotary_emb_clean, image_rotary_emb], dim=1) - if clean_latents_2x is not None and clean_latent_2x_indices is not None: - clean_rotary_emb_2x = self.rope( - frame_indices=clean_latent_2x_indices, height=height, width=width, device=clean_latents_2x.device + if latents_history_2x is not None and indices_latents_history_2x is not None: + image_rotary_emb_history_2x = self.rope( + frame_indices=indices_latents_history_2x, height=height, width=width, device=latents_history_2x.device ) - clean_rotary_emb_2x = _pad_for_3d_conv(clean_rotary_emb_2x, (2, 2, 2)) - clean_rotary_emb_2x = _center_down_sample_3d(clean_rotary_emb_2x, (2, 2, 2)) - hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1) - image_rotary_emb = torch.cat([clean_rotary_emb_2x, image_rotary_emb], dim=1) - - if clean_latents_4x is not None and clean_latent_4x_indices is not None: - clean_rotary_emb_4x = self.rope( - frame_indices=clean_latent_4x_indices, height=height, width=width, device=clean_latents_4x.device + image_rotary_emb_history_2x = _pad_for_3d_conv(image_rotary_emb_history_2x, (2, 2, 2)) + image_rotary_emb_history_2x = _center_down_sample_3d(image_rotary_emb_history_2x, (2, 2, 2)) + hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1) + image_rotary_emb = torch.cat([image_rotary_emb_history_2x, image_rotary_emb], dim=1) + + if latents_history_4x is not None and indices_latents_history_4x is not None: + image_rotary_emb_history_4x = self.rope( + frame_indices=indices_latents_history_4x, height=height, width=width, device=latents_history_4x.device ) - clean_rotary_emb_4x = _pad_for_3d_conv(clean_rotary_emb_4x, (4, 4, 4)) - clean_rotary_emb_4x = _center_down_sample_3d(clean_rotary_emb_4x, (4, 4, 4)) - hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1) - image_rotary_emb = torch.cat([clean_rotary_emb_4x, image_rotary_emb], dim=1) + image_rotary_emb_history_4x = _pad_for_3d_conv(image_rotary_emb_history_4x, (4, 4, 4)) + image_rotary_emb_history_4x = _center_down_sample_3d(image_rotary_emb_history_4x, (4, 4, 4)) + hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1) + image_rotary_emb = torch.cat([image_rotary_emb_history_4x, image_rotary_emb], dim=1) return hidden_states, image_rotary_emb diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py new file mode 100644 index 000000000000..4a9a7850719a --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -0,0 +1,900 @@ +# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + LlamaModel, + LlamaTokenizerFast, + SiglipImageProcessor, + SiglipVisionModel, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import HunyuanVideoLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoFramepackPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... height=320, + ... width=512, + ... num_frames=61, + ... num_inference_steps=30, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_llama_prompt_embeds + def _get_llama_prompt_embeds( + self, + prompt: Union[str, List[str]], + prompt_template: Dict[str, Any], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]] = None, + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def encode_image( + self, image: PipelineImageInput, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ): + device = device or self.image_encoder.device + dtype = dtype or self.image_encoder.dtype + image = self.image_processor(images=image, return_tensors="pt").to(device=device, dtype=dtype) + image_embeds = self.image_encoder(**image).last_hidden_state + return image_embeds + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image_latents( + self, + image: torch.Tensor, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + image = image.to(device=self.vae.device, dtype=self.vae.dtype) + latents = self.vae.encode(image).latent_dist.sample(generator=generator) + latents = latents * self.vae.config.scaling_factor + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Union[str, List[str]] = None, + prompt_2: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Union[str, List[str]] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + latent_window_size: int = 9, + num_inference_steps: int = 50, + sigmas: List[float] = None, + true_cfg_scale: float = 1.0, + guidance_scale: float = 6.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + image_latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. Note that the only available HunyuanVideo model is + CFG-distilled, which means that traditional guidance between unconditional and conditional latent is + not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoFramepackPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoFramepackPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoFramepackPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images and the second element is a list + of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) + content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + transformer_dtype = self.transformer.dtype + vae_dtype = self.vae.dtype + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare image + image = self.video_processor.preprocess(image, height, width) + image_embeds = self.encode_image(image, device=device) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + window_size = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 + num_latent_sections = max(1, (num_frames + window_size - 1) // window_size) + # Specific to the released checkpoint: https://huggingface.co/lllyasviel/FramePackI2V_HY + # TODO: find a more generic way in future if there are more checkpoints + history_sizes = [1, 9, 16] + history_latents = torch.zeros( + batch_size, + num_channels_latents, + sum(history_sizes), + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + dtype=torch.float32, + ) + history_video = None + + image_latents = self.prepare_image_latentss(image, generator=generator, latents=image_latents) + + latent_paddings = list(reversed(range(num_latent_sections))) + if num_latent_sections > 4: + latent_paddings = [3] + [2] * (num_latent_sections - 3) + [1, 0] + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + for i in range(num_latent_sections): + current_latent_padding = latent_paddings[i] + is_last_section = current_latent_padding == 0 + latent_padding_size = current_latent_padding * latent_window_size + total_generated_latent_frames = 0 + + indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes])).unsqueeze(0) + ( + indices_prefix, + indices_padding, + indices_latents_clean, + indices_postfix, + indices_latents_history_2x, + indices_latents_history_4x, + ) = indices.split([1, latent_padding_size, latent_window_size, *history_sizes], dim=1) + # Inverted anti-drifting sampling: Figure 2(c) in the paper + indices_clean_latents = torch.cat([indices_prefix, indices_postfix], dim=1) + + latents_prefix = image_latents.to(device=device, dtype=transformer_dtype) + latents_postfix, latents_history_2x, latents_history_4x = history_latents[ + :, :, : sum(history_sizes) + ].split(history_sizes, dim=2) + latents_clean = torch.cat([latents_prefix, latents_postfix], dim=2) + + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + image_embeds=image_embeds, + latents_clean=latents_clean, + indices_latents_clean=indices_clean_latents, + latents_history_2x=latents_history_2x, + indices_latents_history_2x=indices_latents_history_2x, + latents_history_4x=latents_history_4x, + indices_latents_history_4x=indices_latents_history_4x, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + image_embeds=image_embeds, + latents_clean=latents_clean, + indices_latents_clean=indices_clean_latents, + latents_history_2x=latents_history_2x, + indices_latents_history_2x=indices_latents_history_2x, + latents_history_4x=latents_history_4x, + indices_latents_history_4x=indices_latents_history_4x, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if is_last_section: + latents = torch.cat([latents_prefix, latents]) + + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([latents.to(history_latents), history_latents], dim=2) + + real_history_latents = history_latents[:, :, :total_generated_latent_frames] + + if history_video is None: + if not output_type == "latent": + current_video = real_history_latents.to(vae_dtype) / self.vae.config.scaling_factor + current_video = self.vae.decode(current_video, return_dict=False)[0] + else: + history_video = [real_history_latents] + else: + if not output_type == "latents": + section_latent_frames = ( + (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) + ) + overlapped_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 + current_video = ( + real_history_latents[:, :, :section_latent_frames].to(vae_dtype) + / self.vae.config.scaling_factor + ) + current_video = self.vae.decode(current_video, return_dict=False)[0] + current_video = self._soft_append(current_video, history_video, overlapped_frames) + else: + history_video.append(real_history_latents) + + self._current_timestep = None + + if not output_type == "latent": + video = self.video_processor.postprocess_video(history_video, output_type=output_type) + else: + video = history_video + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoFramepackPipelineOutput(frames=video) + + def _soft_append(current: torch.Tensor, history: torch.Tensor, overlap: int = 0): + if overlap <= 0: + return torch.cat([current, history], dim=2) + + assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})" + assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})" + + weights = torch.linspace(1, 0, overlap, dtype=current.dtype, device=current.device).view(1, 1, -1, 1, 1) + blended = weights * current[:, :, -overlap:] + (1 - weights) * history[:, :, :overlap] + output = torch.cat([current[:, :, :-overlap], blended, history[:, :, overlap:]], dim=2) + + return output.to(current) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_output.py b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py index c5cb853e3932..fae0370a53b7 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_output.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py @@ -1,5 +1,8 @@ from dataclasses import dataclass +from typing import List, Union +import numpy as np +import PIL.Image import torch from diffusers.utils import BaseOutput @@ -18,3 +21,19 @@ class HunyuanVideoPipelineOutput(BaseOutput): """ frames: torch.Tensor + + +@dataclass +class HunyuanVideoFramepackPipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. Or, a list of torch tensors where each tensor + corresponds to a latent that decodes to multiple frames. + """ + + frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]], List[torch.Tensor]] From e2fde455496be5eac25f16101fcd49f09c93e23c Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 01:07:11 +0200 Subject: [PATCH 03/19] fixes --- src/diffusers/__init__.py | 2 + .../transformer_hunyuan_video_framepack.py | 14 ++-- src/diffusers/pipelines/__init__.py | 2 + .../pipelines/hunyuan_video/__init__.py | 2 + .../pipeline_hunyuan_video_framepack.py | 74 ++++++++++--------- 5 files changed, 52 insertions(+), 42 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c64a116b525a..7d4f2c999ab8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -377,6 +377,7 @@ "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", "HunyuanSkyreelsImageToVideoPipeline", + "HunyuanVideoFramepackPipeline", "HunyuanVideoImageToVideoPipeline", "HunyuanVideoPipeline", "I2VGenXLPipeline", @@ -952,6 +953,7 @@ HunyuanDiTPAGPipeline, HunyuanDiTPipeline, HunyuanSkyreelsImageToVideoPipeline, + HunyuanVideoFramepackPipeline, HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, I2VGenXLPipeline, diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index 11b03ff7b858..2910e809aadc 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -207,7 +207,7 @@ def forward( encoder_attention_mask: torch.Tensor, pooled_projections: torch.Tensor, image_embeds: torch.Tensor, - latent_indices: torch.Tensor, + indices_latents: torch.Tensor, guidance: Optional[torch.Tensor] = None, latents_clean: Optional[torch.Tensor] = None, indices_latents_clean: Optional[torch.Tensor] = None, @@ -241,11 +241,11 @@ def forward( hidden_states, image_rotary_emb = self._pack_history_states( hidden_states, - latent_indices, + indices_latents, latents_clean, - indices_latents_clean, latents_history_2x, indices_latents_history_2x, + indices_latents_clean, latents_history_4x, indices_latents_history_4x, ) @@ -318,7 +318,7 @@ def forward( def _pack_history_states( self, hidden_states: torch.Tensor, - latent_indices: torch.Tensor, + indices_latents: torch.Tensor, latents_clean: Optional[torch.Tensor] = None, latents_history_2x: Optional[torch.Tensor] = None, latents_history_4x: Optional[torch.Tensor] = None, @@ -327,13 +327,13 @@ def _pack_history_states( indices_latents_history_4x: Optional[torch.Tensor] = None, ): batch_size, num_channels, num_frames, height, width = hidden_states.shape - if latent_indices is None: - latent_indices = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1) + if indices_latents is None: + indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1) hidden_states = self.x_embedder(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) image_rotary_emb = self.rope( - frame_indices=latent_indices, height=height, width=width, device=hidden_states.device + frame_indices=indices_latents, height=height, width=width, device=hidden_states.device ) latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder( diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 011f23ed371c..faa5abfb03bb 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -227,6 +227,7 @@ "HunyuanVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline", "HunyuanVideoImageToVideoPipeline", + "HunyuanVideoFramepackPipeline", ] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", @@ -589,6 +590,7 @@ from .hidream_image import HiDreamImagePipeline from .hunyuan_video import ( HunyuanSkyreelsImageToVideoPipeline, + HunyuanVideoFramepackPipeline, HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, ) diff --git a/src/diffusers/pipelines/hunyuan_video/__init__.py b/src/diffusers/pipelines/hunyuan_video/__init__.py index d9cacad24f17..d42d38fac979 100644 --- a/src/diffusers/pipelines/hunyuan_video/__init__.py +++ b/src/diffusers/pipelines/hunyuan_video/__init__.py @@ -24,6 +24,7 @@ else: _import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"] _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"] + _import_structure["pipeline_hunyuan_video_framepack"] = ["HunyuanVideoFramepackPipeline"] _import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -36,6 +37,7 @@ else: from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline from .pipeline_hunyuan_video import HunyuanVideoPipeline + from .pipeline_hunyuan_video_framepack import HunyuanVideoFramepackPipeline from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline else: diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index 4a9a7850719a..7f3b7ff250b2 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# Copyright 2025 The Framepack Team, The HunyuanVideo Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...loaders import HunyuanVideoLoraLoaderMixin -from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoFramepackTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor @@ -182,7 +182,7 @@ def __init__( self, text_encoder: LlamaModel, tokenizer: LlamaTokenizerFast, - transformer: HunyuanVideoTransformer3DModel, + transformer: HunyuanVideoFramepackTransformer3DModel, vae: AutoencoderKLHunyuanVideo, scheduler: FlowMatchEulerDiscreteScheduler, text_encoder_2: CLIPTextModel, @@ -353,13 +353,14 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask def encode_image( - self, image: PipelineImageInput, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + self, image: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ): - device = device or self.image_encoder.device - dtype = dtype or self.image_encoder.dtype - image = self.image_processor(images=image, return_tensors="pt").to(device=device, dtype=dtype) + image = (image + 1) / 2.0 # [-1, 1] -> [0, 1] + image = self.feature_extractor(images=image, return_tensors="pt").to( + device=self.image_encoder.device, dtype=self.image_encoder.dtype + ) image_embeds = self.image_encoder(**image).last_hidden_state - return image_embeds + return image_embeds.to(dtype=dtype) def check_inputs( self, @@ -445,12 +446,11 @@ def prepare_image_latents( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if latents is not None: - return latents.to(device=device, dtype=dtype) - image = image.to(device=self.vae.device, dtype=self.vae.dtype) - latents = self.vae.encode(image).latent_dist.sample(generator=generator) - latents = latents * self.vae.config.scaling_factor - return latents + if latents is None: + image = image.unsqueeze(2).to(device=self.vae.device, dtype=self.vae.dtype) + latents = self.vae.encode(image).latent_dist.sample(generator=generator) + latents = latents * self.vae.config.scaling_factor + return latents.to(device=device, dtype=dtype) def enable_vae_slicing(self): r""" @@ -703,7 +703,7 @@ def __call__( # 4. Prepare image image = self.video_processor.preprocess(image, height, width) - image_embeds = self.encode_image(image, device=device) + image_embeds = self.encode_image(image, device=device).to(transformer_dtype) # 4. Prepare timesteps sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas @@ -715,25 +715,28 @@ def __call__( num_latent_sections = max(1, (num_frames + window_size - 1) // window_size) # Specific to the released checkpoint: https://huggingface.co/lllyasviel/FramePackI2V_HY # TODO: find a more generic way in future if there are more checkpoints - history_sizes = [1, 9, 16] + history_sizes = [1, 2, 16] history_latents = torch.zeros( batch_size, num_channels_latents, sum(history_sizes), height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, + device=device, dtype=torch.float32, ) history_video = None - image_latents = self.prepare_image_latentss(image, generator=generator, latents=image_latents) + image_latents = self.prepare_image_latents( + image, dtype=torch.float32, device=device, generator=generator, latents=image_latents + ) latent_paddings = list(reversed(range(num_latent_sections))) if num_latent_sections > 4: latent_paddings = [3] + [2] * (num_latent_sections - 3) + [1, 0] # 6. Prepare guidance condition - guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + guidance = torch.tensor([guidance_scale] * batch_size, dtype=transformer_dtype, device=device) * 1000.0 # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -749,7 +752,7 @@ def __call__( ( indices_prefix, indices_padding, - indices_latents_clean, + indices_latents, indices_postfix, indices_latents_history_2x, indices_latents_history_4x, @@ -757,7 +760,7 @@ def __call__( # Inverted anti-drifting sampling: Figure 2(c) in the paper indices_clean_latents = torch.cat([indices_prefix, indices_postfix], dim=1) - latents_prefix = image_latents.to(device=device, dtype=transformer_dtype) + latents_prefix = image_latents latents_postfix, latents_history_2x, latents_history_4x = history_latents[ :, :, : sum(history_sizes) ].split(history_sizes, dim=2) @@ -781,23 +784,22 @@ def __call__( continue self._current_timestep = t - latent_model_input = latents.to(transformer_dtype) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) + timestep = t.expand(latents.shape[0]) noise_pred = self.transformer( - hidden_states=latent_model_input, + hidden_states=latents.to(transformer_dtype), timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, pooled_projections=pooled_prompt_embeds, - guidance=guidance, image_embeds=image_embeds, - latents_clean=latents_clean, + indices_latents=indices_latents, + guidance=guidance, + latents_clean=latents_clean.to(transformer_dtype), indices_latents_clean=indices_clean_latents, - latents_history_2x=latents_history_2x, + latents_history_2x=latents_history_2x.to(transformer_dtype), indices_latents_history_2x=indices_latents_history_2x, - latents_history_4x=latents_history_4x, + latents_history_4x=latents_history_4x.to(transformer_dtype), indices_latents_history_4x=indices_latents_history_4x, attention_kwargs=attention_kwargs, return_dict=False, @@ -805,26 +807,27 @@ def __call__( if do_true_cfg: neg_noise_pred = self.transformer( - hidden_states=latent_model_input, + hidden_states=latents.to(transformer_dtype), timestep=timestep, encoder_hidden_states=negative_prompt_embeds, encoder_attention_mask=negative_prompt_attention_mask, pooled_projections=negative_pooled_prompt_embeds, image_embeds=image_embeds, - latents_clean=latents_clean, + indices_latents=indices_latents, + guidance=guidance, + latents_clean=latents_clean.to(transformer_dtype), indices_latents_clean=indices_clean_latents, - latents_history_2x=latents_history_2x, + latents_history_2x=latents_history_2x.to(transformer_dtype), indices_latents_history_2x=indices_latents_history_2x, - latents_history_4x=latents_history_4x, + latents_history_4x=latents_history_4x.to(transformer_dtype), indices_latents_history_4x=indices_latents_history_4x, - guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents = self.scheduler.step(noise_pred.float(), t, latents, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -846,7 +849,7 @@ def __call__( latents = torch.cat([latents_prefix, latents]) total_generated_latent_frames += latents.shape[2] - history_latents = torch.cat([latents.to(history_latents), history_latents], dim=2) + history_latents = torch.cat([latents, history_latents], dim=2) real_history_latents = history_latents[:, :, :total_generated_latent_frames] @@ -874,6 +877,7 @@ def __call__( self._current_timestep = None if not output_type == "latent": + history_video = history_video[:, :, :num_frames] video = self.video_processor.postprocess_video(history_video, output_type=output_type) else: video = history_video From ea7e3cd5655564504b1ce5a9de05491802a56e55 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 01:07:32 +0200 Subject: [PATCH 04/19] make fix-copies --- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index b3c6efb8cdcf..6bf4d2c3b570 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -692,6 +692,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HunyuanVideoFramepackPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HunyuanVideoImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From ee54a1918146a23195ab1e1cd63815b0e5390ffd Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 02:23:33 +0200 Subject: [PATCH 05/19] update --- docs/source/en/api/pipelines/hunyuan_video.md | 1 + .../transformer_hunyuan_video_framepack.py | 90 ++++++++++++++----- .../pipeline_hunyuan_video_framepack.py | 25 +++--- 3 files changed, 83 insertions(+), 33 deletions(-) diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md index 5d068c8b6ef8..a2c8e8b20dfa 100644 --- a/docs/source/en/api/pipelines/hunyuan_video.md +++ b/docs/source/en/api/pipelines/hunyuan_video.md @@ -52,6 +52,7 @@ The following models are available for the image-to-video pipeline: | [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. | | [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). | | [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) | +- [`lllyasviel/FramePackI2V_HY`](https://huggingface.co/lllyasviel/FramePackI2V_HY) | lllyasviel's paper introducing a new technique for long-context video generation called [Framepack](https://arxiv.org/abs/2504.12626). | ## Quantization diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index 2910e809aadc..77686085dcec 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -22,11 +22,12 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers from ..cache_utils import CacheMixin -from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, get_1d_rotary_pos_embed +from ..embeddings import get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous from .transformer_hunyuan_video import ( + HunyuanVideoConditionEmbedding, HunyuanVideoPatchEmbed, HunyuanVideoSingleTransformerBlock, HunyuanVideoTokenRefiner, @@ -48,11 +49,16 @@ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], thet def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device): frame_indices = frame_indices.unbind(0) - freqs = [self._forward(f, height, width, device) for f in frame_indices] - freqs_cos, freqs_sin = zip(*freqs) - freqs_cos = torch.stack(freqs_cos, dim=0) # [B, W * H * T, D / 2] - freqs_sin = torch.stack(freqs_sin, dim=0) # [B, W * H * T, D / 2] - return freqs_cos, freqs_sin + # This is from the original code. We don't call _forward for each batch index because we know that + # each batch has the same frame indices. However, it may be possible that the frame indices don't + # always be the same for every item in a batch (such as in training). We cannot use the original + # implementation because our `apply_rotary_emb` function broadcasts across the batch dim. + # freqs = [self._forward(f, height, width, device) for f in frame_indices] + # freqs_cos, freqs_sin = zip(*freqs) + # freqs_cos = torch.stack(freqs_cos, dim=0) # [B, W * H * T, D / 2] + # freqs_sin = torch.stack(freqs_sin, dim=0) # [B, W * H * T, D / 2] + # return freqs_cos, freqs_sin + return self._forward(frame_indices[0], height, width, device) def _forward(self, frame_indices, height, width, device): height = height // self.patch_size @@ -89,7 +95,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class HunyuanVideoPatchEmbedForCleanLatents(nn.Module): +class HunyuanVideoHistoryPatchEmbed(nn.Module): def __init__(self, inner_dim: int): super().__init__() self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) @@ -147,6 +153,7 @@ def __init__( pooled_projection_dim: int = 768, rope_theta: float = 256.0, rope_axes_dim: Tuple[int] = (16, 56, 56), + image_condition_type: Optional[str] = None, has_image_proj: int = False, image_proj_dim: int = 1152, has_clean_x_embedder: int = False, @@ -161,7 +168,9 @@ def __init__( self.context_embedder = HunyuanVideoTokenRefiner( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) - self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) + self.time_text_embed = HunyuanVideoConditionEmbedding( + inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type + ) # 2. RoPE self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) @@ -195,7 +204,7 @@ def __init__( self.clean_x_embedder = None if has_clean_x_embedder: - self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(inner_dim) + self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(inner_dim) self.use_gradient_checkpointing = False @@ -238,19 +247,20 @@ def forward( post_patch_num_frames = num_frames // p_t post_patch_height = height // p post_patch_width = width // p + original_context_length = post_patch_num_frames * post_patch_height * post_patch_width hidden_states, image_rotary_emb = self._pack_history_states( hidden_states, indices_latents, latents_clean, latents_history_2x, - indices_latents_history_2x, - indices_latents_clean, latents_history_4x, + indices_latents_clean, + indices_latents_history_2x, indices_latents_history_4x, ) - temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance) + temb, _ = self.time_text_embed(timestep, pooled_projections, guidance) encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) encoder_hidden_states_image = self.image_projection(image_embeds) @@ -298,6 +308,7 @@ def forward( ) # 5. Output projection + hidden_states = hidden_states[:, -original_context_length:] hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) @@ -331,44 +342,75 @@ def _pack_history_states( indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1) hidden_states = self.x_embedder(hidden_states) - hidden_states = hidden_states.flatten(2).transpose(1, 2) image_rotary_emb = self.rope( frame_indices=indices_latents, height=height, width=width, device=hidden_states.device ) + image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification + pph, ppw = height // self.config.patch_size, width // self.config.patch_size latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder( latents_clean, latents_history_2x, latents_history_4x ) if latents_clean is not None: + hidden_states = torch.cat([latents_clean, hidden_states], dim=1) + image_rotary_emb_clean = self.rope( frame_indices=indices_latents_clean, height=height, width=width, device=latents_clean.device ) - hidden_states = torch.cat([latents_clean, hidden_states], dim=1) - image_rotary_emb = torch.cat([image_rotary_emb_clean, image_rotary_emb], dim=1) + image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0) + image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0) if latents_history_2x is not None and indices_latents_history_2x is not None: + hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1) + image_rotary_emb_history_2x = self.rope( frame_indices=indices_latents_history_2x, height=height, width=width, device=latents_history_2x.device ) - image_rotary_emb_history_2x = _pad_for_3d_conv(image_rotary_emb_history_2x, (2, 2, 2)) - image_rotary_emb_history_2x = _center_down_sample_3d(image_rotary_emb_history_2x, (2, 2, 2)) - hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1) - image_rotary_emb = torch.cat([image_rotary_emb_history_2x, image_rotary_emb], dim=1) + image_rotary_emb_history_2x = self._pad_rotary_emb( + image_rotary_emb_history_2x, indices_latents_history_2x.size(1), pph, ppw, (2, 2, 2) + ) + image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0) + image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0) if latents_history_4x is not None and indices_latents_history_4x is not None: + hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1) + image_rotary_emb_history_4x = self.rope( frame_indices=indices_latents_history_4x, height=height, width=width, device=latents_history_4x.device ) - image_rotary_emb_history_4x = _pad_for_3d_conv(image_rotary_emb_history_4x, (4, 4, 4)) - image_rotary_emb_history_4x = _center_down_sample_3d(image_rotary_emb_history_4x, (4, 4, 4)) - hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1) - image_rotary_emb = torch.cat([image_rotary_emb_history_4x, image_rotary_emb], dim=1) + image_rotary_emb_history_4x = self._pad_rotary_emb( + image_rotary_emb_history_4x, indices_latents_history_4x.size(1), pph, ppw, (4, 4, 4) + ) + image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0) + image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0) return hidden_states, image_rotary_emb + def _pad_rotary_emb( + self, + image_rotary_emb: Tuple[torch.Tensor], + num_frames: int, + height: int, + width: int, + kernel_size: Tuple[int, int, int], + ): + # freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim + freqs_cos, freqs_sin = image_rotary_emb + freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (num_frames, height, width)) + freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (num_frames, height, width)) + freqs_cos = _pad_for_3d_conv(freqs_cos, kernel_size) + freqs_sin = _pad_for_3d_conv(freqs_sin, kernel_size) + freqs_cos = _center_down_sample_3d(freqs_cos, kernel_size) + freqs_sin = _center_down_sample_3d(freqs_sin, kernel_size) + freqs_cos = freqs_cos.flatten(2).permute(0, 2, 1).squeeze(0) + freqs_sin = freqs_sin.flatten(2).permute(0, 2, 1).squeeze(0) + return freqs_cos, freqs_sin + def _pad_for_3d_conv(x, kernel_size): + if isinstance(x, (tuple, list)): + return tuple(_pad_for_3d_conv(i, kernel_size) for i in x) b, c, t, h, w = x.shape pt, ph, pw = kernel_size pad_t = (pt - (t % pt)) % pt @@ -378,4 +420,6 @@ def _pad_for_3d_conv(x, kernel_size): def _center_down_sample_3d(x, kernel_size): + if isinstance(x, (tuple, list)): + return tuple(_center_down_sample_3d(i, kernel_size) for i in x) return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index 7f3b7ff250b2..a2ef2a3341bd 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -705,14 +705,10 @@ def __call__( image = self.video_processor.preprocess(image, height, width) image_embeds = self.encode_image(image, device=device).to(transformer_dtype) - # 4. Prepare timesteps - sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) - # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels - window_size = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 - num_latent_sections = max(1, (num_frames + window_size - 1) // window_size) + window_num_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 + num_latent_sections = max(1, (num_frames + window_num_frames - 1) // window_num_frames) # Specific to the released checkpoint: https://huggingface.co/lllyasviel/FramePackI2V_HY # TODO: find a more generic way in future if there are more checkpoints history_sizes = [1, 2, 16] @@ -739,10 +735,15 @@ def __call__( guidance = torch.tensor([guidance_scale] * batch_size, dtype=transformer_dtype, device=device) * 1000.0 # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas for i in range(num_latent_sections): + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, sigmas=sigmas + ) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + current_latent_padding = latent_paddings[i] is_last_section = current_latent_padding == 0 latent_padding_size = current_latent_padding * latent_window_size @@ -771,7 +772,7 @@ def __call__( num_channels_latents, height, width, - num_frames, + window_num_frames, dtype=torch.float32, device=device, generator=generator, @@ -877,7 +878,11 @@ def __call__( self._current_timestep = None if not output_type == "latent": - history_video = history_video[:, :, :num_frames] + generated_frames = history_video.size(2) + generated_frames = ( + generated_frames - 1 + ) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + history_video = history_video[:, :, :generated_frames] video = self.video_processor.postprocess_video(history_video, output_type=output_type) else: video = history_video From f0d286b144456470660bc1c837d2d5f93d595928 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 06:49:02 +0200 Subject: [PATCH 06/19] add flux mu shift --- .../pipeline_hunyuan_video_framepack.py | 72 +++++++++++++------ 1 file changed, 49 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index a2ef2a3341bd..05e4a635358c 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -89,6 +90,20 @@ } +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -356,11 +371,11 @@ def encode_image( self, image: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ): image = (image + 1) / 2.0 # [-1, 1] -> [0, 1] - image = self.feature_extractor(images=image, return_tensors="pt").to( + image = self.feature_extractor(images=image, return_tensors="pt", do_rescale=False).to( device=self.image_encoder.device, dtype=self.image_encoder.dtype ) image_embeds = self.image_encoder(**image).last_hidden_state - return image_embeds.to(dtype=dtype) + return image_embeds.to(device=device, dtype=dtype) def check_inputs( self, @@ -722,6 +737,7 @@ def __call__( dtype=torch.float32, ) history_video = None + total_generated_latent_frames = 0 image_latents = self.prepare_image_latents( image, dtype=torch.float32, device=device, generator=generator, latents=image_latents @@ -735,19 +751,10 @@ def __call__( guidance = torch.tensor([guidance_scale] * batch_size, dtype=transformer_dtype, device=device) * 1000.0 # 7. Denoising loop - sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas - for i in range(num_latent_sections): - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, sigmas=sigmas - ) - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) - current_latent_padding = latent_paddings[i] is_last_section = current_latent_padding == 0 latent_padding_size = current_latent_padding * latent_window_size - total_generated_latent_frames = 0 indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes])).unsqueeze(0) ( @@ -779,6 +786,25 @@ def __call__( latents=latents, ) + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + image_seq_len = ( + latents.shape[1] * latents.shape[2] * latents.shape[3] / self.transformer.config.patch_size**2 + ) + exp_max = 7.0 + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + mu = min(mu, math.log(exp_max)) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu + ) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -847,7 +873,7 @@ def __call__( xm.mark_step() if is_last_section: - latents = torch.cat([latents_prefix, latents]) + latents = torch.cat([image_latents, latents], dim=2) total_generated_latent_frames += latents.shape[2] history_latents = torch.cat([latents, history_latents], dim=2) @@ -857,11 +883,11 @@ def __call__( if history_video is None: if not output_type == "latent": current_video = real_history_latents.to(vae_dtype) / self.vae.config.scaling_factor - current_video = self.vae.decode(current_video, return_dict=False)[0] + history_video = self.vae.decode(current_video, return_dict=False)[0] else: history_video = [real_history_latents] else: - if not output_type == "latents": + if not output_type == "latent": section_latent_frames = ( (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) ) @@ -871,7 +897,7 @@ def __call__( / self.vae.config.scaling_factor ) current_video = self.vae.decode(current_video, return_dict=False)[0] - current_video = self._soft_append(current_video, history_video, overlapped_frames) + history_video = self._soft_append(current_video, history_video, overlapped_frames) else: history_video.append(real_history_latents) @@ -895,15 +921,15 @@ def __call__( return HunyuanVideoFramepackPipelineOutput(frames=video) - def _soft_append(current: torch.Tensor, history: torch.Tensor, overlap: int = 0): + def _soft_append(self, history: torch.Tensor, current: torch.Tensor, overlap: int = 0): if overlap <= 0: - return torch.cat([current, history], dim=2) + return torch.cat([history, current], dim=2) - assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})" - assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})" + assert history.shape[2] >= overlap, f"Current length ({history.shape[2]}) must be >= overlap ({overlap})" + assert current.shape[2] >= overlap, f"History length ({current.shape[2]}) must be >= overlap ({overlap})" - weights = torch.linspace(1, 0, overlap, dtype=current.dtype, device=current.device).view(1, 1, -1, 1, 1) - blended = weights * current[:, :, -overlap:] + (1 - weights) * history[:, :, :overlap] - output = torch.cat([current[:, :, :-overlap], blended, history[:, :, overlap:]], dim=2) + weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1) + blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap] + output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2) - return output.to(current) + return output.to(history) From 6e1b557b8f9e5caed0467851c28296e2c50f3389 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 08:38:59 +0200 Subject: [PATCH 07/19] update example snippet --- .../pipeline_hunyuan_video_framepack.py | 46 +++++++++++++------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index 05e4a635358c..3fb4085d1683 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -53,25 +53,41 @@ Examples: ```python >>> import torch - >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel - >>> from diffusers.utils import export_to_video + >>> from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel + >>> from diffusers.utils import export_to_video, load_image + >>> from transformers import SiglipImageProcessor, SiglipVisionModel - >>> model_id = "hunyuanvideo-community/HunyuanVideo" - >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( - ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + >>> transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained( + ... "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16 + ... ) + >>> feature_extractor = SiglipImageProcessor.from_pretrained( + ... "lllyasviel/flux_redux_bfl", subfolder="feature_extractor" + ... ) + >>> image_encoder = SiglipVisionModel.from_pretrained( + ... "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16 + ... ) + >>> pipe = HunyuanVideoFramepackPipeline.from_pretrained( + ... "hunyuanvideo-community/HunyuanVideo", + ... transformer=transformer, + ... feature_extractor=feature_extractor, + ... image_encoder=image_encoder, + ... torch_dtype=torch.float16, ... ) - >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) >>> pipe.vae.enable_tiling() >>> pipe.to("cuda") + >>> image = load_image("inputs/penguin.png") >>> output = pipe( - ... prompt="A cat walks on the grass, realistic", - ... height=320, - ... width=512, - ... num_frames=61, + ... image=image, + ... prompt="A penguin dancing in the snow", + ... height=832, + ... width=480, + ... num_frames=91, ... num_inference_steps=30, + ... guidance_scale=9.0, + ... generator=torch.Generator().manual_seed(0), ... ).frames[0] - >>> export_to_video(output, "output.mp4", fps=15) + >>> export_to_video(output, "output.mp4", fps=30) ``` """ @@ -882,8 +898,8 @@ def __call__( if history_video is None: if not output_type == "latent": - current_video = real_history_latents.to(vae_dtype) / self.vae.config.scaling_factor - history_video = self.vae.decode(current_video, return_dict=False)[0] + current_latents = real_history_latents.to(vae_dtype) / self.vae.config.scaling_factor + history_video = self.vae.decode(current_latents, return_dict=False)[0] else: history_video = [real_history_latents] else: @@ -892,11 +908,11 @@ def __call__( (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) ) overlapped_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 - current_video = ( + current_latents = ( real_history_latents[:, :, :section_latent_frames].to(vae_dtype) / self.vae.config.scaling_factor ) - current_video = self.vae.decode(current_video, return_dict=False)[0] + current_video = self.vae.decode(current_latents, return_dict=False)[0] history_video = self._soft_append(current_video, history_video, overlapped_frames) else: history_video.append(real_history_latents) From 88e8dd38053b6f2574fc46f2559dd0decd6d7038 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 09:45:38 +0200 Subject: [PATCH 08/19] debug --- .../transformer_hunyuan_video_framepack.py | 216 +++++++++++++----- .../pipeline_hunyuan_video_framepack.py | 20 +- 2 files changed, 169 insertions(+), 67 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index 77686085dcec..0bfc1ac39b51 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn @@ -22,7 +22,6 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers from ..cache_utils import CacheMixin -from ..embeddings import get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous @@ -38,48 +37,92 @@ logger = get_logger(__name__) # pylint: disable=invalid-name -class HunyuanVideoFramepackRotaryPosEmbed(nn.Module): - def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: +# class HunyuanVideoFramepackRotaryPosEmbed(nn.Module): +# def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: +# super().__init__() + +# self.patch_size = patch_size +# self.patch_size_t = patch_size_t +# self.rope_dim = rope_dim +# self.theta = theta + +# def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device): +# frame_indices = frame_indices.unbind(0) +# # This is from the original code. We don't call _forward for each batch index because we know that +# # each batch has the same frame indices. However, it may be possible that the frame indices don't +# # always be the same for every item in a batch (such as in training). We cannot use the original +# # implementation because our `apply_rotary_emb` function broadcasts across the batch dim. +# # freqs = [self._forward(f, height, width, device) for f in frame_indices] +# # freqs_cos, freqs_sin = zip(*freqs) +# # freqs_cos = torch.stack(freqs_cos, dim=0) # [B, W * H * T, D / 2] +# # freqs_sin = torch.stack(freqs_sin, dim=0) # [B, W * H * T, D / 2] +# # return freqs_cos, freqs_sin +# return self._forward(frame_indices[0], height, width, device) + +# def _forward(self, frame_indices, height, width, device): +# height = height // self.patch_size +# width = width // self.patch_size +# grid = torch.meshgrid( +# frame_indices.to(device=device, dtype=torch.float32), +# torch.arange(0, height, device=device, dtype=torch.float32), +# torch.arange(0, width, device=device, dtype=torch.float32), +# indexing="ij", +# ) # 3 * [W, H, T] +# grid = torch.stack(grid, dim=0) # [3, W, H, T] + +# freqs = [] +# for i in range(3): +# freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) +# freqs.append(freq) + +# freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) +# freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + +# return freqs_cos, freqs_sin + + +class HunyuanVideoRotaryPosEmbed(nn.Module): + def __init__(self, rope_dim, theta): super().__init__() - - self.patch_size = patch_size - self.patch_size_t = patch_size_t - self.rope_dim = rope_dim + self.DT, self.DY, self.DX = rope_dim self.theta = theta - def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device): - frame_indices = frame_indices.unbind(0) - # This is from the original code. We don't call _forward for each batch index because we know that - # each batch has the same frame indices. However, it may be possible that the frame indices don't - # always be the same for every item in a batch (such as in training). We cannot use the original - # implementation because our `apply_rotary_emb` function broadcasts across the batch dim. - # freqs = [self._forward(f, height, width, device) for f in frame_indices] - # freqs_cos, freqs_sin = zip(*freqs) - # freqs_cos = torch.stack(freqs_cos, dim=0) # [B, W * H * T, D / 2] - # freqs_sin = torch.stack(freqs_sin, dim=0) # [B, W * H * T, D / 2] - # return freqs_cos, freqs_sin - return self._forward(frame_indices[0], height, width, device) - - def _forward(self, frame_indices, height, width, device): - height = height // self.patch_size - width = width // self.patch_size - grid = torch.meshgrid( + @torch.no_grad() + def get_frequency(self, dim, pos): + T, H, W = pos.shape + freqs = 1.0 / ( + self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim) + ) + freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0) + return freqs.cos(), freqs.sin() + + @torch.no_grad() + def forward_inner(self, frame_indices, height, width, device): + # TODO(aryan) + height = height // 2 + width = width // 2 + GT, GY, GX = torch.meshgrid( frame_indices.to(device=device, dtype=torch.float32), torch.arange(0, height, device=device, dtype=torch.float32), torch.arange(0, width, device=device, dtype=torch.float32), indexing="ij", - ) # 3 * [W, H, T] - grid = torch.stack(grid, dim=0) # [3, W, H, T] + ) - freqs = [] - for i in range(3): - freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) - freqs.append(freq) + FCT, FST = self.get_frequency(self.DT, GT) + FCY, FSY = self.get_frequency(self.DY, GY) + FCX, FSX = self.get_frequency(self.DX, GX) - freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) - freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0) - return freqs_cos, freqs_sin + return result.to(device) + + @torch.no_grad() + def forward(self, frame_indices, height, width, device): + return self.forward_inner(frame_indices[0], height, width, device).unsqueeze(0) + # frame_indices = frame_indices.unbind(0) + # results = [self.forward_inner(f, height, width, device) for f in frame_indices] + # results = torch.stack(results, dim=0) + # return results class FramepackClipVisionProjection(nn.Module): @@ -173,7 +216,8 @@ def __init__( ) # 2. RoPE - self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + # self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta) # 3. Dual stream transformer blocks self.transformer_blocks = nn.ModuleList( @@ -280,10 +324,14 @@ def forward( effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] effective_sequence_length = latent_sequence_length + effective_condition_sequence_length - for i in range(batch_size): - attention_mask[i, : effective_sequence_length[i]] = True - # [B, 1, 1, N], for broadcasting across attention heads - attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + if batch_size == 1: + encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]] + attention_mask = None + else: + for i in range(batch_size): + attention_mask[i, : effective_sequence_length[i]] = True + # [B, 1, 1, N], for broadcasting across attention heads + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.transformer_blocks: @@ -345,8 +393,7 @@ def _pack_history_states( image_rotary_emb = self.rope( frame_indices=indices_latents, height=height, width=width, device=hidden_states.device ) - image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification - pph, ppw = height // self.config.patch_size, width // self.config.patch_size + image_rotary_emb = image_rotary_emb.flatten(2).transpose(1, 2) latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder( latents_clean, latents_history_2x, latents_history_4x @@ -358,8 +405,8 @@ def _pack_history_states( image_rotary_emb_clean = self.rope( frame_indices=indices_latents_clean, height=height, width=width, device=latents_clean.device ) - image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0) - image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0) + image_rotary_emb_clean = image_rotary_emb_clean.flatten(2).transpose(1, 2) + image_rotary_emb = torch.cat([image_rotary_emb_clean, image_rotary_emb], dim=1) if latents_history_2x is not None and indices_latents_history_2x is not None: hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1) @@ -367,11 +414,10 @@ def _pack_history_states( image_rotary_emb_history_2x = self.rope( frame_indices=indices_latents_history_2x, height=height, width=width, device=latents_history_2x.device ) - image_rotary_emb_history_2x = self._pad_rotary_emb( - image_rotary_emb_history_2x, indices_latents_history_2x.size(1), pph, ppw, (2, 2, 2) - ) - image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0) - image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0) + image_rotary_emb_history_2x = _pad_for_3d_conv(image_rotary_emb_history_2x, (2, 2, 2)) + image_rotary_emb_history_2x = _center_down_sample_3d(image_rotary_emb_history_2x, (2, 2, 2)) + image_rotary_emb_history_2x = image_rotary_emb_history_2x.flatten(2).transpose(1, 2) + image_rotary_emb = torch.cat([image_rotary_emb_history_2x, image_rotary_emb], dim=1) if latents_history_4x is not None and indices_latents_history_4x is not None: hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1) @@ -379,13 +425,73 @@ def _pack_history_states( image_rotary_emb_history_4x = self.rope( frame_indices=indices_latents_history_4x, height=height, width=width, device=latents_history_4x.device ) - image_rotary_emb_history_4x = self._pad_rotary_emb( - image_rotary_emb_history_4x, indices_latents_history_4x.size(1), pph, ppw, (4, 4, 4) - ) - image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0) - image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0) - - return hidden_states, image_rotary_emb + image_rotary_emb_history_4x = _pad_for_3d_conv(image_rotary_emb_history_4x, (4, 4, 4)) + image_rotary_emb_history_4x = _center_down_sample_3d(image_rotary_emb_history_4x, (4, 4, 4)) + image_rotary_emb_history_4x = image_rotary_emb_history_4x.flatten(2).transpose(1, 2) + image_rotary_emb = torch.cat([image_rotary_emb_history_4x, image_rotary_emb], dim=1) + + return hidden_states, image_rotary_emb.squeeze(0).chunk(2, dim=-1) + + # def _pack_history_states( + # self, + # hidden_states: torch.Tensor, + # indices_latents: torch.Tensor, + # latents_clean: Optional[torch.Tensor] = None, + # latents_history_2x: Optional[torch.Tensor] = None, + # latents_history_4x: Optional[torch.Tensor] = None, + # indices_latents_clean: Optional[torch.Tensor] = None, + # indices_latents_history_2x: Optional[torch.Tensor] = None, + # indices_latents_history_4x: Optional[torch.Tensor] = None, + # ): + # batch_size, num_channels, num_frames, height, width = hidden_states.shape + # if indices_latents is None: + # indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1) + + # hidden_states = self.x_embedder(hidden_states) + # image_rotary_emb = self.rope( + # frame_indices=indices_latents, height=height, width=width, device=hidden_states.device + # ) + # image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification + # pph, ppw = height // self.config.patch_size, width // self.config.patch_size + + # latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder( + # latents_clean, latents_history_2x, latents_history_4x + # ) + + # if latents_clean is not None: + # hidden_states = torch.cat([latents_clean, hidden_states], dim=1) + + # image_rotary_emb_clean = self.rope( + # frame_indices=indices_latents_clean, height=height, width=width, device=latents_clean.device + # ) + # image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0) + # image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0) + + # if latents_history_2x is not None and indices_latents_history_2x is not None: + # hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1) + + # image_rotary_emb_history_2x = self.rope( + # frame_indices=indices_latents_history_2x, height=height, width=width, device=latents_history_2x.device + # ) + # image_rotary_emb_history_2x = self._pad_rotary_emb( + # image_rotary_emb_history_2x, indices_latents_history_2x.size(1), pph, ppw, (2, 2, 2) + # ) + # image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0) + # image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0) + + # if latents_history_4x is not None and indices_latents_history_4x is not None: + # hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1) + + # image_rotary_emb_history_4x = self.rope( + # frame_indices=indices_latents_history_4x, height=height, width=width, device=latents_history_4x.device + # ) + # image_rotary_emb_history_4x = self._pad_rotary_emb( + # image_rotary_emb_history_4x, indices_latents_history_4x.size(1), pph, ppw, (4, 4, 4) + # ) + # image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0) + # image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0) + + # return hidden_states, image_rotary_emb def _pad_rotary_emb( self, diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index 3fb4085d1683..06116f6b46e2 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -551,7 +551,6 @@ def __call__( guidance_scale: float = 6.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, image_latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, @@ -614,10 +613,8 @@ def __call__( generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. + image_latents (`torch.Tensor`, *optional*): + Pre-encoded image latents. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. @@ -767,10 +764,9 @@ def __call__( guidance = torch.tensor([guidance_scale] * batch_size, dtype=transformer_dtype, device=device) * 1000.0 # 7. Denoising loop - for i in range(num_latent_sections): - current_latent_padding = latent_paddings[i] - is_last_section = current_latent_padding == 0 - latent_padding_size = current_latent_padding * latent_window_size + for k in range(num_latent_sections): + is_last_section = latent_paddings[k] == 0 + latent_padding_size = latent_paddings[k] * latent_window_size indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes])).unsqueeze(0) ( @@ -799,12 +795,12 @@ def __call__( dtype=torch.float32, device=device, generator=generator, - latents=latents, + latents=None, ) sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas image_seq_len = ( - latents.shape[1] * latents.shape[2] * latents.shape[3] / self.transformer.config.patch_size**2 + latents.shape[2] * latents.shape[3] * latents.shape[4] / self.transformer.config.patch_size**2 ) exp_max = 7.0 mu = calculate_shift( @@ -887,7 +883,7 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - + if is_last_section: latents = torch.cat([image_latents, latents], dim=2) From 65c24f685f51bb889f6e18b0f7cc8d901e6e06a1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 11:05:53 +0200 Subject: [PATCH 09/19] cleanup --- .../transformer_hunyuan_video_framepack.py | 222 +++++------------- .../pipeline_hunyuan_video_framepack.py | 8 +- 2 files changed, 64 insertions(+), 166 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index 0bfc1ac39b51..026612ad6b9a 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -22,6 +22,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers from ..cache_utils import CacheMixin +from ..embeddings import get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous @@ -37,92 +38,49 @@ logger = get_logger(__name__) # pylint: disable=invalid-name -# class HunyuanVideoFramepackRotaryPosEmbed(nn.Module): -# def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: -# super().__init__() - -# self.patch_size = patch_size -# self.patch_size_t = patch_size_t -# self.rope_dim = rope_dim -# self.theta = theta - -# def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device): -# frame_indices = frame_indices.unbind(0) -# # This is from the original code. We don't call _forward for each batch index because we know that -# # each batch has the same frame indices. However, it may be possible that the frame indices don't -# # always be the same for every item in a batch (such as in training). We cannot use the original -# # implementation because our `apply_rotary_emb` function broadcasts across the batch dim. -# # freqs = [self._forward(f, height, width, device) for f in frame_indices] -# # freqs_cos, freqs_sin = zip(*freqs) -# # freqs_cos = torch.stack(freqs_cos, dim=0) # [B, W * H * T, D / 2] -# # freqs_sin = torch.stack(freqs_sin, dim=0) # [B, W * H * T, D / 2] -# # return freqs_cos, freqs_sin -# return self._forward(frame_indices[0], height, width, device) - -# def _forward(self, frame_indices, height, width, device): -# height = height // self.patch_size -# width = width // self.patch_size -# grid = torch.meshgrid( -# frame_indices.to(device=device, dtype=torch.float32), -# torch.arange(0, height, device=device, dtype=torch.float32), -# torch.arange(0, width, device=device, dtype=torch.float32), -# indexing="ij", -# ) # 3 * [W, H, T] -# grid = torch.stack(grid, dim=0) # [3, W, H, T] - -# freqs = [] -# for i in range(3): -# freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) -# freqs.append(freq) - -# freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) -# freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) - -# return freqs_cos, freqs_sin - - -class HunyuanVideoRotaryPosEmbed(nn.Module): - def __init__(self, rope_dim, theta): +class HunyuanVideoFramepackRotaryPosEmbed(nn.Module): + def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: super().__init__() - self.DT, self.DY, self.DX = rope_dim + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.rope_dim = rope_dim self.theta = theta - @torch.no_grad() - def get_frequency(self, dim, pos): - T, H, W = pos.shape - freqs = 1.0 / ( - self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim) - ) - freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0) - return freqs.cos(), freqs.sin() - - @torch.no_grad() - def forward_inner(self, frame_indices, height, width, device): - # TODO(aryan) - height = height // 2 - width = width // 2 - GT, GY, GX = torch.meshgrid( + def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device): + # This is from the original code. We don't call _forward for each batch index because we know that + # each batch has the same frame indices. However, it may be possible that the frame indices don't + # always be the same for every item in a batch (such as in training). We cannot use the original + # implementation because our `apply_rotary_emb` function broadcasts across the batch dim, so we'd + # need to first implement another attention processor or modify the existing one with different apply_rotary_emb + # frame_indices = frame_indices.unbind(0) + # freqs = [self._forward(f, height, width, device) for f in frame_indices] + # freqs_cos, freqs_sin = zip(*freqs) + # freqs_cos = torch.stack(freqs_cos, dim=0) # [B, W * H * T, D / 2] + # freqs_sin = torch.stack(freqs_sin, dim=0) # [B, W * H * T, D / 2] + # return freqs_cos, freqs_sin + return self._forward(frame_indices, height, width, device) + + def _forward(self, frame_indices, height, width, device): + height = height // self.patch_size + width = width // self.patch_size + grid = torch.meshgrid( frame_indices.to(device=device, dtype=torch.float32), torch.arange(0, height, device=device, dtype=torch.float32), torch.arange(0, width, device=device, dtype=torch.float32), indexing="ij", - ) - - FCT, FST = self.get_frequency(self.DT, GT) - FCY, FSY = self.get_frequency(self.DY, GY) - FCX, FSX = self.get_frequency(self.DX, GX) + ) # 3 * [W, H, T] + grid = torch.stack(grid, dim=0) # [3, W, H, T] - result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0) + freqs = [] + for i in range(3): + freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) + freqs.append(freq) - return result.to(device) + freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) + freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) - @torch.no_grad() - def forward(self, frame_indices, height, width, device): - return self.forward_inner(frame_indices[0], height, width, device).unsqueeze(0) - # frame_indices = frame_indices.unbind(0) - # results = [self.forward_inner(f, height, width, device) for f in frame_indices] - # results = torch.stack(results, dim=0) - # return results + return freqs_cos, freqs_sin class FramepackClipVisionProjection(nn.Module): @@ -216,8 +174,7 @@ def __init__( ) # 2. RoPE - # self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) - self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta) + self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) # 3. Dual stream transformer blocks self.transformer_blocks = nn.ModuleList( @@ -320,18 +277,17 @@ def forward( attention_mask = torch.zeros( batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool ) # [B, N] - effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] effective_sequence_length = latent_sequence_length + effective_condition_sequence_length - if batch_size == 1: - encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]] - attention_mask = None - else: - for i in range(batch_size): - attention_mask[i, : effective_sequence_length[i]] = True - # [B, 1, 1, N], for broadcasting across attention heads - attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + # if batch_size == 1: + # encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]] + # attention_mask = None + # else: + for i in range(batch_size): + attention_mask[i, : effective_sequence_length[i]] = True + # [B, 1, 1, N], for broadcasting across attention heads + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.transformer_blocks: @@ -393,7 +349,8 @@ def _pack_history_states( image_rotary_emb = self.rope( frame_indices=indices_latents, height=height, width=width, device=hidden_states.device ) - image_rotary_emb = image_rotary_emb.flatten(2).transpose(1, 2) + image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification + pph, ppw = height // self.config.patch_size, width // self.config.patch_size latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder( latents_clean, latents_history_2x, latents_history_4x @@ -405,8 +362,8 @@ def _pack_history_states( image_rotary_emb_clean = self.rope( frame_indices=indices_latents_clean, height=height, width=width, device=latents_clean.device ) - image_rotary_emb_clean = image_rotary_emb_clean.flatten(2).transpose(1, 2) - image_rotary_emb = torch.cat([image_rotary_emb_clean, image_rotary_emb], dim=1) + image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0) + image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0) if latents_history_2x is not None and indices_latents_history_2x is not None: hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1) @@ -414,10 +371,11 @@ def _pack_history_states( image_rotary_emb_history_2x = self.rope( frame_indices=indices_latents_history_2x, height=height, width=width, device=latents_history_2x.device ) - image_rotary_emb_history_2x = _pad_for_3d_conv(image_rotary_emb_history_2x, (2, 2, 2)) - image_rotary_emb_history_2x = _center_down_sample_3d(image_rotary_emb_history_2x, (2, 2, 2)) - image_rotary_emb_history_2x = image_rotary_emb_history_2x.flatten(2).transpose(1, 2) - image_rotary_emb = torch.cat([image_rotary_emb_history_2x, image_rotary_emb], dim=1) + image_rotary_emb_history_2x = self._pad_rotary_emb( + image_rotary_emb_history_2x, indices_latents_history_2x.size(0), pph, ppw, (2, 2, 2) + ) + image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0) + image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0) if latents_history_4x is not None and indices_latents_history_4x is not None: hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1) @@ -425,73 +383,13 @@ def _pack_history_states( image_rotary_emb_history_4x = self.rope( frame_indices=indices_latents_history_4x, height=height, width=width, device=latents_history_4x.device ) - image_rotary_emb_history_4x = _pad_for_3d_conv(image_rotary_emb_history_4x, (4, 4, 4)) - image_rotary_emb_history_4x = _center_down_sample_3d(image_rotary_emb_history_4x, (4, 4, 4)) - image_rotary_emb_history_4x = image_rotary_emb_history_4x.flatten(2).transpose(1, 2) - image_rotary_emb = torch.cat([image_rotary_emb_history_4x, image_rotary_emb], dim=1) - - return hidden_states, image_rotary_emb.squeeze(0).chunk(2, dim=-1) - - # def _pack_history_states( - # self, - # hidden_states: torch.Tensor, - # indices_latents: torch.Tensor, - # latents_clean: Optional[torch.Tensor] = None, - # latents_history_2x: Optional[torch.Tensor] = None, - # latents_history_4x: Optional[torch.Tensor] = None, - # indices_latents_clean: Optional[torch.Tensor] = None, - # indices_latents_history_2x: Optional[torch.Tensor] = None, - # indices_latents_history_4x: Optional[torch.Tensor] = None, - # ): - # batch_size, num_channels, num_frames, height, width = hidden_states.shape - # if indices_latents is None: - # indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1) - - # hidden_states = self.x_embedder(hidden_states) - # image_rotary_emb = self.rope( - # frame_indices=indices_latents, height=height, width=width, device=hidden_states.device - # ) - # image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification - # pph, ppw = height // self.config.patch_size, width // self.config.patch_size - - # latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder( - # latents_clean, latents_history_2x, latents_history_4x - # ) - - # if latents_clean is not None: - # hidden_states = torch.cat([latents_clean, hidden_states], dim=1) - - # image_rotary_emb_clean = self.rope( - # frame_indices=indices_latents_clean, height=height, width=width, device=latents_clean.device - # ) - # image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0) - # image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0) - - # if latents_history_2x is not None and indices_latents_history_2x is not None: - # hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1) - - # image_rotary_emb_history_2x = self.rope( - # frame_indices=indices_latents_history_2x, height=height, width=width, device=latents_history_2x.device - # ) - # image_rotary_emb_history_2x = self._pad_rotary_emb( - # image_rotary_emb_history_2x, indices_latents_history_2x.size(1), pph, ppw, (2, 2, 2) - # ) - # image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0) - # image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0) - - # if latents_history_4x is not None and indices_latents_history_4x is not None: - # hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1) - - # image_rotary_emb_history_4x = self.rope( - # frame_indices=indices_latents_history_4x, height=height, width=width, device=latents_history_4x.device - # ) - # image_rotary_emb_history_4x = self._pad_rotary_emb( - # image_rotary_emb_history_4x, indices_latents_history_4x.size(1), pph, ppw, (4, 4, 4) - # ) - # image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0) - # image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0) - - # return hidden_states, image_rotary_emb + image_rotary_emb_history_4x = self._pad_rotary_emb( + image_rotary_emb_history_4x, indices_latents_history_4x.size(0), pph, ppw, (4, 4, 4) + ) + image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0) + image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0) + + return hidden_states, image_rotary_emb def _pad_rotary_emb( self, diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index 06116f6b46e2..c04527ed6003 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -768,7 +768,7 @@ def __call__( is_last_section = latent_paddings[k] == 0 latent_padding_size = latent_paddings[k] * latent_window_size - indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes])).unsqueeze(0) + indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes])) ( indices_prefix, indices_padding, @@ -776,9 +776,9 @@ def __call__( indices_postfix, indices_latents_history_2x, indices_latents_history_4x, - ) = indices.split([1, latent_padding_size, latent_window_size, *history_sizes], dim=1) + ) = indices.split([1, latent_padding_size, latent_window_size, *history_sizes], dim=0) # Inverted anti-drifting sampling: Figure 2(c) in the paper - indices_clean_latents = torch.cat([indices_prefix, indices_postfix], dim=1) + indices_clean_latents = torch.cat([indices_prefix, indices_postfix], dim=0) latents_prefix = image_latents latents_postfix, latents_history_2x, latents_history_4x = history_latents[ @@ -883,7 +883,7 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - + if is_last_section: latents = torch.cat([image_latents, latents], dim=2) From 68a7fa61f37d7729bdf168c0298262bb963e7c1f Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 11:06:25 +0200 Subject: [PATCH 10/19] batch_size=1 optimization --- .../transformer_hunyuan_video_framepack.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index 026612ad6b9a..a6df4b2571e2 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -280,14 +280,14 @@ def forward( effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] effective_sequence_length = latent_sequence_length + effective_condition_sequence_length - # if batch_size == 1: - # encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]] - # attention_mask = None - # else: - for i in range(batch_size): - attention_mask[i, : effective_sequence_length[i]] = True - # [B, 1, 1, N], for broadcasting across attention heads - attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + if batch_size == 1: + encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]] + attention_mask = None + else: + for i in range(batch_size): + attention_mask[i, : effective_sequence_length[i]] = True + # [B, 1, 1, N], for broadcasting across attention heads + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.transformer_blocks: @@ -311,7 +311,6 @@ def forward( hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb ) - # 5. Output projection hidden_states = hidden_states[:, -original_context_length:] hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From 8df1710febef22d2ae43277522455972e0301979 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 11:38:58 +0200 Subject: [PATCH 11/19] add pipeline test --- .../transformer_hunyuan_video_framepack.py | 12 +- .../test_hunyuan_video_framepack.py | 376 ++++++++++++++++++ 2 files changed, 382 insertions(+), 6 deletions(-) create mode 100644 tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index a6df4b2571e2..e9aaafd4f4f1 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -97,11 +97,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class HunyuanVideoHistoryPatchEmbed(nn.Module): - def __init__(self, inner_dim: int): + def __init__(self, in_channels: int, inner_dim: int): super().__init__() - self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) - self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) - self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + self.proj = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) def forward( self, @@ -131,7 +131,7 @@ class HunyuanVideoFramepackTransformer3DModel( _no_split_modules = [ "HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock", - "HunyuanVideoPatchEmbedForCleanLatents", # TODO + "HunyuanVideoHistoryPatchEmbed", "HunyuanVideoTokenRefiner", ] @@ -205,7 +205,7 @@ def __init__( self.clean_x_embedder = None if has_clean_x_embedder: - self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(inner_dim) + self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim) self.use_gradient_checkpointing = False diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py b/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py new file mode 100644 index 000000000000..b4e296669b32 --- /dev/null +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py @@ -0,0 +1,376 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import ( + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + LlamaConfig, + LlamaModel, + LlamaTokenizer, + SiglipImageProcessor, + SiglipVisionModel, +) + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FasterCacheConfig, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoFramepackPipeline, + HunyuanVideoFramepackTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_pipelines_common import ( + FasterCacheTesterMixin, + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + to_np, +) + + +enable_full_determinism() + + +class HunyuanVideoFramepackPipelineFastTests( + PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase +): + pipeline_class = HunyuanVideoFramepackPipeline + params = frozenset( + ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"] + ) + batch_params = frozenset(["image", "prompt"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + + supports_dduf = False + # there is no xformers processor for Flux + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + faster_cache_config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 901), + unconditional_batch_skip_range=2, + attention_weight_callback=lambda _: 0.5, + is_guidance_distilled=True, + ) + + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): + torch.manual_seed(0) + transformer = HunyuanVideoFramepackTransformer3DModel( + in_channels=4, + out_channels=4, + num_attention_heads=2, + attention_head_dim=10, + num_layers=num_layers, + num_single_layers=num_single_layers, + num_refiner_layers=1, + patch_size=2, + patch_size_t=1, + guidance_embeds=True, + text_embed_dim=16, + pooled_projection_dim=8, + rope_axes_dim=(2, 4, 4), + image_condition_type=None, + has_image_proj=True, + image_proj_dim=32, + has_clean_x_embedder=True, + ) + + torch.manual_seed(0) + vae = AutoencoderKLHunyuanVideo( + in_channels=3, + out_channels=3, + latent_channels=4, + down_block_types=( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + up_block_types=( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + layers_per_block=1, + act_fn="silu", + norm_num_groups=4, + scaling_factor=0.476986, + spatial_compression_ratio=8, + temporal_compression_ratio=4, + mid_block_add_attention=True, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + llama_text_encoder_config = LlamaConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=16, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=8, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = LlamaModel(llama_text_encoder_config) + tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer") + + torch.manual_seed(0) + text_encoder_2 = CLIPTextModel(clip_text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + feature_extractor = SiglipImageProcessor.from_pretrained( + "hf-internal-testing/tiny-random-SiglipVisionModel", size={"height": 30, "width": 30} + ) + image_encoder = SiglipVisionModel.from_pretrained( + "hf-internal-testing/tiny-random-SiglipVisionModel", torch_dtype=torch.float16 + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "feature_extractor": feature_extractor, + "image_encoder": image_encoder, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image_height = 32 + image_width = 32 + image = Image.new("RGB", (image_width, image_height)) + inputs = { + "image": image, + "prompt": "dance monkey", + "prompt_template": { + "template": "{}", + "crop_start": 0, + }, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.5, + "height": image_height, + "width": image_width, + "num_frames": 9, + "latent_window_size": 3, + "max_sequence_length": 256, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (13, 3, 32, 32)) + expected_video = torch.randn(13, 3, 32, 32) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + # Seems to require higher tolerance than the other tests + expected_diff_max = 0.6 + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + # TODO(aryan): Create a dummy gemma model with smol vocab size + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_consistent(self): + pass + + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_single_identical(self): + pass From ccf593ebba35762d784c348476757bc5f319aebf Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 21:59:50 +0200 Subject: [PATCH 12/19] fix for model cpu offloading' --- .../hunyuan_video/pipeline_hunyuan_video_framepack.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index c04527ed6003..eaf8b21efd11 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -206,7 +206,7 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( @@ -386,12 +386,13 @@ def encode_prompt( def encode_image( self, image: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ): + device = device or self._execution_device image = (image + 1) / 2.0 # [-1, 1] -> [0, 1] image = self.feature_extractor(images=image, return_tensors="pt", do_rescale=False).to( - device=self.image_encoder.device, dtype=self.image_encoder.dtype + device=device, dtype=self.image_encoder.dtype ) image_embeds = self.image_encoder(**image).last_hidden_state - return image_embeds.to(device=device, dtype=dtype) + return image_embeds.to(dtype=dtype) def check_inputs( self, @@ -477,8 +478,9 @@ def prepare_image_latents( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: + device = device or self._execution_device if latents is None: - image = image.unsqueeze(2).to(device=self.vae.device, dtype=self.vae.dtype) + image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype) latents = self.vae.encode(image).latent_dist.sample(generator=generator) latents = latents * self.vae.config.scaling_factor return latents.to(device=device, dtype=dtype) From 00707630710d141ced0a8837ce65387309385b01 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 23:34:07 +0200 Subject: [PATCH 13/19] add last_image support; credits: https://github.com/lllyasviel/FramePack/pull/167 --- .../pipeline_hunyuan_video_framepack.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index eaf8b21efd11..ce1c73fe12d5 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -539,6 +539,7 @@ def interrupt(self): def __call__( self, image: PipelineImageInput, + last_image: Optional[PipelineImageInput] = None, prompt: Union[str, List[str]] = None, prompt_2: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, @@ -554,6 +555,7 @@ def __call__( num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, image_latents: Optional[torch.Tensor] = None, + last_image_latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, @@ -574,6 +576,11 @@ def __call__( The call function to the pipeline for generation. Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to be used as the starting point for the video generation. + last_image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`, *optional*): + The optional last image to be used as the ending point for the video generation. This is useful for + generating transitions between two images. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -616,7 +623,9 @@ def __call__( A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. image_latents (`torch.Tensor`, *optional*): - Pre-encoded image latents. + Pre-encoded image latents. If not provided, the image will be encoded using the VAE. + last_image_latents (`torch.Tensor`, *optional*): + Pre-encoded last image latents. If not provided, the last image will be encoded using the VAE. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. @@ -734,6 +743,12 @@ def __call__( # 4. Prepare image image = self.video_processor.preprocess(image, height, width) image_embeds = self.encode_image(image, device=device).to(transformer_dtype) + if last_image is not None: + # Credits: https://github.com/lllyasviel/FramePack/pull/167 + # Users can modify the weighting strategy applied here + last_image = self.video_processor.preprocess(last_image, height, width) + last_image_embeds = self.encode_image(last_image, device=device).to(transformer_dtype) + last_image_embeds = (image_embeds + last_image_embeds) / 2 # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels @@ -757,6 +772,10 @@ def __call__( image_latents = self.prepare_image_latents( image, dtype=torch.float32, device=device, generator=generator, latents=image_latents ) + if last_image is not None: + last_image_latents = self.prepare_image_latents( + last_image, dtype=torch.float32, device=device, generator=generator + ) latent_paddings = list(reversed(range(num_latent_sections))) if num_latent_sections > 4: @@ -767,7 +786,8 @@ def __call__( # 7. Denoising loop for k in range(num_latent_sections): - is_last_section = latent_paddings[k] == 0 + is_first_section = k == 0 + is_last_section = k == num_latent_sections - 1 latent_padding_size = latent_paddings[k] * latent_window_size indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes])) @@ -786,6 +806,8 @@ def __call__( latents_postfix, latents_history_2x, latents_history_4x = history_latents[ :, :, : sum(history_sizes) ].split(history_sizes, dim=2) + if last_image is not None and is_first_section: + latents_postfix = last_image_latents latents_clean = torch.cat([latents_prefix, latents_postfix], dim=2) latents = self.prepare_latents( From 556711842838b0b76e39a9c59362456ef8e03c62 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 23:35:47 +0200 Subject: [PATCH 14/19] update example with flf2v --- .../pipeline_hunyuan_video_framepack.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index ce1c73fe12d5..1a2368cef989 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -51,6 +51,8 @@ EXAMPLE_DOC_STRING = """ Examples: + ##### Image-to-Video + ```python >>> import torch >>> from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel @@ -89,6 +91,53 @@ ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=30) ``` + + ##### First and Last Image-to-Video + + ```python + >>> import torch + >>> from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel + >>> from diffusers.utils import export_to_video, load_image + >>> from transformers import SiglipImageProcessor, SiglipVisionModel + + >>> transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained( + ... "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16 + ... ) + >>> feature_extractor = SiglipImageProcessor.from_pretrained( + ... "lllyasviel/flux_redux_bfl", subfolder="feature_extractor" + ... ) + >>> image_encoder = SiglipVisionModel.from_pretrained( + ... "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16 + ... ) + >>> pipe = HunyuanVideoFramepackPipeline.from_pretrained( + ... "hunyuanvideo-community/HunyuanVideo", + ... transformer=transformer, + ... feature_extractor=feature_extractor, + ... image_encoder=image_encoder, + ... torch_dtype=torch.float16, + ... ) + >>> pipe.to("cuda") + + >>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." + >>> first_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png" + ... ) + >>> last_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png" + ... ) + >>> output = pipe( + ... image=first_image, + ... last_image=last_image, + ... prompt=prompt, + ... height=512, + ... width=512, + ... num_frames=91, + ... num_inference_steps=30, + ... guidance_scale=9.0, + ... generator=torch.Generator().manual_seed(0), + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=30) + ``` """ From dbe378dfed518470ae23b0f185bbf593a4a1924e Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 23:38:46 +0200 Subject: [PATCH 15/19] update penguin url --- .../hunyuan_video/pipeline_hunyuan_video_framepack.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index 1a2368cef989..4fc338c4d784 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -78,7 +78,9 @@ >>> pipe.vae.enable_tiling() >>> pipe.to("cuda") - >>> image = load_image("inputs/penguin.png") + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png" + ... ) >>> output = pipe( ... image=image, ... prompt="A penguin dancing in the snow", From 1c1c2c02a48dd8c2fbd453e6079d5978a005b8dd Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Apr 2025 23:44:24 +0200 Subject: [PATCH 16/19] fix test --- tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py b/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py index b4e296669b32..f4408e7cd5ae 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py @@ -174,9 +174,7 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): feature_extractor = SiglipImageProcessor.from_pretrained( "hf-internal-testing/tiny-random-SiglipVisionModel", size={"height": 30, "width": 30} ) - image_encoder = SiglipVisionModel.from_pretrained( - "hf-internal-testing/tiny-random-SiglipVisionModel", torch_dtype=torch.float16 - ) + image_encoder = SiglipVisionModel.from_pretrained("hf-internal-testing/tiny-random-SiglipVisionModel") components = { "transformer": transformer, From 2f9efa970e930758c4fa44e55a43611414605e87 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 5 May 2025 12:20:58 +0200 Subject: [PATCH 17/19] address review comment: https://github.com/huggingface/diffusers/pull/11428#discussion_r2071032371 --- .../transformer_hunyuan_video_framepack.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index e9aaafd4f4f1..4f600b47d552 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -48,20 +48,6 @@ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], thet self.theta = theta def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device): - # This is from the original code. We don't call _forward for each batch index because we know that - # each batch has the same frame indices. However, it may be possible that the frame indices don't - # always be the same for every item in a batch (such as in training). We cannot use the original - # implementation because our `apply_rotary_emb` function broadcasts across the batch dim, so we'd - # need to first implement another attention processor or modify the existing one with different apply_rotary_emb - # frame_indices = frame_indices.unbind(0) - # freqs = [self._forward(f, height, width, device) for f in frame_indices] - # freqs_cos, freqs_sin = zip(*freqs) - # freqs_cos = torch.stack(freqs_cos, dim=0) # [B, W * H * T, D / 2] - # freqs_sin = torch.stack(freqs_sin, dim=0) # [B, W * H * T, D / 2] - # return freqs_cos, freqs_sin - return self._forward(frame_indices, height, width, device) - - def _forward(self, frame_indices, height, width, device): height = height // self.patch_size width = width // self.patch_size grid = torch.meshgrid( From 17a57b2f6baf26fa6f25857c740769cd20e5d3a4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 5 May 2025 13:10:19 +0200 Subject: [PATCH 18/19] address review comment: https://github.com/huggingface/diffusers/pull/11428#discussion_r2071087689 --- .../transformer_hunyuan_video_framepack.py | 91 +++++++++---------- 1 file changed, 45 insertions(+), 46 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index 4f600b47d552..58b811569403 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -236,15 +236,42 @@ def forward( post_patch_width = width // p original_context_length = post_patch_num_frames * post_patch_height * post_patch_width + if indices_latents is None: + indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1) + + hidden_states = self.x_embedder(hidden_states) + image_rotary_emb = self.rope( + frame_indices=indices_latents, height=height, width=width, device=hidden_states.device + ) + + latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder( + latents_clean, latents_history_2x, latents_history_4x + ) + + if latents_clean is not None and indices_latents_clean is not None: + image_rotary_emb_clean = self.rope( + frame_indices=indices_latents_clean, height=height, width=width, device=hidden_states.device + ) + if latents_history_2x is not None and indices_latents_history_2x is not None: + image_rotary_emb_history_2x = self.rope( + frame_indices=indices_latents_history_2x, height=height, width=width, device=hidden_states.device + ) + if latents_history_4x is not None and indices_latents_history_4x is not None: + image_rotary_emb_history_4x = self.rope( + frame_indices=indices_latents_history_4x, height=height, width=width, device=hidden_states.device + ) + hidden_states, image_rotary_emb = self._pack_history_states( hidden_states, - indices_latents, latents_clean, latents_history_2x, latents_history_4x, - indices_latents_clean, - indices_latents_history_2x, - indices_latents_history_4x, + image_rotary_emb, + image_rotary_emb_clean, + image_rotary_emb_history_2x, + image_rotary_emb_history_4x, + post_patch_height, + post_patch_width, ) temb, _ = self.time_text_embed(timestep, pooled_projections, guidance) @@ -318,76 +345,48 @@ def forward( def _pack_history_states( self, hidden_states: torch.Tensor, - indices_latents: torch.Tensor, latents_clean: Optional[torch.Tensor] = None, latents_history_2x: Optional[torch.Tensor] = None, latents_history_4x: Optional[torch.Tensor] = None, - indices_latents_clean: Optional[torch.Tensor] = None, - indices_latents_history_2x: Optional[torch.Tensor] = None, - indices_latents_history_4x: Optional[torch.Tensor] = None, + image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None, + image_rotary_emb_clean: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + image_rotary_emb_history_2x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + image_rotary_emb_history_4x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + height: int = None, + width: int = None, ): - batch_size, num_channels, num_frames, height, width = hidden_states.shape - if indices_latents is None: - indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1) - - hidden_states = self.x_embedder(hidden_states) - image_rotary_emb = self.rope( - frame_indices=indices_latents, height=height, width=width, device=hidden_states.device - ) image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification - pph, ppw = height // self.config.patch_size, width // self.config.patch_size - latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder( - latents_clean, latents_history_2x, latents_history_4x - ) - - if latents_clean is not None: + if latents_clean is not None and image_rotary_emb_clean is not None: hidden_states = torch.cat([latents_clean, hidden_states], dim=1) - - image_rotary_emb_clean = self.rope( - frame_indices=indices_latents_clean, height=height, width=width, device=latents_clean.device - ) image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0) image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0) - if latents_history_2x is not None and indices_latents_history_2x is not None: + if latents_history_2x is not None and image_rotary_emb_history_2x is not None: hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1) - - image_rotary_emb_history_2x = self.rope( - frame_indices=indices_latents_history_2x, height=height, width=width, device=latents_history_2x.device - ) - image_rotary_emb_history_2x = self._pad_rotary_emb( - image_rotary_emb_history_2x, indices_latents_history_2x.size(0), pph, ppw, (2, 2, 2) - ) + image_rotary_emb_history_2x = self._pad_rotary_emb(image_rotary_emb_history_2x, height, width, (2, 2, 2)) image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0) image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0) - if latents_history_4x is not None and indices_latents_history_4x is not None: + if latents_history_4x is not None and image_rotary_emb_history_4x is not None: hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1) - - image_rotary_emb_history_4x = self.rope( - frame_indices=indices_latents_history_4x, height=height, width=width, device=latents_history_4x.device - ) - image_rotary_emb_history_4x = self._pad_rotary_emb( - image_rotary_emb_history_4x, indices_latents_history_4x.size(0), pph, ppw, (4, 4, 4) - ) + image_rotary_emb_history_4x = self._pad_rotary_emb(image_rotary_emb_history_4x, height, width, (4, 4, 4)) image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0) image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0) - return hidden_states, image_rotary_emb + return hidden_states, tuple(image_rotary_emb) def _pad_rotary_emb( self, image_rotary_emb: Tuple[torch.Tensor], - num_frames: int, height: int, width: int, kernel_size: Tuple[int, int, int], ): # freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim freqs_cos, freqs_sin = image_rotary_emb - freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (num_frames, height, width)) - freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (num_frames, height, width)) + freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width)) + freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width)) freqs_cos = _pad_for_3d_conv(freqs_cos, kernel_size) freqs_sin = _pad_for_3d_conv(freqs_sin, kernel_size) freqs_cos = _center_down_sample_3d(freqs_cos, kernel_size) From e959a40a0444e244f014a6384e0f1194d0802da5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 6 May 2025 14:22:37 +0530 Subject: [PATCH 19/19] Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py --- .../pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index 4fc338c4d784..43db740c6d7b 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -49,6 +49,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# TODO(yiyi): We can pack the checkpoints nicely with modular loader EXAMPLE_DOC_STRING = """ Examples: ##### Image-to-Video