Skip to content

Commit a84ce55

Browse files
Dynamically add boot function to bridge (#964)
* Dynamically add boot function to bridge * Fix imports * Fix circular import * Keep mypy happy * Edit tests to use TransformerBridge.boot_transformers instead of boot * Remove exposing transformers in model_bridge init * removed pretrained and replaced with nn module as param * removed extra import and ran import * added lazy import again --------- Co-authored-by: Bryce Meyer <bryce13950@gmail.com>
1 parent 54baf48 commit a84ce55

6 files changed

Lines changed: 25 additions & 16 deletions

File tree

tests/integration/model_bridge/test_bridge_integration.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import torch
99

1010
from transformer_lens.ActivationCache import ActivationCache
11-
from transformer_lens.boot import boot
11+
from transformer_lens.model_bridge import TransformerBridge
1212

1313

1414
def test_model_initialization():
1515
"""Test that the model can be initialized correctly."""
1616
model_name = "gpt2" # Use a smaller model for testing
17-
bridge = boot(model_name)
17+
bridge = TransformerBridge.boot_transformers(model_name)
1818

1919
assert bridge is not None, "Bridge should be initialized"
2020
assert bridge.tokenizer is not None, "Tokenizer should be initialized"
@@ -24,7 +24,7 @@ def test_model_initialization():
2424
def test_text_generation():
2525
"""Test basic text generation functionality."""
2626
model_name = "gpt2" # Use a smaller model for testing
27-
bridge = boot(model_name)
27+
bridge = TransformerBridge.boot_transformers(model_name)
2828

2929
if bridge.tokenizer.pad_token is None:
3030
bridge.tokenizer.pad_token = bridge.tokenizer.eos_token
@@ -39,7 +39,7 @@ def test_text_generation():
3939
def test_hooks():
4040
"""Test that hooks can be added and removed correctly."""
4141
model_name = "gpt2" # Use a smaller model for testing
42-
bridge = boot(model_name)
42+
bridge = TransformerBridge.boot_transformers(model_name)
4343

4444
if bridge.tokenizer.pad_token is None:
4545
bridge.tokenizer.pad_token = bridge.tokenizer.eos_token
@@ -77,7 +77,7 @@ def test_hook(tensor, hook):
7777
def test_cache():
7878
"""Test that the cache functionality works correctly."""
7979
model_name = "gpt2" # Use a smaller model for testing
80-
bridge = boot(model_name)
80+
bridge = TransformerBridge.boot_transformers(model_name)
8181

8282
if bridge.tokenizer.pad_token is None:
8383
bridge.tokenizer.pad_token = bridge.tokenizer.eos_token
@@ -105,7 +105,7 @@ def test_cache():
105105
def test_component_access():
106106
"""Test that model components can be accessed correctly."""
107107
model_name = "gpt2" # Use a smaller model for testing
108-
bridge = boot(model_name)
108+
bridge = TransformerBridge.boot_transformers(model_name)
109109

110110
# Test accessing various components
111111
assert hasattr(bridge, "embed"), "Bridge should have embed component"

transformer_lens/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from .BertNextSentencePrediction import BertNextSentencePrediction
1919
from . import head_detector
2020
from . import loading_from_pretrained as loading
21-
from . import boot
2221
from . import patching
2322
from . import train
2423

@@ -41,5 +40,4 @@
4140
"EasyTransformerConfig",
4241
"EasyTransformerKeyValueCache",
4342
"EasyTransformerKeyValueCacheEntry",
44-
"boot",
4543
]

transformer_lens/model_bridge/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
This module provides functionality to bridge between different model architectures.
44
"""
55

6-
from transformer_lens.factories.architecture_adapter_factory import (
7-
ArchitectureAdapterFactory,
8-
)
96
from transformer_lens.model_bridge.architecture_adapter import (
107
ArchitectureAdapter,
118
)
9+
1210
from transformer_lens.model_bridge.bridge import (
1311
TransformerBridge,
1412
)
@@ -39,9 +37,11 @@
3937
TransformerLensPath,
4038
)
4139

40+
import transformer_lens.model_bridge.sources.transformers
41+
42+
4243
__all__ = [
4344
"ArchitectureAdapter",
44-
"ArchitectureAdapterFactory",
4545
"TransformerBridge",
4646
"AttentionBridge",
4747
"BlockBridge",

transformer_lens/model_bridge/architecture_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, cast
77

88
import torch
9-
from transformers.modeling_utils import PreTrainedModel
9+
from torch import nn
1010

1111
from transformer_lens.model_bridge.conversion_utils.conversion_steps import (
1212
WeightConversionSet,
@@ -307,7 +307,7 @@ def translate_transformer_lens_path(
307307
return remote_path.split(".")[-1]
308308
return remote_path
309309

310-
def convert_weights(self, hf_model: PreTrainedModel) -> dict[str, torch.Tensor]:
310+
def convert_weights(self, hf_model: nn.Module) -> dict[str, torch.Tensor]:
311311
"""Convert the weights from the HuggingFace format to the HookedTransformer format.
312312
313313
Args:
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Sources module.
2+
3+
This module provides functionality to load and convert models from HuggingFace to TransformerLens format.
4+
"""

transformer_lens/boot.py renamed to transformer_lens/model_bridge/sources/transformers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Boot module for TransformerLens.
1+
"""Transformers module for TransformerLens.
22
33
This module provides functionality to load and convert models from HuggingFace to TransformerLens format.
44
"""
@@ -14,7 +14,6 @@
1414
PreTrainedTokenizerBase,
1515
)
1616

17-
from transformer_lens.model_bridge import ArchitectureAdapterFactory
1817
from transformer_lens.model_bridge.bridge import TransformerBridge
1918
from transformer_lens.utils import get_tokenizer_with_bos
2019

@@ -38,6 +37,11 @@ def boot(
3837
Returns:
3938
The bridge to the loaded model.
4039
"""
40+
# Lazy import to avoid circular import
41+
from transformer_lens.factories.architecture_adapter_factory import (
42+
ArchitectureAdapterFactory,
43+
)
44+
4145
hf_config = AutoConfig.from_pretrained(model_name, **kwargs)
4246
adapter = ArchitectureAdapterFactory.select_architecture_adapter(hf_config)
4347
default_config = adapter.default_cfg
@@ -125,3 +129,6 @@ def setup_tokenizer(
125129
tokenizer.bos_token = tokenizer.eos_token
126130

127131
return tokenizer
132+
133+
134+
setattr(TransformerBridge, "boot_transformers", staticmethod(boot))

0 commit comments

Comments
 (0)