Skip to content

Commit e8ff2fe

Browse files
Merge pull request #2 from Open-Athena/refactor-model
Refactor model - MLM and CLM
2 parents fcd1918 + 4f235ca commit e8ff2fe

9 files changed

Lines changed: 373 additions & 118 deletions

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ uv venv
1313
source .venv/bin/activate
1414
```
1515

16+
Pre-requisites:
17+
- Pytorch (e.g. `uv pip install torch`)
18+
1619
Install the package:
1720

1821
```bash
@@ -32,7 +35,7 @@ Run the example script:
3235

3336
```bash
3437
source .venv/bin/activate
35-
python examples/plantcad_evolutionary_constraint.py
38+
python examples/marin_evolutionary_constraint.py
3639
```
3740

3841
## Development Setup
@@ -42,9 +45,6 @@ To set up the development environment with linting, formatting, type checking, a
4245
```bash
4346
# Install development dependencies
4447
uv pip install --group dev
45-
46-
# Install both main package and dev tools
47-
uv pip install -e . --group dev
4848
```
4949

5050
## Development Tools
@@ -53,7 +53,7 @@ uv pip install -e . --group dev
5353

5454
```bash
5555
# Run all pre-commit hooks (linting, formatting, type checking)
56-
pre-commit run --all-files
56+
pre-commit run
5757
```
5858

5959
### Running Tests

biofoundation/data.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from typing import Any
33

44

5+
NUCLEOTIDES = list("ACGT")
6+
7+
58
def transform_reflogprob_mlm(
69
example: dict[str, Any],
710
tokenizer: PreTrainedTokenizerBase,
8-
pos: int,
9-
seq_col: str = "seq",
1011
) -> dict[str, Any]:
1112
"""Transform a sequence example for reference log probability MLM inference.
1213
@@ -19,24 +20,39 @@ def transform_reflogprob_mlm(
1920
example: Dictionary containing the sequence data. Must have a key matching
2021
`seq_col` that contains the input sequence.
2122
tokenizer: HuggingFace tokenizer for converting text to token IDs.
22-
pos: Position in the sequence to mask (0-indexed).
23-
seq_col: Key in the example dictionary that contains the sequence.
24-
Defaults to "seq".
2523
2624
Returns:
2725
Dictionary with three keys:
28-
- input_ids_BL: Token IDs with the specified position masked
29-
- pos_B: The masked position
30-
- ref_B: The reference token ID that was at the masked position
26+
- input_ids: Token IDs with the specified position masked
27+
- pos: The masked position
28+
- ref: The reference token ID that was at the masked position
3129
3230
Example:
3331
>>> example = {"seq": "ATCG"}
3432
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
3533
>>> result = transform_reflogprob_mlm(example, tokenizer, 1)
3634
>>> print(result)
37-
{'input_ids_BL': tensor([...]), 'pos_B': 1, 'ref_B': 3}
35+
{'input_ids': tensor([...]), 'pos': 1, 'ref': 3}
3836
"""
39-
input_ids = tokenizer(example[seq_col], return_tensors="pt")["input_ids"][0]
37+
pos = example["pos"]
38+
assert example["seq"][pos] in NUCLEOTIDES
39+
input_ids = tokenizer(example["seq"], return_tensors="pt")["input_ids"][0]
4040
ref = input_ids[pos].item()
4141
input_ids[pos] = tokenizer.mask_token_id
42-
return dict(input_ids_BL=input_ids, pos_B=pos, ref_B=ref)
42+
return dict(input_ids=input_ids, pos=pos, ref=ref)
43+
44+
45+
def transform_reflogprob_clm(
46+
example: dict[str, Any],
47+
tokenizer: PreTrainedTokenizerBase,
48+
) -> dict[str, Any]:
49+
pos = example["pos"]
50+
assert example["seq"][pos] in NUCLEOTIDES
51+
input_ids = tokenizer(example["seq"], return_tensors="pt")["input_ids"][0]
52+
ref = input_ids[pos].item()
53+
# Create 4 copies of the input sequence
54+
new_input_ids = input_ids.unsqueeze(0).repeat(len(NUCLEOTIDES), 1)
55+
for i, nuc in enumerate(NUCLEOTIDES):
56+
new_input_ids[i, pos] = tokenizer.encode(nuc)[0]
57+
ref = NUCLEOTIDES.index(example["seq"][pos])
58+
return dict(input_ids=new_input_ids, ref=ref)

biofoundation/inference.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,58 @@
11
import datasets
22
import tempfile
33
import torch.nn as nn
4-
from transformers import Trainer, TrainingArguments
5-
from typing import Any
4+
from transformers import Trainer, TrainingArguments, PreTrainedTokenizerBase
5+
from typing import Any, Callable
6+
from functools import partial
7+
8+
from .data import (
9+
transform_reflogprob_mlm,
10+
transform_reflogprob_clm,
11+
)
12+
from .model import (
13+
compute_reflogprob_mlm,
14+
compute_reflogprob_clm,
15+
)
616

717

818
def run_inference(
19+
model: nn.Module,
20+
tokenizer: PreTrainedTokenizerBase, # TODO: create an adapter for this
21+
dataset: datasets.Dataset,
22+
compute_fn: Callable[..., Any],
23+
data_transform_fn: Callable[..., dict[str, Any]] | None = None,
24+
data_transform_on_the_fly: bool = False,
25+
data_transform_kwargs: dict[str, Any] | None = None,
26+
inference_kwargs: dict[str, Any] | None = None,
27+
) -> Any:
28+
processed_dataset = _process_dataset(
29+
dataset,
30+
tokenizer,
31+
data_transform_fn,
32+
data_transform_on_the_fly,
33+
data_transform_kwargs,
34+
)
35+
return _run_inference(
36+
_ModelComputeFnWrapper(model, compute_fn),
37+
processed_dataset,
38+
**(inference_kwargs or {}),
39+
)
40+
41+
42+
run_reflogprob_mlm = partial(
43+
run_inference,
44+
compute_fn=compute_reflogprob_mlm,
45+
data_transform_fn=transform_reflogprob_mlm,
46+
)
47+
48+
run_reflogprob_clm = partial(
49+
run_inference,
50+
compute_fn=compute_reflogprob_clm,
51+
data_transform_fn=transform_reflogprob_clm,
52+
)
53+
54+
55+
def _run_inference(
956
model: nn.Module,
1057
dataset: datasets.Dataset,
1158
**kwargs: Any,
@@ -29,7 +76,55 @@ def run_inference(
2976
"""
3077
training_args = TrainingArguments(
3178
output_dir=tempfile.TemporaryDirectory().name,
32-
**kwargs,
79+
**(kwargs or {}),
3380
)
3481
trainer = Trainer(model=model, args=training_args)
3582
return trainer.predict(test_dataset=dataset).predictions
83+
84+
85+
class _ModelComputeFnWrapper(nn.Module):
86+
def __init__(self, model: nn.Module, compute_fn: Callable[..., Any]):
87+
super().__init__()
88+
self.model = model
89+
self.compute_fn = compute_fn
90+
91+
def forward(self, *args: Any, **kwargs: Any) -> Any:
92+
return self.compute_fn(self.model, *args, **kwargs)
93+
94+
95+
def _process_dataset(
96+
dataset: datasets.Dataset,
97+
tokenizer: PreTrainedTokenizerBase,
98+
data_transform_fn: Callable[..., dict[str, Any]] | None = None,
99+
data_transform_on_the_fly: bool = False,
100+
data_transform_kwargs: dict[str, Any] | None = None,
101+
) -> datasets.Dataset:
102+
if data_transform_fn is None:
103+
return dataset
104+
data_transform_fn = partial(data_transform_fn, tokenizer=tokenizer)
105+
if data_transform_on_the_fly:
106+
return dataset.with_transform(
107+
_make_batch_transform(data_transform_fn),
108+
**data_transform_kwargs,
109+
)
110+
return dataset.map(
111+
data_transform_fn,
112+
**data_transform_kwargs,
113+
)
114+
115+
116+
def _make_batch_transform(
117+
transform_fn: Callable[[dict[str, Any]], dict[str, Any]],
118+
) -> Callable[[dict[str, list[Any]]], dict[str, list[Any]]]:
119+
def batch_transform_fn(batch: dict[str, list[Any]]) -> dict[str, list[Any]]:
120+
# Convert batch format to list of examples
121+
examples = [dict(zip(batch.keys(), values)) for values in zip(*batch.values())]
122+
# Apply transform to each example
123+
transformed_examples = [transform_fn(example) for example in examples]
124+
# Convert back to batch format
125+
return {
126+
key: [ex[key] for ex in transformed_examples]
127+
for key in transformed_examples[0].keys()
128+
}
129+
130+
return batch_transform_fn

0 commit comments

Comments
 (0)