From e5063143f2793f386510e4d78c126dda443ae187 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Thu, 24 Apr 2025 01:15:26 +0530 Subject: [PATCH 01/11] update --- src/diffusers/models/auto_model.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 1b742463aa2e..2eaee50f2ef7 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -156,10 +156,14 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi "subfolder": subfolder, } - config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - orig_class_name = config["_class_name"] - - library = importlib.import_module("diffusers") + try: + config = cls.load_config(os.path.join(pretrained_model_or_path, "model_index.json"), **load_config_kwargs) + library, orig_class_name = config["subfolder"] + except Exception: + # Fallback to loading the config from the config.json file + config = cls.load_config(os.path.join(pretrained_model_or_path, cls.config_name), **load_config_kwargs) + library = importlib.import_module("diffusers") + orig_class_name = config["_class_name"] model_cls = getattr(library, orig_class_name, None) if model_cls is None: From 85024b0aa294443d3814bf3e8e79231fb5062b2f Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Thu, 24 Apr 2025 01:31:32 +0530 Subject: [PATCH 02/11] update --- src/diffusers/models/auto_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 2eaee50f2ef7..1ae28774dd71 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -153,7 +153,6 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi "token": token, "local_files_only": local_files_only, "revision": revision, - "subfolder": subfolder, } try: @@ -161,6 +160,7 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi library, orig_class_name = config["subfolder"] except Exception: # Fallback to loading the config from the config.json file + load_config_kwargs["subfolder"] = subfolder config = cls.load_config(os.path.join(pretrained_model_or_path, cls.config_name), **load_config_kwargs) library = importlib.import_module("diffusers") orig_class_name = config["_class_name"] From 6a0d0bec46b4c5b03f2bfa86d63a42af221541a6 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Thu, 24 Apr 2025 02:00:06 +0530 Subject: [PATCH 03/11] update --- src/diffusers/models/auto_model.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 1ae28774dd71..1ba2284aafa3 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -22,7 +22,7 @@ class AutoModel(ConfigMixin): - config_name = "config.json" + config_name = "model_index.json" def __init__(self, *args, **kwargs): raise EnvironmentError( @@ -153,15 +153,18 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi "token": token, "local_files_only": local_files_only, "revision": revision, + "subfolder": subfolder, } try: - config = cls.load_config(os.path.join(pretrained_model_or_path, "model_index.json"), **load_config_kwargs) - library, orig_class_name = config["subfolder"] + mindex_kwargs = {k: v for k, v in load_config_kwargs.items() if k != "subfolder"} + config = cls.load_config(os.path.join(pretrained_model_or_path), **mindex_kwargs) + library, orig_class_name = config[subfolder] + library = importlib.import_module(library) except Exception: # Fallback to loading the config from the config.json file - load_config_kwargs["subfolder"] = subfolder - config = cls.load_config(os.path.join(pretrained_model_or_path, cls.config_name), **load_config_kwargs) + cls.config_name = "config.json" + config = cls.load_config(os.path.join(pretrained_model_or_path), **load_config_kwargs) library = importlib.import_module("diffusers") orig_class_name = config["_class_name"] From d86b0f28af060bf53c76b7a0f98e883aff21ddd8 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Thu, 24 Apr 2025 03:41:28 +0530 Subject: [PATCH 04/11] update --- src/diffusers/models/auto_model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 1ba2284aafa3..8ade229164bb 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -17,12 +17,13 @@ from typing import Optional, Union from huggingface_hub.utils import validate_hf_hub_args +from huggingface_hub import hf_hub_download from ..configuration_utils import ConfigMixin class AutoModel(ConfigMixin): - config_name = "model_index.json" + config_name = "config.json" def __init__(self, *args, **kwargs): raise EnvironmentError( @@ -158,13 +159,14 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi try: mindex_kwargs = {k: v for k, v in load_config_kwargs.items() if k != "subfolder"} - config = cls.load_config(os.path.join(pretrained_model_or_path), **mindex_kwargs) + mindex_kwargs["filename"] = "model_index.json" + config_path = hf_hub_download(pretrained_model_or_path, **mindex_kwargs) + config = cls.load_config(config_path, **load_config_kwargs) library, orig_class_name = config[subfolder] library = importlib.import_module(library) except Exception: # Fallback to loading the config from the config.json file - cls.config_name = "config.json" - config = cls.load_config(os.path.join(pretrained_model_or_path), **load_config_kwargs) + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) library = importlib.import_module("diffusers") orig_class_name = config["_class_name"] From 528e00239e21775d2cc3bd609152c3b0c4967f7b Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Tue, 29 Apr 2025 12:59:13 +0530 Subject: [PATCH 05/11] addressed PR comments --- src/diffusers/models/auto_model.py | 21 +++++++++++++++---- .../pipelines/pipeline_loading_utils.py | 2 +- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 8ade229164bb..285a57c25668 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -16,10 +16,12 @@ import os from typing import Optional, Union +from huggingface_hub import constants, hf_hub_download from huggingface_hub.utils import validate_hf_hub_args -from huggingface_hub import hf_hub_download +from .. import pipelines from ..configuration_utils import ConfigMixin +from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates class AutoModel(ConfigMixin): @@ -163,14 +165,25 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi config_path = hf_hub_download(pretrained_model_or_path, **mindex_kwargs) config = cls.load_config(config_path, **load_config_kwargs) library, orig_class_name = config[subfolder] - library = importlib.import_module(library) + model_cls, _ = get_class_obj_and_candidates( + library_name=library, + class_name=orig_class_name, + importable_classes=ALL_IMPORTABLE_CLASSES, + pipelines=pipelines, + is_pipeline_module=hasattr(pipelines, library), + component_name=subfolder, + cache_dir=constants.HF_HUB_CACHE, + ) except Exception: - # Fallback to loading the config from the config.json file + # Fallback to loading the config from the config.json file and `diffusers` library + import traceback + + traceback.print_exc() config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) library = importlib.import_module("diffusers") orig_class_name = config["_class_name"] + model_cls = getattr(library, orig_class_name, None) - model_cls = getattr(library, orig_class_name, None) if model_cls is None: raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 9788d758e9bc..ac8a321fde1f 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -332,7 +332,7 @@ def maybe_raise_or_warn( def get_class_obj_and_candidates( - library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None + library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name, cache_dir ): """Simple helper method to retrieve class object of module as well as potential parent class objects""" component_folder = os.path.join(cache_dir, component_name) From 6e92f40325035593d1f3e5036ac80332498e5c09 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Tue, 29 Apr 2025 13:58:58 +0530 Subject: [PATCH 06/11] update --- src/diffusers/models/auto_model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 285a57c25668..40cd7bf2c8b6 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -176,9 +176,6 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi ) except Exception: # Fallback to loading the config from the config.json file and `diffusers` library - import traceback - - traceback.print_exc() config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) library = importlib.import_module("diffusers") orig_class_name = config["_class_name"] From f697631f96accd61f78be34294b9754ca664383e Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Fri, 2 May 2025 11:16:47 +0530 Subject: [PATCH 07/11] addressed PR comments --- src/diffusers/models/auto_model.py | 15 ++++++++------- src/diffusers/pipelines/pipeline_loading_utils.py | 6 +++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 40cd7bf2c8b6..977a0474b24a 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -16,10 +16,9 @@ import os from typing import Optional, Union -from huggingface_hub import constants, hf_hub_download -from huggingface_hub.utils import validate_hf_hub_args +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, validate_hf_hub_args -from .. import pipelines from ..configuration_utils import ConfigMixin from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates @@ -160,6 +159,9 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi } try: + # To avoid circular import problem. + from diffusers import pipelines + mindex_kwargs = {k: v for k, v in load_config_kwargs.items() if k != "subfolder"} mindex_kwargs["filename"] = "model_index.json" config_path = hf_hub_download(pretrained_model_or_path, **mindex_kwargs) @@ -171,11 +173,10 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi importable_classes=ALL_IMPORTABLE_CLASSES, pipelines=pipelines, is_pipeline_module=hasattr(pipelines, library), - component_name=subfolder, - cache_dir=constants.HF_HUB_CACHE, ) - except Exception: - # Fallback to loading the config from the config.json file and `diffusers` library + except EntryNotFoundError: + # If `model_index.json` is not found, we try to load the model from the + # `config.json` file and `diffusers` library. config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) library = importlib.import_module("diffusers") orig_class_name = config["_class_name"] diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index ac8a321fde1f..64567bf02a31 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -332,17 +332,17 @@ def maybe_raise_or_warn( def get_class_obj_and_candidates( - library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name, cache_dir + library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): """Simple helper method to retrieve class object of module as well as potential parent class objects""" - component_folder = os.path.join(cache_dir, component_name) + component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None if is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) class_candidates = dict.fromkeys(importable_classes.keys(), class_obj) - elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): + elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")): # load custom component class_obj = get_class_from_dynamic_module( component_folder, module_file=library_name + ".py", class_name=class_name From 5614a156b192dc0ad7fece00b49c1731c5e91bb2 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Fri, 2 May 2025 11:59:14 +0530 Subject: [PATCH 08/11] added tests --- tests/models/test_models_auto.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/models/test_models_auto.py diff --git a/tests/models/test_models_auto.py b/tests/models/test_models_auto.py new file mode 100644 index 000000000000..ee56e38dcb59 --- /dev/null +++ b/tests/models/test_models_auto.py @@ -0,0 +1,20 @@ +import unittest +from unittest.mock import patch + +from huggingface_hub.utils import EntryNotFoundError +from transformers import CLIPTextModel + +from diffusers.models import AutoModel, UNet2DConditionModel + + +class TestAutoModel(unittest.TestCase): + @patch("diffusers.models.auto_model.hf_hub_download", side_effect=EntryNotFoundError("File not found")) + def test_from_pretrained_falls_back_on_entry_error(self, mock_hf_hub_download): + model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet") + assert isinstance(model, UNet2DConditionModel) + + def test_from_pretrained_loads_successfully( + self + ): + model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder") + assert isinstance(model, CLIPTextModel) From f6b6b427221caba00bd1a3d7220c700098fde087 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Fri, 2 May 2025 15:01:23 +0530 Subject: [PATCH 09/11] addressed PR comments --- src/diffusers/models/auto_model.py | 44 ++++++++++++++++++------------ 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 977a0474b24a..45fae413f4cd 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -158,24 +158,32 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi "subfolder": subfolder, } - try: - # To avoid circular import problem. - from diffusers import pipelines - - mindex_kwargs = {k: v for k, v in load_config_kwargs.items() if k != "subfolder"} - mindex_kwargs["filename"] = "model_index.json" - config_path = hf_hub_download(pretrained_model_or_path, **mindex_kwargs) - config = cls.load_config(config_path, **load_config_kwargs) - library, orig_class_name = config[subfolder] - model_cls, _ = get_class_obj_and_candidates( - library_name=library, - class_name=orig_class_name, - importable_classes=ALL_IMPORTABLE_CLASSES, - pipelines=pipelines, - is_pipeline_module=hasattr(pipelines, library), - ) - except EntryNotFoundError: - # If `model_index.json` is not found, we try to load the model from the + if subfolder is not None: + try: + # To avoid circular import problem. + from diffusers import pipelines + + mindex_kwargs = {k: v for k, v in load_config_kwargs.items() if k != "subfolder"} + mindex_kwargs["filename"] = "model_index.json" + config_path = hf_hub_download(pretrained_model_or_path, **mindex_kwargs) + config = cls.load_config(config_path, **load_config_kwargs) + library, orig_class_name = config[subfolder] + model_cls, _ = get_class_obj_and_candidates( + library_name=library, + class_name=orig_class_name, + importable_classes=ALL_IMPORTABLE_CLASSES, + pipelines=pipelines, + is_pipeline_module=hasattr(pipelines, library), + ) + except EntryNotFoundError: + # If `model_index.json` is not found, we load the model from the + # `config.json` file and `diffusers` library. + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + library = importlib.import_module("diffusers") + orig_class_name = config["_class_name"] + model_cls = getattr(library, orig_class_name, None) + else: + # If `subfolder` is not provided, we load the model from the # `config.json` file and `diffusers` library. config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) library = importlib.import_module("diffusers") From 4e5cac1d19f156c3b710a8d6d7d8be55ca9eb777 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Fri, 2 May 2025 15:57:01 +0530 Subject: [PATCH 10/11] updates --- src/diffusers/models/auto_model.py | 60 +++++++++++++++--------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 45fae413f4cd..723b5b23b7b4 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -12,17 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib import os from typing import Optional, Union -from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, validate_hf_hub_args from ..configuration_utils import ConfigMixin from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates +from ..utils import logging +logger = logging.get_logger(__name__) + class AutoModel(ConfigMixin): config_name = "config.json" @@ -155,40 +156,39 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi "token": token, "local_files_only": local_files_only, "revision": revision, - "subfolder": subfolder, } - if subfolder is not None: - try: - # To avoid circular import problem. - from diffusers import pipelines + library = None + orig_class_name = None + from diffusers import pipelines + + # Always attempt to fetch model_index.json first + try: + cls.config_name = "model_index.json" + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - mindex_kwargs = {k: v for k, v in load_config_kwargs.items() if k != "subfolder"} - mindex_kwargs["filename"] = "model_index.json" - config_path = hf_hub_download(pretrained_model_or_path, **mindex_kwargs) - config = cls.load_config(config_path, **load_config_kwargs) + if subfolder is not None and subfolder in config: library, orig_class_name = config[subfolder] - model_cls, _ = get_class_obj_and_candidates( - library_name=library, - class_name=orig_class_name, - importable_classes=ALL_IMPORTABLE_CLASSES, - pipelines=pipelines, - is_pipeline_module=hasattr(pipelines, library), - ) - except EntryNotFoundError: - # If `model_index.json` is not found, we load the model from the - # `config.json` file and `diffusers` library. - config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - library = importlib.import_module("diffusers") - orig_class_name = config["_class_name"] - model_cls = getattr(library, orig_class_name, None) - else: - # If `subfolder` is not provided, we load the model from the - # `config.json` file and `diffusers` library. + + except EntryNotFoundError as e: + logger.debug(e) + + # Unable to load from model_index.json so fallback to loading from config + if library is None and orig_class_name is None: + cls.config_name = "config.json" + load_config_kwargs.update({"subfolder": subfolder}) + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - library = importlib.import_module("diffusers") orig_class_name = config["_class_name"] - model_cls = getattr(library, orig_class_name, None) + library = "diffusers" + + model_cls, _ = get_class_obj_and_candidates( + library_name=library, + class_name=orig_class_name, + importable_classes=ALL_IMPORTABLE_CLASSES, + pipelines=pipelines, + is_pipeline_module=hasattr(pipelines, library), + ) if model_cls is None: raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") From 24f16f68f507b2bd39193a65b68ab47d315cfda1 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Sun, 4 May 2025 14:19:39 +0530 Subject: [PATCH 11/11] update --- src/diffusers/models/auto_model.py | 34 ++++++++++++++++++++++++------ tests/models/test_models_auto.py | 20 ++++++++++++------ 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 723b5b23b7b4..f8daf43a664f 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -15,7 +15,7 @@ import os from typing import Optional, Union -from huggingface_hub.utils import EntryNotFoundError, validate_hf_hub_args +from huggingface_hub.utils import validate_hf_hub_args from ..configuration_utils import ConfigMixin from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates @@ -169,18 +169,38 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi if subfolder is not None and subfolder in config: library, orig_class_name = config[subfolder] + load_config_kwargs.update({"subfolder": subfolder}) - except EntryNotFoundError as e: + except EnvironmentError as e: logger.debug(e) # Unable to load from model_index.json so fallback to loading from config if library is None and orig_class_name is None: cls.config_name = "config.json" - load_config_kwargs.update({"subfolder": subfolder}) - - config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - orig_class_name = config["_class_name"] - library = "diffusers" + config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs) + + if "_class_name" in config: + # If we find a class name in the config, we can try to load the model as a diffusers model + orig_class_name = config["_class_name"] + library = "diffusers" + load_config_kwargs.update({"subfolder": subfolder}) + else: + # If we don't find a class name in the config, we can try to load the model as a transformers model + logger.warning( + f"Doesn't look like a diffusers model. Loading {pretrained_model_or_path} as a transformer model." + ) + if "architectures" in config and len(config["architectures"]) > 0: + if len(config["architectures"]) > 1: + logger.warning( + f"Found multiple architectures in {pretrained_model_or_path}. Using the first one: {config['architectures'][0]}" + ) + orig_class_name = config["architectures"][0] + library = "transformers" + load_config_kwargs.update({"subfolder": "" if subfolder is None else subfolder}) + else: + raise ValueError( + f"Couldn't find model associated with the config file at {pretrained_model_or_path}." + ) model_cls, _ = get_class_obj_and_candidates( library_name=library, diff --git a/tests/models/test_models_auto.py b/tests/models/test_models_auto.py index ee56e38dcb59..5b7a5e5b3bcf 100644 --- a/tests/models/test_models_auto.py +++ b/tests/models/test_models_auto.py @@ -1,20 +1,26 @@ import unittest from unittest.mock import patch -from huggingface_hub.utils import EntryNotFoundError -from transformers import CLIPTextModel +from transformers import AlbertForMaskedLM, CLIPTextModel from diffusers.models import AutoModel, UNet2DConditionModel class TestAutoModel(unittest.TestCase): - @patch("diffusers.models.auto_model.hf_hub_download", side_effect=EntryNotFoundError("File not found")) - def test_from_pretrained_falls_back_on_entry_error(self, mock_hf_hub_download): + @patch("diffusers.models.AutoModel.load_config", side_effect=[EnvironmentError("File not found"), {"_class_name": "UNet2DConditionModel"}]) + def test_load_from_config_diffusers_with_subfolder(self, mock_load_config): model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet") assert isinstance(model, UNet2DConditionModel) - def test_from_pretrained_loads_successfully( - self - ): + @patch("diffusers.models.AutoModel.load_config", side_effect=[EnvironmentError("File not found"), {"architectures": [ "CLIPTextModel"]}]) + def test_load_from_config_transformers_with_subfolder(self, mock_load_config): + model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder") + assert isinstance(model, CLIPTextModel) + + def test_load_from_config_without_subfolder(self): + model = AutoModel.from_pretrained("hf-internal-testing/tiny-albert") + assert isinstance(model, AlbertForMaskedLM) + + def test_load_from_model_index(self): model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder") assert isinstance(model, CLIPTextModel)