Skip to content

Commit 30ef8af

Browse files
fix(load): restore LabelAttentionConfig object from loaded dict
used as a namespace after, so no converting it throws a bug
1 parent 601fa46 commit 30ef8af

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

torchTextClassifiers/torchTextClassifiers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,10 @@ def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassif
671671

672672
# Reconstruct model_config
673673
model_config = ModelConfig.from_dict(metadata["model_config"])
674+
if type(model_config.label_attention_config) is dict:
675+
model_config.label_attention_config = LabelAttentionConfig(
676+
**model_config.label_attention_config
677+
)
674678

675679
# Create instance
676680
instance = cls(

0 commit comments

Comments
 (0)