-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathembed_docs.py
More file actions
123 lines (103 loc) · 3.58 KB
/
Copy pathembed_docs.py
File metadata and controls
123 lines (103 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# /// script
# description = "Embed text from pdfs"
# requires-python = ">=3.12, <3.13"
# dependencies = ["daft[openai]>=0.7.10", "pymupdf", "sentence-transformers", "spacy", "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl", "python-dotenv"]
# ///
import pymupdf
import spacy
import daft
from daft import DataType, col
from daft.functions import embed_text, unnest
@daft.func(
return_dtype=DataType.list(
DataType.struct(
{
"page_number": DataType.uint8(),
"page_text": DataType.string(),
"page_image_bytes": DataType.binary(),
}
)
)
)
def extract_pdf(file: daft.File):
content = []
with file.to_tempfile() as tmp:
doc = pymupdf.Document(filename=str(tmp.name), filetype="pdf")
for pno, page in enumerate(doc):
row = {
"page_number": pno,
"page_text": page.get_text("text"),
"page_image_bytes": page.get_pixmap().tobytes(),
}
content.append(row)
return content
@daft.cls()
class SpaCyChunkText:
def __init__(self, model="en_core_web_sm"):
self.nlp = spacy.load(model)
@daft.method(
return_dtype=DataType.list(
DataType.struct(
{
"sent_id": DataType.int32(),
"sent_start": DataType.int32(),
"sent_end": DataType.int32(),
"sent_text": DataType.string(),
"sent_ents": DataType.list(DataType.string()),
}
)
)
)
def chunk_text(self, text: list[str]):
doc = self.nlp(text)
return [
{
"sent_id": i,
"sent_start": sent.start,
"sent_end": sent.end,
"sent_text": sent.text,
"sent_ents": [ent.text for ent in sent.ents] if sent.ents else [],
}
for i, sent in enumerate(doc.sents)
]
if __name__ == "__main__":
from dotenv import load_dotenv
TEXT_EMBED_MODEL = "google/paligemma2-3b-mix-448"
IMAGE_EMBED_MODEL = "google/embeddinggemma-300m"
MAX_DOCS = 5
load_dotenv()
Chunker = SpaCyChunkText("en_core_web_sm")
# Config
uri = "hf://datasets/Eventual-Inc/sample-files/papers/*.pdf"
# Download the spacy model
spacy.load("en_core_web_sm")
# Discover and download pdfs
df = (
daft.from_glob_path(uri)
.with_column("documents", col("path").download())
# Extract text from pdf pages
.with_column("pages", extract_pdf(col("documents")))
.explode("pages")
.select(col("path"), unnest(col("pages")))
).collect()
df = df.with_column(
"images",
col("page_image_bytes").decode_image().convert_image("RGB").resize(256, 256),
)
df.explain()
df = (
df
# Chunk page text into sentences
.with_column("text_normalized", col("text").normalize(nfd_unicode=True, white_space=True))
.with_column("sentences", Chunker.chunk_text(col("text_normalized")))
.explode("sentences")
.select(col("path"), col("page_number"), unnest(col("sentences")))
.where(col("sent_end") - col("sent_start") > 1) # remove sentences that are too short
# Embed sentences
.with_column(
"text_embedding",
embed_text(col("sent_text"), model="text-embedding-3-small"),
)
)
df.write_parquet(".data/embed_text")
df.show()