|
| 1 | +<!--Copyright 2026 The HuggingFace Team. All rights reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
| 4 | +the License. You may obtain a copy of the License at |
| 5 | +
|
| 6 | +http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +
|
| 8 | +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
| 9 | +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
| 10 | +specific language governing permissions and limitations under the License. |
| 11 | +
|
| 12 | +--> |
| 13 | + |
| 14 | +# Model structure rules |
| 15 | + |
| 16 | +Transformers enforces a set of static rules on every `modeling_*.py`, `modular_*.py`, and `configuration_*.py` file. The [mlinter](https://github.com/huggingface/transformers/tree/main/utils/mlinter) tool checks them as part of `make typing` and errors out if violations are found. |
| 17 | + |
| 18 | +These are the expected model conventions for adding or changing modeling code. They keep the codebase consistent and ensure compatibility with features like pipeline parallelism, device maps, and weight tying. |
| 19 | + |
| 20 | +## Running the checker |
| 21 | + |
| 22 | +`make typing` runs `mlinter` alongside the `ty` type checker. Run `mlinter` on its own with the following commands. |
| 23 | + |
| 24 | +```bash |
| 25 | +python -m utils.mlinter # check all modeling files |
| 26 | +python -m utils.mlinter --changed-only # check only files changed vs origin/main |
| 27 | +python -m utils.mlinter --list-rules # list all rules and their enabled status |
| 28 | +python -m utils.mlinter --rule TRF001 # show built-in docs for a specific rule |
| 29 | +``` |
| 30 | + |
| 31 | +The `--changed-only` flag is the fastest option during development. It only checks the files you've modified relative to the main branch. |
| 32 | + |
| 33 | +## Fixing a violation |
| 34 | + |
| 35 | +When a rule violation is detected, the error looks like this: |
| 36 | + |
| 37 | +``` |
| 38 | +src/transformers/models/acme/modeling_acme.py:18: TRF013: AcmeModel.__init__ does not call self.post_init(). |
| 39 | +``` |
| 40 | + |
| 41 | +Use the rule ID to look up the fix in the [rules reference](#rules-reference). TRF013 is triggered when a [`PreTrainedModel`] subclass doesn't call `self.post_init()`. That method performs essential finalization steps, and omitting it causes runtime bugs. |
| 42 | + |
| 43 | +```diff |
| 44 | + class AcmeModel(AcmePreTrainedModel): |
| 45 | + def __init__(self, config): |
| 46 | + super().__init__(config) |
| 47 | + self.layers = nn.ModuleList( |
| 48 | + [AcmeDecoderLayer(config) for _ in range(config.num_hidden_layers)] |
| 49 | + ) |
| 50 | ++ self.post_init() |
| 51 | +``` |
| 52 | + |
| 53 | +## Rules reference |
| 54 | + |
| 55 | +Each rule below lists what it enforces and a diff showing the fix. Run `python -m utils.mlinter --rule TRF001` to see the built-in docs for any rule. |
| 56 | + |
| 57 | +<!-- BEGIN RULES REFERENCE --> |
| 58 | + |
| 59 | +### TRF001 |
| 60 | + |
| 61 | +Checks naming consistency between <Model>PreTrainedModel and config_class. Mismatched config_class can break loading, auto classes, and developer expectations. |
| 62 | + |
| 63 | +```diff |
| 64 | +class AcmePreTrainedModel(PreTrainedModel): |
| 65 | +- config_class = WileConfig |
| 66 | ++ config_class = AcmeConfig |
| 67 | +``` |
| 68 | + |
| 69 | +### TRF002 |
| 70 | + |
| 71 | +Checks that base_model_prefix, when set, is a non-empty, whitespace-free string literal. Invalid prefixes can break weight loading key mapping and base model access patterns. |
| 72 | + |
| 73 | +```diff |
| 74 | +class AcmePreTrainedModel(PreTrainedModel): |
| 75 | +- base_model_prefix = "" |
| 76 | ++ base_model_prefix = "model" |
| 77 | +``` |
| 78 | + |
| 79 | +### TRF003 |
| 80 | + |
| 81 | +Detects forward methods that use the old 'if not return_dict: return (x,)' pattern. The old return_dict branching pattern is error-prone and verbose. Use the capture_output or can_return_tuple decorators instead. |
| 82 | + |
| 83 | +```diff |
| 84 | +-def forward(self, x, return_dict=None): |
| 85 | +- if not return_dict: |
| 86 | +- return (x,) |
| 87 | +- return AcmeModelOutput(last_hidden_state=x) |
| 88 | ++@can_return_tuple |
| 89 | ++def forward(self, x): |
| 90 | ++ return AcmeModelOutput(last_hidden_state=x) |
| 91 | +``` |
| 92 | + |
| 93 | +### TRF004 |
| 94 | + |
| 95 | +Checks that no model class defines a tie_weights method. Overriding tie_weights leads to bad consequences for loading, device_map computation, and saving. Use _tied_weights_keys class attribute to declare tied weights instead. |
| 96 | + |
| 97 | +```diff |
| 98 | +-def tie_weights(self): |
| 99 | +- self.lm_head.weight = self.emb.weight |
| 100 | ++class AcmeForCausalLM(AcmePreTrainedModel): |
| 101 | ++ _tied_weights_keys = ["lm_head.weight"] |
| 102 | +``` |
| 103 | + |
| 104 | +### TRF005 |
| 105 | + |
| 106 | +Checks the shape of _no_split_modules when present. Malformed values can break device-map partitioning and sharding behavior. |
| 107 | + |
| 108 | +```diff |
| 109 | +-_no_split_modules = [SomeLayerClass, ""] |
| 110 | ++_no_split_modules = ["AcmeDecoderLayer", "AcmeAttention"] |
| 111 | +``` |
| 112 | + |
| 113 | +### TRF006 |
| 114 | + |
| 115 | +Checks forward signatures that expose cache arguments for usage of those arguments in method body. Unused cache arguments can indicate incomplete caching support and inconsistent API behavior. |
| 116 | + |
| 117 | +```diff |
| 118 | +def forward(self, x, past_key_values=None, use_cache=False): |
| 119 | ++ if use_cache: |
| 120 | ++ ... |
| 121 | + return x |
| 122 | +``` |
| 123 | + |
| 124 | +### TRF007 |
| 125 | + |
| 126 | +Checks for self attribute assignments after self.post_init() in __init__. Mutating model structure after post_init can bypass intended initialization/finalization logic. |
| 127 | + |
| 128 | +```diff |
| 129 | +def __init__(self, config): |
| 130 | + ... |
| 131 | +- self.post_init() |
| 132 | +- self.proj = nn.Linear(...) |
| 133 | ++ self.proj = nn.Linear(...) |
| 134 | ++ self.post_init() |
| 135 | +``` |
| 136 | + |
| 137 | +### TRF008 |
| 138 | + |
| 139 | +Checks add_start_docstrings usage on model classes for non-empty docstring arguments. Empty decorator usage produces unclear docs and weakens generated API documentation quality. |
| 140 | + |
| 141 | +```diff |
| 142 | +-@add_start_docstrings("") |
| 143 | ++@add_start_docstrings("The Acme model.") |
| 144 | + class AcmeModel(AcmePreTrainedModel): |
| 145 | + ... |
| 146 | +``` |
| 147 | + |
| 148 | +### TRF009 |
| 149 | + |
| 150 | +Checks modeling files for cross-model imports such as transformers.models.other_model.* or from ..other_model.* imports. Cross-model implementation imports violate the single-file policy and make model behavior harder to inspect and maintain. |
| 151 | + |
| 152 | +```diff |
| 153 | +-from transformers.models.llama.modeling_llama import LlamaAttention |
| 154 | ++# Keep implementation local to this file. |
| 155 | ++# If reusing code, copy it with a # Copied from comment. |
| 156 | +``` |
| 157 | + |
| 158 | +### TRF010 |
| 159 | + |
| 160 | +Checks direct PreTrainedConfig/PretrainedConfig subclasses in configuration_*.py and modular_*.py for an explicit @strict(accept_kwargs=True) decorator. Without strict, new config classes miss the repo's runtime type-validation contract and drift from the dataclass-based config standard. |
| 161 | + |
| 162 | +```diff |
| 163 | ++@strict(accept_kwargs=True) |
| 164 | + class AcmeConfig(PreTrainedConfig): |
| 165 | + ... |
| 166 | +``` |
| 167 | + |
| 168 | +### TRF011 |
| 169 | + |
| 170 | +In forward() methods of PreTrainedModel subclasses, checks for attribute accesses on submodules that would not exist on torch.nn.Identity. This includes attribute accesses on loop variables iterating over self.layers, and self.<submodule>.<attr> chains where <attr> is not a standard nn.Module attribute. Pipeline parallelism may replace any submodule with torch.nn.Identity. Accessing custom attributes (e.g. decoder_layer.attention_type) on a replaced module raises AttributeError at runtime. Per-layer metadata should be read from self.config instead. |
| 171 | + |
| 172 | +```diff |
| 173 | +def forward(self, ...): |
| 174 | +- for decoder_layer in self.layers: |
| 175 | ++ for i, decoder_layer in enumerate(self.layers): |
| 176 | + hidden_states = decoder_layer( |
| 177 | + hidden_states, |
| 178 | +- attention_mask=causal_mask_mapping[decoder_layer.attention_type], |
| 179 | ++ attention_mask=causal_mask_mapping[self.config.layer_types[i]], |
| 180 | + ) |
| 181 | +``` |
| 182 | + |
| 183 | +### TRF012 |
| 184 | + |
| 185 | +Checks that _init_weights(self, module) does not use in-place operations (e.g. .normal_(), .zero_()) directly on module weights. We rely on internal flags set on parameters to track whether they need re-initialization. In-place ops bypass this mechanism. Use the `init` primitives instead. |
| 186 | + |
| 187 | +```diff |
| 188 | ++from transformers import initialization as init |
| 189 | ++ |
| 190 | + def _init_weights(self, module): |
| 191 | +- module.weight.normal_(mean=0.0, std=0.02) |
| 192 | ++ init.normal_(module.weight, mean=0.0, std=0.02) |
| 193 | +``` |
| 194 | + |
| 195 | +### TRF013 |
| 196 | + |
| 197 | +Checks that every PreTrainedModel subclass with an __init__ method calls self.post_init(). In modular files, calling super().__init__() is also accepted since it propagates post_init from the parent. post_init performs essential finalization (weight initialization, gradient checkpointing setup, etc.). Omitting it causes subtle runtime bugs. |
| 198 | + |
| 199 | +```diff |
| 200 | +class AcmeModel(AcmePreTrainedModel): |
| 201 | + def __init__(self, config): |
| 202 | + super().__init__(config) |
| 203 | + self.layers = nn.ModuleList(...) |
| 204 | ++ self.post_init() |
| 205 | +``` |
| 206 | + |
| 207 | +### TRF014 |
| 208 | + |
| 209 | +Checks whether `trust_remote_code` is passed or used in code (e.g. as kwarg) within native model integration files. `trust_remote_code` allows arbitrary loading, including binaries, which should only be a power feature for users, not a standard use-case. Native integrations must not depend on it, as remote code cannot be reviewed or maintained within transformers. |
| 210 | + |
| 211 | +```diff |
| 212 | +class AcmeModel(AcmePreTrainedModel): |
| 213 | + def __init__(self, config): |
| 214 | + super().__init__(config) |
| 215 | +- self.model = AutoModel.from_pretrained(..., trust_remote_code=True) |
| 216 | ++ self.model = AutoModel.from_pretrained(...) |
| 217 | +``` |
| 218 | + |
| 219 | +<!-- END RULES REFERENCE --> |
| 220 | + |
| 221 | +## Suppressing violations |
| 222 | + |
| 223 | +If you need to suppress a rule violation, use one of the two options below. |
| 224 | + |
| 225 | +### Inline suppression |
| 226 | + |
| 227 | +Add a `# trf-ignore: RULE_ID` comment on the violating line. Include an explanation so reviewers understand why the suppression is justified. |
| 228 | + |
| 229 | +```py |
| 230 | +# trf-ignore: TRF011 — mask is derived from self.config, not the layer |
| 231 | +hidden_states = layer(hidden_states, attention_mask=mask_from_config) |
| 232 | +``` |
| 233 | + |
| 234 | +Don't use `trf-ignore` to silence violations that should be fixed in the code. |
| 235 | + |
| 236 | +### `allowlist_models` |
| 237 | + |
| 238 | +For models with legacy code that can't be fixed immediately, add the model's directory name to the relevant rule's `allowlist_models` list in `utils/mlinter/rules.toml`. |
| 239 | + |
| 240 | +```toml |
| 241 | +[rules.TRF004] |
| 242 | +allowlist_models = ["existing_model", "your_model_name"] |
| 243 | +``` |
0 commit comments