-
Notifications
You must be signed in to change notification settings - Fork 80
Expand file tree
/
Copy pathlongform.py
More file actions
111 lines (90 loc) · 3.12 KB
/
longform.py
File metadata and controls
111 lines (90 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# https://arxiv.org/pdf/2304.08460
# https://github.com/akoksal/LongForm/tree/main
import argparse
import asyncio
import json
import os
from dataclasses import dataclass
from typing import List
from dotenv import load_dotenv
from tqdm.asyncio import tqdm as tqdm_async
from graphgen.models import OpenAIClient
from graphgen.utils import compute_content_hash, create_event_loop
PROMPT_TEMPLATE = """Instruction: X
Output:{doc}
What kind of instruction could this be the answer to?
X:"""
@dataclass
class LongForm:
llm_client: OpenAIClient = None
max_concurrent: int = 1000
def generate(self, docs: List[List[dict]]) -> List[dict]:
loop = create_event_loop()
return loop.run_until_complete(self.async_generate(docs))
async def async_generate(self, docs: List[List[dict]]) -> dict:
final_results = {}
semaphore = asyncio.Semaphore(self.max_concurrent)
async def process_chunk(content: str):
async with semaphore:
question = await self.llm_client.generate_answer(content)
return {
compute_content_hash(question): {
"question": question,
"answer": content,
}
}
tasks = []
for doc in docs:
for chunk in doc:
tasks.append(process_chunk(chunk["content"]))
for result in tqdm_async(
asyncio.as_completed(tasks),
total=len(tasks),
desc="Generating using LongForm",
):
try:
qa = await result
final_results.update(qa)
except Exception as e: # pylint: disable=broad-except
print(f"Error: {e}")
return final_results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_file",
help="Raw context jsonl path.",
default="resources/input_examples/json_demo.json",
type=str,
)
parser.add_argument(
"--data_type",
help="Data type of input file. (Raw context or chunked context)",
choices=["raw", "chunked"],
default="raw",
type=str,
)
parser.add_argument(
"--output_file",
help="Output file path.",
default="cache/data/longform.json",
type=str,
)
args = parser.parse_args()
load_dotenv()
llm_client = OpenAIClient(
model_name=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
)
longform = LongForm(llm_client=llm_client)
if args.data_type == "raw":
with open(args.input_file, "r", encoding="utf-8") as f:
data = [json.loads(line) for line in f]
data = [[chunk] for chunk in data]
elif args.data_type == "chunked":
with open(args.input_file, "r", encoding="utf-8") as f:
data = json.load(f)
results = longform.generate(data)
# Save results
with open(args.output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=4, ensure_ascii=False)