Skip to content

Commit 658952d

Browse files
Butaniumclaude
andcommitted
Add dict_class to save_pretrained config for local model loading
save_pretrained now includes the model class name in config.json, enabling load_dictionary_model to determine the correct class when loading from a local save_pretrained directory. _from_pretrained strips dict_class before instantiation to avoid __init__ TypeError. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c2617ad commit 658952d

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

dictionary_learning/dictionary.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,17 @@ def decode(self, f):
162162
"""
163163
pass
164164

165+
def save_pretrained(self, *args, **kwargs):
166+
"""Save model with dict_class in config for local loading."""
167+
self._hub_mixin_config["dict_class"] = type(self).__name__
168+
return super().save_pretrained(*args, **kwargs)
169+
170+
@classmethod
171+
def _from_pretrained(cls, *, model_id, **kwargs):
172+
"""Strip dict_class from config before instantiation."""
173+
kwargs.pop("dict_class", None)
174+
return super()._from_pretrained(model_id=model_id, **kwargs)
175+
165176
@classmethod
166177
@abstractmethod
167178
def from_pretrained(

0 commit comments

Comments
 (0)