diff --git a/nemo_automodel/__init__.py b/nemo_automodel/__init__.py index 8bdc4be35d..febbe6e066 100644 --- a/nemo_automodel/__init__.py +++ b/nemo_automodel/__init__.py @@ -35,6 +35,10 @@ "nemo_automodel._transformers.auto_model", "NeMoAutoModelForSequenceClassification", ), + "NeMoAutoModelForTokenClassification": ( + "nemo_automodel._transformers.auto_model", + "NeMoAutoModelForTokenClassification", + ), "NeMoAutoModelForTextToWaveform": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForTextToWaveform"), "NeMoAutoModelBiEncoder": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelBiEncoder"), "NeMoAutoModelCrossEncoder": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelCrossEncoder"), diff --git a/nemo_automodel/_transformers/__init__.py b/nemo_automodel/_transformers/__init__.py index 03f70f3da4..aa00fdf539 100644 --- a/nemo_automodel/_transformers/__init__.py +++ b/nemo_automodel/_transformers/__init__.py @@ -25,6 +25,10 @@ "nemo_automodel._transformers.auto_model", "NeMoAutoModelForSequenceClassification", ), + "NeMoAutoModelForTokenClassification": ( + "nemo_automodel._transformers.auto_model", + "NeMoAutoModelForTokenClassification", + ), "NeMoAutoModelForTextToWaveform": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForTextToWaveform"), "NeMoAutoModelBiEncoder": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelBiEncoder"), "NeMoAutoModelCrossEncoder": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelCrossEncoder"), @@ -37,6 +41,7 @@ "NeMoAutoModelForImageTextToText", "NeMoAutoModelForMultimodalLM", "NeMoAutoModelForSequenceClassification", + "NeMoAutoModelForTokenClassification", "NeMoAutoModelForTextToWaveform", "NeMoAutoModelBiEncoder", "NeMoAutoModelCrossEncoder", diff --git a/nemo_automodel/_transformers/auto_model.py b/nemo_automodel/_transformers/auto_model.py index 3df1e2d5ba..965ba3964f 100644 --- a/nemo_automodel/_transformers/auto_model.py +++ b/nemo_automodel/_transformers/auto_model.py @@ -42,6 +42,7 @@ AutoModelForMultimodalLM, AutoModelForSequenceClassification, AutoModelForTextToWaveform, + AutoModelForTokenClassification, PreTrainedModel, ) from transformers.initialization import no_init_weights # noqa: E402 @@ -736,6 +737,33 @@ class NeMoAutoModelForSequenceClassification(_BaseNeMoAutoModelClass, AutoModelF pass +class NeMoAutoModelForTokenClassification(_BaseNeMoAutoModelClass, AutoModelForTokenClassification): + """Drop-in replacement for ``transformers.AutoModelForTokenClassification`` with custom-kernels. + + The class only overrides ``from_pretrained`` and ``from_config`` to add the + optional ``use_liger_kernel`` flag. If the flag is ``True`` (default) and + the Liger kernel is available, the model's attention layers are + monkey-patched in place. If patching fails for any reason, the call is + retried once with ``use_liger_kernel=False`` so that users still obtain a + functional model. + + Notes: + ----- + - No changes are made to the model's public API; forward signatures, + generation utilities, and weight shapes remain identical. + - Only decoder-style (causal) architectures are currently supported by the + Liger patch. Unsupported models will silently fall back. + + Examples: + -------- + >>> model = NeMoAutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english") # try Liger + >>> model = NeMoAutoModelForTokenClassification.from_pretrained( + ... "dbmdz/bert-large-cased-finetuned-conll03-english", use_liger_kernel=False) # skip Liger + """ + + pass + + class NeMoAutoModelForTextToWaveform(_BaseNeMoAutoModelClass, AutoModelForTextToWaveform): """Drop-in replacement for ``transformers.AutoModelForTextToWaveform`` with custom-kernels.