-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy path__init__.py
More file actions
87 lines (77 loc) · 2.65 KB
/
Copy path__init__.py
File metadata and controls
87 lines (77 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
from pathlib import Path
from typing import Optional
import torch
from langchain_community.vectorstores import Chroma
from langchain_core.embeddings import Embeddings
from langchain_huggingface import (HuggingFaceEmbeddings,
HuggingFaceEndpointEmbeddings)
from langchain_openai import OpenAIEmbeddings
from data_generation.metadata_csv_loader import MetaDataCSVLoader
from data_generation.uniprot.csv_generator import generate_uniprot_csv
def upload_to_chromadb(
embeddings_dir: str,
file: str,
embedding_table: str,
hf_model: Optional[str] = None,
device: Optional[str] = None,
) -> Chroma:
metadata_columns: dict[str, list] = {
"uniprot_data": [
"gene_names",
"short_protein_name",
"full_protein_name",
"protein_family",
"biological_pathways",
],
}
loader = MetaDataCSVLoader(
file_path=file,
metadata_columns=metadata_columns[embedding_table],
encoding="utf-8",
)
docs = loader.load()
print(f"Loaded {len(docs)} documents from {file}")
embeddings_instance: Embeddings
if hf_model is None: # Use OpenAI
print("Using OpenAI embeddings")
embeddings_instance = OpenAIEmbeddings(
model="text-embedding-3-large",
chunk_size=500,
show_progress_bar=True,
)
elif hf_model.startswith("openai/text-embedding-"):
embeddings_instance = OpenAIEmbeddings(
model=hf_model[len("openai/") :],
chunk_size=500,
show_progress_bar=True,
)
elif "HUGGINGFACEHUB_API_TOKEN" in os.environ:
embeddings_instance = HuggingFaceEndpointEmbeddings(
huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
model=hf_model,
)
else:
if device == "cuda":
torch.cuda.empty_cache()
embeddings_instance = HuggingFaceEmbeddings(
model_name=hf_model,
model_kwargs={"device": device, "trust_remote_code": True},
encode_kwargs={"batch_size": 12, "normalize_embeddings": False},
)
return Chroma.from_documents(
documents=docs,
embedding=embeddings_instance,
persist_directory=os.path.join(embeddings_dir, embedding_table),
)
def generate_uniprot_embeddings(
embedding_path: Path,
hf_model: Optional[str] = None,
device: Optional[str] = None,
**_,
) -> None:
csv_path = generate_uniprot_csv(embedding_path)
db = upload_to_chromadb(
str(embedding_path), str(csv_path), "uniprot_data", hf_model, device
)
print(db._collection.count())