diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 1b742463aa2e..f8daf43a664f 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib import os from typing import Optional, Union from huggingface_hub.utils import validate_hf_hub_args from ..configuration_utils import ConfigMixin +from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates +from ..utils import logging +logger = logging.get_logger(__name__) + class AutoModel(ConfigMixin): config_name = "config.json" @@ -153,15 +156,60 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi "token": token, "local_files_only": local_files_only, "revision": revision, - "subfolder": subfolder, } - config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - orig_class_name = config["_class_name"] - - library = importlib.import_module("diffusers") + library = None + orig_class_name = None + from diffusers import pipelines + + # Always attempt to fetch model_index.json first + try: + cls.config_name = "model_index.json" + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + + if subfolder is not None and subfolder in config: + library, orig_class_name = config[subfolder] + load_config_kwargs.update({"subfolder": subfolder}) + + except EnvironmentError as e: + logger.debug(e) + + # Unable to load from model_index.json so fallback to loading from config + if library is None and orig_class_name is None: + cls.config_name = "config.json" + config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs) + + if "_class_name" in config: + # If we find a class name in the config, we can try to load the model as a diffusers model + orig_class_name = config["_class_name"] + library = "diffusers" + load_config_kwargs.update({"subfolder": subfolder}) + else: + # If we don't find a class name in the config, we can try to load the model as a transformers model + logger.warning( + f"Doesn't look like a diffusers model. Loading {pretrained_model_or_path} as a transformer model." + ) + if "architectures" in config and len(config["architectures"]) > 0: + if len(config["architectures"]) > 1: + logger.warning( + f"Found multiple architectures in {pretrained_model_or_path}. Using the first one: {config['architectures'][0]}" + ) + orig_class_name = config["architectures"][0] + library = "transformers" + load_config_kwargs.update({"subfolder": "" if subfolder is None else subfolder}) + else: + raise ValueError( + f"Couldn't find model associated with the config file at {pretrained_model_or_path}." + ) + + model_cls, _ = get_class_obj_and_candidates( + library_name=library, + class_name=orig_class_name, + importable_classes=ALL_IMPORTABLE_CLASSES, + pipelines=pipelines, + is_pipeline_module=hasattr(pipelines, library), + ) - model_cls = getattr(library, orig_class_name, None) if model_cls is None: raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 9788d758e9bc..64567bf02a31 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -335,14 +335,14 @@ def get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): """Simple helper method to retrieve class object of module as well as potential parent class objects""" - component_folder = os.path.join(cache_dir, component_name) + component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None if is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) class_candidates = dict.fromkeys(importable_classes.keys(), class_obj) - elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): + elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")): # load custom component class_obj = get_class_from_dynamic_module( component_folder, module_file=library_name + ".py", class_name=class_name diff --git a/tests/models/test_models_auto.py b/tests/models/test_models_auto.py new file mode 100644 index 000000000000..5b7a5e5b3bcf --- /dev/null +++ b/tests/models/test_models_auto.py @@ -0,0 +1,26 @@ +import unittest +from unittest.mock import patch + +from transformers import AlbertForMaskedLM, CLIPTextModel + +from diffusers.models import AutoModel, UNet2DConditionModel + + +class TestAutoModel(unittest.TestCase): + @patch("diffusers.models.AutoModel.load_config", side_effect=[EnvironmentError("File not found"), {"_class_name": "UNet2DConditionModel"}]) + def test_load_from_config_diffusers_with_subfolder(self, mock_load_config): + model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet") + assert isinstance(model, UNet2DConditionModel) + + @patch("diffusers.models.AutoModel.load_config", side_effect=[EnvironmentError("File not found"), {"architectures": [ "CLIPTextModel"]}]) + def test_load_from_config_transformers_with_subfolder(self, mock_load_config): + model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder") + assert isinstance(model, CLIPTextModel) + + def test_load_from_config_without_subfolder(self): + model = AutoModel.from_pretrained("hf-internal-testing/tiny-albert") + assert isinstance(model, AlbertForMaskedLM) + + def test_load_from_model_index(self): + model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder") + assert isinstance(model, CLIPTextModel)