Skip to content

Hunyuan Video Framepack #11428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
May 6, 2025
Merged

Hunyuan Video Framepack #11428

merged 23 commits into from
May 6, 2025

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Apr 27, 2025

This PR adds support for Framepack: https://github.com/lllyasviel/FramePack

I2V example

import torch
from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
from diffusers.utils import export_to_video, load_image
from transformers import SiglipImageProcessor, SiglipVisionModel

transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
    "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16
)
feature_extractor = SiglipImageProcessor.from_pretrained(
    "lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
)
image_encoder = SiglipVisionModel.from_pretrained(
    "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
)
pipe = HunyuanVideoFramepackPipeline.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    transformer=transformer,
    feature_extractor=feature_extractor,
    image_encoder=image_encoder,
    torch_dtype=torch.float16,
)
pipe.vae.enable_tiling()
pipe.to("cuda")

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png")
output = pipe(
    image=image,
    prompt="A penguin dancing in the snow",
    height=832,
    width=480,
    num_frames=91,
    num_inference_steps=30,
    guidance_scale=9.0,
    generator=torch.Generator().manual_seed(0),
).frames[0]
export_to_video(output, "output.mp4", fps=30)
output.mp4

FLF2V example

Credits: lllyasviel/FramePack#167

import torch
from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
from diffusers.utils import export_to_video, load_image
from transformers import SiglipImageProcessor, SiglipVisionModel

transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
    "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16
)
feature_extractor = SiglipImageProcessor.from_pretrained(
    "lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
)
image_encoder = SiglipVisionModel.from_pretrained(
    "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
)
pipe = HunyuanVideoFramepackPipeline.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    transformer=transformer,
    feature_extractor=feature_extractor,
    image_encoder=image_encoder,
    torch_dtype=torch.float16,
)
pipe.to("cuda")

prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
first_image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
)
last_image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png"
)
output = pipe(
    image=first_image,
    last_image=last_image,
    prompt=prompt,
    height=512,
    width=512,
    num_frames=91,
    num_inference_steps=30,
    guidance_scale=9.0,
    generator=torch.Generator().manual_seed(0),
).frames[0]
export_to_video(output, "output.mp4", fps=30)
output.mp4

TODO:

  • end frame
  • fix test
  • upload penguin to documentation repo

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review April 29, 2025 06:39
@a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu April 29, 2025 09:41
@nitinmukesh
Copy link

nitinmukesh commented Apr 29, 2025

Thank you @a-r-r-o-w

For some reason getting error

NotImplementedError: Cannot copy out of meta tensor; no data!

Replaced
pipe.to("cuda")

with
pipe.enable_sequential_cpu_offload()


and with
pipe.enable_model_cpu_offload()

(sddw-dev) C:\aiOWN\diffuser_webui>python framepack3.py
Fetching 3 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  5.99it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 15.53it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████| 7/7 [00:03<00:00,  1.81it/s]
Traceback (most recent call last):
  File "C:\aiOWN\diffuser_webui\framepack3.py", line 27, in <module>
    output = pipe(
             ^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\pipelines\hunyuan_video\pipeline_hunyuan_video_framepack.py", line 755, in __call__
    image_latents = self.prepare_image_latents(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\pipelines\hunyuan_video\pipeline_hunyuan_video_framepack.py", line 482, in prepare_image_latents
    latents = self.vae.encode(image).latent_dist.sample(generator=generator)
              ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\utils\accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\models\autoencoders\autoencoder_kl_hunyuan_video.py", line 821, in encode
    h = self._encode(x)
        ^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\models\autoencoders\autoencoder_kl_hunyuan_video.py", line 795, in _encode
    return self.tiled_encode(x)
           ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\models\autoencoders\autoencoder_kl_hunyuan_video.py", line 928, in tiled_encode
    tile = self.encoder(tile)
           ^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\models\autoencoders\autoencoder_kl_hunyuan_video.py", line 501, in forward
    hidden_states = self.conv_in(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\models\autoencoders\autoencoder_kl_hunyuan_video.py", line 79, in forward
    return self.conv(hidden_states)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\conv.py", line 725, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\conv.py", line 720, in _conv_forward
    return F.conv3d(
           ^^^^^^^^^
NotImplementedError: Could not run 'aten::slow_conv3d_forward' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::slow_conv3d_forward' is only available for these backends: [CPU, Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

@tin2tin
Copy link

tin2tin commented Apr 29, 2025

Also, bitsandbytes seem to crash a 4090, or I'm doing something wrong:

import torch
from diffusers import BitsAndBytesConfig, HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
from diffusers.utils import export_to_video, load_image
from transformers import SiglipImageProcessor, SiglipVisionModel

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
    "lllyasviel/FramePackI2V_HY",
    quantization_config=nf4_config,
    torch_dtype=torch.bfloat16,
)
feature_extractor = SiglipImageProcessor.from_pretrained(
    "lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
)
image_encoder = SiglipVisionModel.from_pretrained(
    "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
)
pipe = HunyuanVideoFramepackPipeline.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    transformer=transformer,
    feature_extractor=feature_extractor,
    image_encoder=image_encoder,
    torch_dtype=torch.float16,
)
#pipe.vae.enable_tiling()
#pipe.enable_model_cpu_offload()
pipe.to("cuda")

image = load_image("https://i.ibb.co/35CWK8rv/pinguin.png")
output = pipe(
    image=image,
    prompt="A penguin dancing in the snow",
    height=832,
    width=480,
    num_frames=91,
    num_inference_steps=30,
    guidance_scale=9.0,
    generator=torch.Generator().manual_seed(0),
).frames[0]
export_to_video(output, "output.mp4", fps=30)

@a-r-r-o-w
Copy link
Member Author

@nitinmukesh The latest commit should fix the problem you're facing with enable_model_cpu_offload. I'm taking a look at sequential offloading now

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Apr 29, 2025

@tin2tin Tested your code with the latest commit and seems to run for both full-cuda and cpu-offload. Could you try again?

output.mp4

@a-r-r-o-w
Copy link
Member Author

@nitinmukesh Unable to get the sequential offloading to work at the moment.


It fails in the output projection layer of attention in siglip's pooling head:

  • pooling head init: link
  • multihead attention init from torch: link
  • layer where the error occurs in torch: link
reproducer
import torch
from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
from diffusers.utils import export_to_video, load_image
from transformers import SiglipImageProcessor, SiglipVisionModel

transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained("lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16)
feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder="feature_extractor")
image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16)
pipe = HunyuanVideoFramepackPipeline.from_pretrained("hunyuanvideo-community/HunyuanVideo", transformer=transformer, feature_extractor=feature_extractor, image_encoder=image_encoder, torch_dtype=torch.float16)
pipe.enable_sequential_cpu_offload()

image = load_image("inputs/penguin.png")
output = pipe(
    image=image,
    prompt="A penguin dancing in the snow",
    height=832,
    width=480,
    num_frames=31,
    num_inference_steps=2,
    guidance_scale=9.0,
    generator=torch.Generator().manual_seed(0),
).frames[0]
export_to_video(output, "output.mp4", fps=30)
stack trace
Traceback (most recent call last):
  File "/home/aryan/work/diffusers/workflows/integrations/framepack_hunyuan_video/test_pipeline.py", line 41, in <module>
    output = pipe(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py", line 736, in __call__
    image_embeds = self.encode_image(image, device=device).to(transformer_dtype)
  File "/home/aryan/work/diffusers/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py", line 394, in encode_image
    image_embeds = self.image_encoder(**image).last_hidden_state
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/aryan/accelerate-dev/src/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/transformers/utils/generic.py", line 965, in wrapper
    output = func(self, *args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1030, in forward
    return self.vision_model(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/transformers/utils/generic.py", line 965, in wrapper
    output = func(self, *args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 944, in forward
    pooler_output = self.head(last_hidden_state) if self.use_head else None
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/aryan/accelerate-dev/src/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 969, in forward
    hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/aryan/accelerate-dev/src/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1373, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/functional.py", line 6383, in multi_head_attention_forward
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 291, in _fn
    result = fn(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 83, in inner
    r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 1460, in addmm
    return out + beta * self
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 291, in _fn
    result = fn(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 143, in _fn
    result = fn(**bound.arguments)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_refs/__init__.py", line 1145, in add
    output = prims.add(a, b)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_library/fake_impl.py", line 95, in meta_kernel
    return fake_impl_holder.kernel(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_library/utils.py", line 31, in __call__
    return self.func(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/library.py", line 1193, in inner
    return func(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 592, in fake_impl
    return self._abstract_fn(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_prims/__init__.py", line 403, in _prim_elementwise_meta
    utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 764, in check_same_device
    raise RuntimeError(msg)
RuntimeError: Tensor on device meta is not on the expected device cuda:0!

I'm not sure what to look into specifically since it could be an accelerate unhandled case or the layer implementation is not compatible with sequential offloading. Sequential offloading requires the forward method to be invoked for the pre-forward hook to move the weights to correct device, but as can be seen in the torch implementation, F.multi_head_attention_forward is called with the linear weights instead of the linear's forward. IIRC, the same problem occurs when applying sequential/group offloading to HunyuanDiT because it does something similar:

x, _ = F.multi_head_attention_forward(

cc @SunMarc

@tin2tin
Copy link

tin2tin commented Apr 29, 2025

It seems to be doing the inference with tiling and cpu offload:
Fetching 3 files: 100%|██████████████████████████████████████████████████████████████████████████| 3/3 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 3/3 [00:19<00:00, 6.57s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 63.47it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████| 7/7 [00:01<00:00, 6.19it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [02:43<00:00, 5.44s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [02:48<00:00, 5.60s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [02:46<00:00, 5.57s/it]

Forgot to change the path to something I can find on my computer, so I need to rerun it.

@a-r-r-o-w
Copy link
Member Author

Would recommend installing and using imageio instead of opencv, since it is planned to be deprecated soon

opencv-python==4.10.0.84
imageio==2.35.1
imageio[ffmpeg]==0.5.1

@tin2tin
Copy link

tin2tin commented Apr 29, 2025

Yes, it is working here too: 8 minutes and 51 seconds. Thank you!
Looking forward to trying flf2v (but too late now).

output.mp4

@tin2tin
Copy link

tin2tin commented Apr 30, 2025

flf2v example is producing a video with bitsandbytes too (loading: 16 GB, inference: 12 GB, 4 minutes and 43 seconds):

import torch
from diffusers import (
    BitsAndBytesConfig, 
    HunyuanVideoFramepackPipeline,
    HunyuanVideoFramepackTransformer3DModel,
)
from diffusers.utils import export_to_video, load_image
from transformers import SiglipImageProcessor, SiglipVisionModel

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
    "lllyasviel/FramePackI2V_HY",
    quantization_config=nf4_config,
    torch_dtype=torch.bfloat16,
)

feature_extractor = SiglipImageProcessor.from_pretrained(
    "lllyasviel/flux_redux_bfl", subfolder="feature_extractor", 
)

image_encoder = SiglipVisionModel.from_pretrained(
    "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
)

pipe = HunyuanVideoFramepackPipeline.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    transformer=transformer,
    feature_extractor=feature_extractor,
    image_encoder=image_encoder,
    torch_dtype=torch.float16,
)

pipe.vae.enable_tiling()
pipe.enable_model_cpu_offload()
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
first_image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
)

last_image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png"
)

output = pipe(
    image=first_image,
    last_image=last_image,
    prompt=prompt,
    height=512,
    width=512,
    num_frames=91,
    num_inference_steps=30,
    guidance_scale=9.0,
    generator=torch.Generator().manual_seed(0),
).frames[0]
export_to_video(output, "C:/Users/peter/Downloads/output.mp4", fps=30)
output.mp4

@nitinmukesh
Copy link

@a-r-r-o-w

Thank you for supporting model offload.
I have logged an issue and hopefully it will be fixed fingers_crossed
huggingface/accelerate#3542

@tin2tin
Copy link

tin2tin commented Apr 30, 2025

@a-r-r-o-w Tried to improve the image and prompt for testing. If you need it later on feature demonstration:
flf2v_input_first_frame
flf2v_input_last_frame
prompt = "CG animation style, a small blue bird takes off from the ground, quickly flapping its wings up and down repeatedly. The camera follows the bird upward, capturing its flight up in the air. A close-up, low-angle perspective."
generator=torch.Generator().manual_seed(1000),

bird_output.mp4

And a little prince:
prompt = "CG animation style, the small warrior turns towards the camera and waves his arms, handheld"
generator=torch.Generator().manual_seed(1000),
prince_img_1
prince_img_2

prince_output.mp4

@tin2tin
Copy link

tin2tin commented May 1, 2025

@a-r-r-o-w Apparently, it is also possible to implement t2v: lllyasviel/FramePack#266 and Hunyuan LoRA suport: https://github.com/colinurbs/FramePack-Studio/blob/main/diffusers_helper/lora_utils.py but maybe it would be better for later patches?

@a-r-r-o-w
Copy link
Member Author

@tin2tin T2V addition sounds good! I'll add that in a follow-up PR once this is merged. LoRA should work already I think

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! looking great!
I left one comment on the transformer;

@tin2tin
Copy link

tin2tin commented May 3, 2025

FramePack F1 released:
FramePack-F1 is the framepack with forward-only sampling.
lllyasviel/FramePack#459
https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503
lllyasviel/FramePack@0f4df00

@a-r-r-o-w
Copy link
Member Author

@yiyixuxu Addressed review comments. Can you take another look?

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@tin2tin
Copy link

tin2tin commented May 6, 2025

Are the memory handling and speed improvements of FramePack included in this patch? I noticed huge speed differences between running FramePack as standalone (ca. 1 min. for 1 sec) and via Diffusers (ca. 10 min for 1 sec) using a 4090.

@a-r-r-o-w
Copy link
Member Author

@tin2tin 1 sec of video as in 31 frames generated and saved at 30 fps? I can't seem to reproduce it taking 10 mins on a 4090.

@a-r-r-o-w a-r-r-o-w merged commit d7ffe60 into main May 6, 2025
16 checks passed
@a-r-r-o-w a-r-r-o-w deleted the integrations/framepack branch May 6, 2025 09:29
@tin2tin
Copy link

tin2tin commented May 6, 2025

@tin2tin 1 sec of video as in 31 frames generated and saved at 30 fps? I can't seem to reproduce it taking 10 mins on a 4090.

Will test some more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants