Skip to content

Commit 48807bb

Browse files
fix: ensure SentenceTransformer models load directly to device
1 parent 79eba56 commit 48807bb

3 files changed

Lines changed: 10 additions & 10 deletions

File tree

rankify/retrievers/bge_reasoner_retriever.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(
9999
def _load_model(self):
100100
# ReasonEmbed SentenceTransformer models
101101
if self.model_id == "bge-reasoner-embed":
102-
self.model = SentenceTransformer(self.checkpoint or "BAAI/bge-reasoner-embed-qwen3-8b-0923", trust_remote_code=True)
102+
self.model = SentenceTransformer(self.checkpoint or "BAAI/bge-reasoner-embed-qwen3-8b-0923", trust_remote_code=True, device=self.device)
103103
self.model = self.model.to(self.device)
104104
else:
105105
raise ValueError(f"The model {self.model_id} is not supported")

rankify/retrievers/diver_dense_retriever.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,18 +149,18 @@ def __init__(
149149
def _load_model(self):
150150
# SentenceTransformer models
151151
if self.model_id == "bge":
152-
self.model = SentenceTransformer('BAAI/bge-large-en-v1.5')
152+
self.model = SentenceTransformer('BAAI/bge-large-en-v1.5', device=self.device)
153153
elif self.model_id == "sbert":
154-
self.model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
154+
self.model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device=self.device)
155155
elif self.model_id == "contriever_st":
156-
self.model = SentenceTransformer('nishimoto/contriever-sentencetransformer')
156+
self.model = SentenceTransformer('nishimoto/contriever-sentencetransformer', device=self.device)
157157
elif self.model_id == "nomic":
158-
self.model = SentenceTransformer(self.checkpoint or "nomic-ai/nomic-embed-text-v1", trust_remote_code=True)
158+
self.model = SentenceTransformer(self.checkpoint or "nomic-ai/nomic-embed-text-v1", trust_remote_code=True, device=self.device)
159159
elif self.model_id == "inst-l":
160-
self.model = SentenceTransformer("hkunlp/instructor-large")
160+
self.model = SentenceTransformer("hkunlp/instructor-large", device=self.device)
161161
self.model.max_seq_length = self.doc_max_length
162162
elif self.model_id == "inst-xl":
163-
self.model = SentenceTransformer("hkunlp/instructor-xl")
163+
self.model = SentenceTransformer("hkunlp/instructor-xl", device=self.device)
164164
self.model.max_seq_length = self.doc_max_length
165165

166166
# HF AutoModel models (sf, e5, rader)

rankify/retrievers/reasonembed_retriever.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,13 @@ def __init__(
102102
def _load_model(self):
103103
# ReasonEmbed SentenceTransformer models
104104
if self.model_id == "qwen3-8b":
105-
self.model = SentenceTransformer(self.checkpoint or "hanhainebula/reason-embed-qwen3-8b-0928", trust_remote_code=True)
105+
self.model = SentenceTransformer(self.checkpoint or "hanhainebula/reason-embed-qwen3-8b-0928", trust_remote_code=True, device=self.device)
106106
self.model = self.model.to(self.device)
107107
elif self.model_id == "qwen3-4b":
108-
self.model = SentenceTransformer(self.checkpoint or "hanhainebula/reason-embed-qwen3-4b-0928", trust_remote_code=True)
108+
self.model = SentenceTransformer(self.checkpoint or "hanhainebula/reason-embed-qwen3-4b-0928", trust_remote_code=True, device=self.device)
109109
self.model = self.model.to(self.device)
110110
elif self.model_id == "llama-8b":
111-
self.model = SentenceTransformer(self.checkpoint or "hanhainebula/reason-embed-llama-3.1-8b-0928", trust_remote_code=True)
111+
self.model = SentenceTransformer(self.checkpoint or "hanhainebula/reason-embed-llama-3.1-8b-0928", trust_remote_code=True, device=self.device)
112112
self.model = self.model.to(self.device)
113113
else:
114114
raise ValueError(f"The model {self.model_id} is not supported")

0 commit comments

Comments
 (0)