-
Notifications
You must be signed in to change notification settings - Fork 61
Expand file tree
/
Copy pathmodel.py
More file actions
41 lines (32 loc) · 1.54 KB
/
model.py
File metadata and controls
41 lines (32 loc) · 1.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from model_factory import ModelFactory
class F2LLM:
def __init__(self,
model_path,
max_seq_length=512,
args=None
):
self.args = args
self.dtype = torch.bfloat16
self.device = None # set after accelerator.prepare
# Use model factory to create adapter
self.adapter = ModelFactory.create_adapter(model_path, max_seq_length, args)
# Load model and tokenizer
self.lm = self.adapter.load_model()
self.lm.config.use_cache = False
self.tokenizer = self.adapter.load_tokenizer()
self.max_seq_length = max_seq_length
def set_device(self):
self.device = self.lm.device
def forward(self, batch):
bs = batch['bs']
num_hard_neg = int((len(batch['input_ids']) - 2*bs) / bs)
outputs = self.lm(batch['input_ids'],
batch['attention_mask'],
)
passage_features_all_tokens = outputs.last_hidden_state
return {
'query_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs)]),
'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]),
'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1)
}