Skip to content

Commit 3fd14f1

Browse files
authored
[AutoModel] Allow registering auto_map to model config (#13186)
* update * update
1 parent e7fe4ce commit 3fd14f1

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

docs/source/en/using-diffusers/automodel.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,32 @@ If the custom model inherits from the [`ModelMixin`] class, it gets access to th
9797
> )
9898
> ```
9999
100+
### Saving custom models
101+
102+
Use [`~ConfigMixin.register_for_auto_class`] to add the `auto_map` entry to `config.json` automatically when saving. This avoids having to manually edit the config file.
103+
104+
```py
105+
# my_model.py
106+
from diffusers import ModelMixin, ConfigMixin
107+
108+
class MyCustomModel(ModelMixin, ConfigMixin):
109+
...
110+
111+
MyCustomModel.register_for_auto_class("AutoModel")
112+
113+
model = MyCustomModel(...)
114+
model.save_pretrained("./my_model")
115+
```
116+
117+
The saved `config.json` will include the `auto_map` field.
118+
119+
```json
120+
{
121+
"auto_map": {
122+
"AutoModel": "my_model.MyCustomModel"
123+
}
124+
}
125+
```
126+
100127
> [!NOTE]
101128
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.

src/diffusers/configuration_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,38 @@ class ConfigMixin:
107107
has_compatibles = False
108108

109109
_deprecated_kwargs = []
110+
_auto_class = None
111+
112+
@classmethod
113+
def register_for_auto_class(cls, auto_class="AutoModel"):
114+
"""
115+
Register this class with the given auto class so that it can be loaded with `AutoModel.from_pretrained(...,
116+
trust_remote_code=True)`.
117+
118+
When the config is saved, the resulting `config.json` will include an `auto_map` entry mapping the auto class
119+
to this class's module and class name.
120+
121+
Args:
122+
auto_class (`str` or type, *optional*, defaults to `"AutoModel"`):
123+
The auto class to register this class with. Can be a string (e.g. `"AutoModel"`) or the class itself.
124+
Currently only `"AutoModel"` is supported.
125+
126+
Example:
127+
128+
```python
129+
from diffusers import ModelMixin, ConfigMixin
130+
131+
132+
class MyCustomModel(ModelMixin, ConfigMixin): ...
133+
134+
135+
MyCustomModel.register_for_auto_class("AutoModel")
136+
```
137+
"""
138+
if auto_class != "AutoModel":
139+
raise ValueError(f"Only 'AutoModel' is supported, got '{auto_class}'.")
140+
141+
cls._auto_class = auto_class
110142

111143
def register_to_config(self, **kwargs):
112144
if self.config_name is None:
@@ -621,6 +653,12 @@ def to_json_saveable(value):
621653
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
622654
_ = config_dict.pop("_pre_quantization_dtype", None)
623655

656+
if getattr(self, "_auto_class", None) is not None:
657+
module = self.__class__.__module__.split(".")[-1]
658+
auto_map = config_dict.get("auto_map", {})
659+
auto_map[self._auto_class] = f"{module}.{self.__class__.__name__}"
660+
config_dict["auto_map"] = auto_map
661+
624662
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
625663

626664
def to_json_file(self, json_file_path: str | os.PathLike):

tests/models/test_models_auto.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import torch
88
from transformers import CLIPTextModel, LongformerModel
99

10+
from diffusers import ConfigMixin
1011
from diffusers.models import AutoModel, UNet2DConditionModel
12+
from diffusers.models.modeling_utils import ModelMixin
1113

1214

1315
class TestAutoModel(unittest.TestCase):
@@ -143,3 +145,51 @@ def test_from_config_with_model_type_routes_to_transformers(self, mock_get_class
143145
def test_from_config_raises_on_none(self):
144146
with self.assertRaises(ValueError, msg="Please provide a `pretrained_model_name_or_path_or_dict`"):
145147
AutoModel.from_config(None)
148+
149+
150+
class TestRegisterForAutoClass(unittest.TestCase):
151+
def test_register_for_auto_class_sets_attribute(self):
152+
class DummyModel(ModelMixin, ConfigMixin):
153+
config_name = "config.json"
154+
155+
DummyModel.register_for_auto_class("AutoModel")
156+
self.assertEqual(DummyModel._auto_class, "AutoModel")
157+
158+
def test_register_for_auto_class_rejects_unsupported(self):
159+
class DummyModel(ModelMixin, ConfigMixin):
160+
config_name = "config.json"
161+
162+
with self.assertRaises(ValueError, msg="Only 'AutoModel' is supported"):
163+
DummyModel.register_for_auto_class("AutoPipeline")
164+
165+
def test_auto_map_in_saved_config(self):
166+
class DummyModel(ModelMixin, ConfigMixin):
167+
config_name = "config.json"
168+
169+
DummyModel.register_for_auto_class("AutoModel")
170+
model = DummyModel()
171+
172+
with tempfile.TemporaryDirectory() as tmpdir:
173+
model.save_config(tmpdir)
174+
config_path = os.path.join(tmpdir, "config.json")
175+
with open(config_path, "r") as f:
176+
config = json.load(f)
177+
178+
self.assertIn("auto_map", config)
179+
self.assertIn("AutoModel", config["auto_map"])
180+
module_name = DummyModel.__module__.split(".")[-1]
181+
self.assertEqual(config["auto_map"]["AutoModel"], f"{module_name}.DummyModel")
182+
183+
def test_no_auto_map_without_register(self):
184+
class DummyModel(ModelMixin, ConfigMixin):
185+
config_name = "config.json"
186+
187+
model = DummyModel()
188+
189+
with tempfile.TemporaryDirectory() as tmpdir:
190+
model.save_config(tmpdir)
191+
config_path = os.path.join(tmpdir, "config.json")
192+
with open(config_path, "r") as f:
193+
config = json.load(f)
194+
195+
self.assertNotIn("auto_map", config)

0 commit comments

Comments
 (0)