11import datasets
22import tempfile
33import torch .nn as nn
4- from transformers import Trainer , TrainingArguments
5- from typing import Any
4+ from transformers import Trainer , TrainingArguments , PreTrainedTokenizerBase
5+ from typing import Any , Callable
6+ from functools import partial
7+
8+ from .data import (
9+ transform_reflogprob_mlm ,
10+ transform_reflogprob_clm ,
11+ )
12+ from .model import (
13+ compute_reflogprob_mlm ,
14+ compute_reflogprob_clm ,
15+ )
616
717
818def run_inference (
19+ model : nn .Module ,
20+ tokenizer : PreTrainedTokenizerBase , # TODO: create an adapter for this
21+ dataset : datasets .Dataset ,
22+ compute_fn : Callable [..., Any ],
23+ data_transform_fn : Callable [..., dict [str , Any ]] | None = None ,
24+ data_transform_on_the_fly : bool = False ,
25+ data_transform_kwargs : dict [str , Any ] | None = None ,
26+ inference_kwargs : dict [str , Any ] | None = None ,
27+ ) -> Any :
28+ processed_dataset = _process_dataset (
29+ dataset ,
30+ tokenizer ,
31+ data_transform_fn ,
32+ data_transform_on_the_fly ,
33+ data_transform_kwargs ,
34+ )
35+ return _run_inference (
36+ _ModelComputeFnWrapper (model , compute_fn ),
37+ processed_dataset ,
38+ ** (inference_kwargs or {}),
39+ )
40+
41+
42+ run_reflogprob_mlm = partial (
43+ run_inference ,
44+ compute_fn = compute_reflogprob_mlm ,
45+ data_transform_fn = transform_reflogprob_mlm ,
46+ )
47+
48+ run_reflogprob_clm = partial (
49+ run_inference ,
50+ compute_fn = compute_reflogprob_clm ,
51+ data_transform_fn = transform_reflogprob_clm ,
52+ )
53+
54+
55+ def _run_inference (
956 model : nn .Module ,
1057 dataset : datasets .Dataset ,
1158 ** kwargs : Any ,
@@ -29,7 +76,55 @@ def run_inference(
2976 """
3077 training_args = TrainingArguments (
3178 output_dir = tempfile .TemporaryDirectory ().name ,
32- ** kwargs ,
79+ ** ( kwargs or {}) ,
3380 )
3481 trainer = Trainer (model = model , args = training_args )
3582 return trainer .predict (test_dataset = dataset ).predictions
83+
84+
85+ class _ModelComputeFnWrapper (nn .Module ):
86+ def __init__ (self , model : nn .Module , compute_fn : Callable [..., Any ]):
87+ super ().__init__ ()
88+ self .model = model
89+ self .compute_fn = compute_fn
90+
91+ def forward (self , * args : Any , ** kwargs : Any ) -> Any :
92+ return self .compute_fn (self .model , * args , ** kwargs )
93+
94+
95+ def _process_dataset (
96+ dataset : datasets .Dataset ,
97+ tokenizer : PreTrainedTokenizerBase ,
98+ data_transform_fn : Callable [..., dict [str , Any ]] | None = None ,
99+ data_transform_on_the_fly : bool = False ,
100+ data_transform_kwargs : dict [str , Any ] | None = None ,
101+ ) -> datasets .Dataset :
102+ if data_transform_fn is None :
103+ return dataset
104+ data_transform_fn = partial (data_transform_fn , tokenizer = tokenizer )
105+ if data_transform_on_the_fly :
106+ return dataset .with_transform (
107+ _make_batch_transform (data_transform_fn ),
108+ ** data_transform_kwargs ,
109+ )
110+ return dataset .map (
111+ data_transform_fn ,
112+ ** data_transform_kwargs ,
113+ )
114+
115+
116+ def _make_batch_transform (
117+ transform_fn : Callable [[dict [str , Any ]], dict [str , Any ]],
118+ ) -> Callable [[dict [str , list [Any ]]], dict [str , list [Any ]]]:
119+ def batch_transform_fn (batch : dict [str , list [Any ]]) -> dict [str , list [Any ]]:
120+ # Convert batch format to list of examples
121+ examples = [dict (zip (batch .keys (), values )) for values in zip (* batch .values ())]
122+ # Apply transform to each example
123+ transformed_examples = [transform_fn (example ) for example in examples ]
124+ # Convert back to batch format
125+ return {
126+ key : [ex [key ] for ex in transformed_examples ]
127+ for key in transformed_examples [0 ].keys ()
128+ }
129+
130+ return batch_transform_fn
0 commit comments