|
| 1 | +# Adding a new model family |
| 2 | + |
| 3 | +This guide describes the expected steps for adding a new model family to |
| 4 | +`colpali_engine.models`. A model family is a backbone-specific package such as |
| 5 | +`qwen3`, `gemma3`, `idefics3`, or `paligemma` that exposes one or more retriever |
| 6 | +variants. |
| 7 | + |
| 8 | +Most families contain: |
| 9 | + |
| 10 | +- A `Col*` late-interaction model that returns one normalized vector per token |
| 11 | + and is scored with MaxSim. |
| 12 | +- Optionally, a `Bi*` dense retrieval model that pools to one normalized vector |
| 13 | + per input. |
| 14 | +- One processor per variant, responsible for image/text formatting and scoring. |
| 15 | + |
| 16 | +## 1. Choose the package layout |
| 17 | + |
| 18 | +Create a family directory under `colpali_engine/models`: |
| 19 | + |
| 20 | +```text |
| 21 | +colpali_engine/models/<family>/ |
| 22 | + __init__.py |
| 23 | + col<family>/ |
| 24 | + __init__.py |
| 25 | + modeling_col<family>.py |
| 26 | + processing_col<family>.py |
| 27 | + bi<family>/ |
| 28 | + __init__.py |
| 29 | + modeling_bi<family>.py |
| 30 | + processing_bi<family>.py |
| 31 | +``` |
| 32 | + |
| 33 | +Only add `bi<family>/` if the family supports a dense bi-encoder variant. |
| 34 | +Follow the naming style used by nearby families. For example, Qwen variants use |
| 35 | +`ColQwen3`, `ColQwen3Processor`, `BiQwen3`, and `BiQwen3Processor`. |
| 36 | + |
| 37 | +## 2. Implement the Col model |
| 38 | + |
| 39 | +The `Col*` class should usually inherit from the corresponding Transformers |
| 40 | +backbone model, not from a generic wrapper. See |
| 41 | +`colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py` and |
| 42 | +`colpali_engine/models/gemma3/colgemma3/modeling_colgemma.py` for current |
| 43 | +patterns. |
| 44 | + |
| 45 | +The class should define: |
| 46 | + |
| 47 | +- `main_input_name = "doc_input_ids"` for Transformers compatibility. |
| 48 | +- A retrieval projection layer, usually `self.custom_text_proj`. |
| 49 | +- `self.dim`, the embedding size returned by the retriever head. |
| 50 | +- `self.padding_side`, matching the processor and backbone requirements. |
| 51 | +- `self.mask_non_image_embeddings` when the family supports image-only masking. |
| 52 | +- A `forward` method that returns a `torch.Tensor` shaped |
| 53 | + `(batch_size, sequence_length, dim)`. |
| 54 | + |
| 55 | +The forward pass should: |
| 56 | + |
| 57 | +1. Accept the batch produced by the processor for both images and text. |
| 58 | +2. Adapt processor-specific image tensors before calling the backbone when the |
| 59 | + backbone expects a flattened visual-token layout. |
| 60 | +3. Call the parent model with `use_cache=False`, `output_hidden_states=True`, |
| 61 | + and `return_dict=True`. |
| 62 | +4. Project the last hidden states with `custom_text_proj`. |
| 63 | +5. L2-normalize the projected embeddings on the last dimension. |
| 64 | +6. Multiply by `attention_mask.unsqueeze(-1)` so padding tokens score as zero. |
| 65 | +7. If `mask_non_image_embeddings=True`, zero non-image token embeddings for |
| 66 | + image batches. |
| 67 | + |
| 68 | +Expose patch metadata needed by interpretability when the backbone supports it: |
| 69 | + |
| 70 | +```python |
| 71 | +@property |
| 72 | +def patch_size(self) -> int: |
| 73 | + return self.visual.config.patch_size |
| 74 | + |
| 75 | +@property |
| 76 | +def spatial_merge_size(self) -> int: |
| 77 | + return self.visual.config.spatial_merge_size |
| 78 | +``` |
| 79 | + |
| 80 | +Adjust the properties to match the backbone config. Some models only expose |
| 81 | +`patch_size`. |
| 82 | + |
| 83 | +## 3. Handle checkpoint key mappings |
| 84 | + |
| 85 | +Adapter checkpoints often contain PEFT or backbone-specific prefixes that do not |
| 86 | +match the retriever class. Add a `_checkpoint_conversion_mapping` to the model |
| 87 | +when needed: |
| 88 | + |
| 89 | +```python |
| 90 | +_checkpoint_conversion_mapping = { |
| 91 | + r"^base_model\.model\.custom_text_proj": "custom_text_proj", |
| 92 | + r"^model\.visual": "visual", |
| 93 | + r"^model\.language_model": "language_model", |
| 94 | + r"^model\.": "", |
| 95 | +} |
| 96 | +``` |
| 97 | + |
| 98 | +Override `from_pretrained` to pass the mapping through `key_mapping`: |
| 99 | + |
| 100 | +```python |
| 101 | +@classmethod |
| 102 | +def from_pretrained(cls, *args, **kwargs): |
| 103 | + key_mapping = kwargs.pop("key_mapping", None) |
| 104 | + if key_mapping is None: |
| 105 | + key_mapping = dict(getattr(super(), "_checkpoint_conversion_mapping", {})) |
| 106 | + key_mapping.update(getattr(cls, "_checkpoint_conversion_mapping", {})) |
| 107 | + return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping) |
| 108 | +``` |
| 109 | + |
| 110 | +If Transformers requires registration for the model type, register the mapping |
| 111 | +with `register_checkpoint_conversion_mapping`, as in the Qwen and ModernVBert |
| 112 | +implementations. |
| 113 | + |
| 114 | +Add tests to `tests/models/test_checkpoint_key_mappings.py` for every custom |
| 115 | +mapping that rewrites adapter keys. |
| 116 | + |
| 117 | +## 4. Implement the processor |
| 118 | + |
| 119 | +Processors should inherit from `BaseVisualRetrieverProcessor` and the matching |
| 120 | +Transformers processor: |
| 121 | + |
| 122 | +```python |
| 123 | +class ColNewFamilyProcessor(BaseVisualRetrieverProcessor, NewFamilyProcessor): |
| 124 | + ... |
| 125 | +``` |
| 126 | + |
| 127 | +The processor must implement: |
| 128 | + |
| 129 | +- `process_images(self, images)`: converts PIL images to model-ready batches. |
| 130 | +- `process_texts(self, texts)`: converts text inputs to model-ready batches. |
| 131 | +- `score(self, qs, ps, device=None, **kwargs)`: delegates to |
| 132 | + `score_multi_vector` for `Col*` models. |
| 133 | +- `get_n_patches(...)`: returns `(n_patches_x, n_patches_y)` for |
| 134 | + interpretability. |
| 135 | + |
| 136 | +Set prompt and token attributes when the backbone needs them: |
| 137 | + |
| 138 | +```python |
| 139 | +visual_prompt_prefix = "..." |
| 140 | +query_prefix = "..." |
| 141 | +query_augmentation_token = "..." |
| 142 | +image_token = "..." |
| 143 | +``` |
| 144 | + |
| 145 | +Use the backbone's chat template or special tokens consistently with the |
| 146 | +checkpoint used for training. Also set `self.tokenizer.padding_side` in |
| 147 | +`__init__` when the model requires left or right padding. |
| 148 | + |
| 149 | +If the processor pads per-image visual tensors for distributed training, the |
| 150 | +model forward pass must undo that padding before calling the backbone. The Qwen |
| 151 | +processors and models are the reference pattern for this. |
| 152 | + |
| 153 | +## 5. Implement an optional Bi model |
| 154 | + |
| 155 | +Add a `Bi*` model when the family needs dense single-vector retrieval. The class |
| 156 | +normally shares the same backbone and processor conventions, but the forward pass |
| 157 | +returns `(batch_size, hidden_size_or_dim)` instead of per-token embeddings. |
| 158 | + |
| 159 | +Support the local pooling styles used elsewhere when possible: |
| 160 | + |
| 161 | +- `cls`: first token. |
| 162 | +- `last`: last token. |
| 163 | +- `mean`: attention-mask-weighted mean. |
| 164 | + |
| 165 | +Normalize the pooled embedding before returning it. Its processor `score` method |
| 166 | +should delegate to `score_single_vector`. |
| 167 | + |
| 168 | +## 6. Export the new classes |
| 169 | + |
| 170 | +Wire the imports at every package level: |
| 171 | + |
| 172 | +```python |
| 173 | +# colpali_engine/models/<family>/col<family>/__init__.py |
| 174 | +from .modeling_col<family> import ColNewFamily |
| 175 | +from .processing_col<family> import ColNewFamilyProcessor |
| 176 | + |
| 177 | +# colpali_engine/models/<family>/__init__.py |
| 178 | +from .col<family> import ColNewFamily, ColNewFamilyProcessor |
| 179 | +from .bi<family> import BiNewFamily, BiNewFamilyProcessor |
| 180 | + |
| 181 | +# colpali_engine/models/__init__.py |
| 182 | +from .<family> import BiNewFamily, BiNewFamilyProcessor, ColNewFamily, ColNewFamilyProcessor |
| 183 | +``` |
| 184 | + |
| 185 | +Keep these exports stable because users import models directly from |
| 186 | +`colpali_engine.models`. |
| 187 | + |
| 188 | +## 7. Add tests |
| 189 | + |
| 190 | +Create tests under `tests/models/<family>/<variant>/`, following existing model |
| 191 | +families. |
| 192 | + |
| 193 | +Processor tests should verify: |
| 194 | + |
| 195 | +- `from_pretrained` returns the custom processor class. |
| 196 | +- `process_images` returns expected keys and tensor batch dimensions. |
| 197 | +- `process_texts` returns text tensors with the expected batch size. |
| 198 | +- `process_queries` remains compatible with the legacy evaluator path. |
| 199 | + |
| 200 | +Model tests should verify: |
| 201 | + |
| 202 | +- `from_pretrained` returns the custom model class. |
| 203 | +- Image forward pass returns a tensor with shape |
| 204 | + `(batch_size, sequence_length, model.dim)` for `Col*`. |
| 205 | +- Query forward pass returns the same embedding dimension. |
| 206 | +- Retrieval smoke tests rank matching image/query pairs correctly when a small |
| 207 | + public checkpoint is available. |
| 208 | + |
| 209 | +Use `@pytest.mark.slow` for tests that download or run full checkpoints. |
| 210 | + |
| 211 | +Run the targeted tests before opening a PR: |
| 212 | + |
| 213 | +```bash |
| 214 | +pytest tests/models/<family> |
| 215 | +pytest tests/models/test_checkpoint_key_mappings.py |
| 216 | +``` |
| 217 | + |
| 218 | +Run the linter before submitting: |
| 219 | + |
| 220 | +```bash |
| 221 | +ruff check . |
| 222 | +``` |
| 223 | + |
| 224 | +## 8. Add training and example entry points when needed |
| 225 | + |
| 226 | +If the family is trainable from this repository, add a config under |
| 227 | +`scripts/configs/<family>/` and update any training scripts that need to import |
| 228 | +the new classes. Keep the config names aligned with the model class names, for |
| 229 | +example `train_col<family>_model.py` or `train_col<family>_model.yaml`. |
| 230 | + |
| 231 | +If interpretability is supported, add an example under |
| 232 | +`examples/interpretability/<variant>/` and make sure `get_n_patches` plus |
| 233 | +`get_image_mask` return masks in the same token order as the model embeddings. |
| 234 | + |
| 235 | +## 9. Update user-facing documentation |
| 236 | + |
| 237 | +When the checkpoint is public and supported, update the model table in |
| 238 | +`README.md` with: |
| 239 | + |
| 240 | +- The Hugging Face model id. |
| 241 | +- The base backbone. |
| 242 | +- The license. |
| 243 | +- Notes about dynamic resolution, embedding dimension, or masking behavior. |
| 244 | +- Whether the model is currently supported. |
| 245 | + |
| 246 | +Add usage snippets only if loading or preprocessing differs from the existing |
| 247 | +quick start pattern. |
| 248 | + |
| 249 | +## Review checklist |
| 250 | + |
| 251 | +Before submitting the change, check that: |
| 252 | + |
| 253 | +- The model and processor can be imported from `colpali_engine.models`. |
| 254 | +- `process_images`, `process_texts`, and `process_queries` all work. |
| 255 | +- `model(**processor.process_images(...))` and |
| 256 | + `model(**processor.process_queries(...))` return normalized tensors. |
| 257 | +- Padding embeddings are zeroed for `Col*` outputs. |
| 258 | +- Checkpoint mappings load LoRA or adapter checkpoints without manual key edits. |
| 259 | +- Slow tests are marked, and fast tests do not download large checkpoints unless |
| 260 | + the existing family tests already do the same. |
0 commit comments