-
Notifications
You must be signed in to change notification settings - Fork 81
Expand file tree
/
Copy pathvqa_generator.py
More file actions
136 lines (128 loc) · 4.62 KB
/
vqa_generator.py
File metadata and controls
136 lines (128 loc) · 4.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import re
from typing import Any
from graphgen.bases import BaseGenerator
from graphgen.templates import VQA_GENERATION_PROMPT
from graphgen.utils import compute_content_hash, detect_main_language, logger
class VQAGenerator(BaseGenerator):
@staticmethod
def build_prompt(
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
nodes, edges = batch
entities_str = "\n".join(
[
f"{index + 1}. {node[0]}: {node[1]['description']}"
for index, node in enumerate(nodes)
]
)
relationships_str = "\n".join(
[
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
for index, edge in enumerate(edges)
]
)
language = detect_main_language(entities_str + relationships_str)
prompt = VQA_GENERATION_PROMPT[language].format(
entities=entities_str, relationships=relationships_str
)
return prompt
@staticmethod
def parse_response(response: str) -> Any:
"""
Parse the LLM response and return the generated QAs
:param response
:return: QA pairs
"""
qa_pairs = {}
pattern = r"<question>(.*?)</question>\s*<answer>(.*?)</answer>"
matches = re.findall(pattern, response, re.DOTALL)
if matches:
for question, answer in matches:
question = question.strip().strip('"').strip("'")
answer = answer.strip().strip('"').strip("'")
logger.debug("Question: %s", question)
logger.debug("Answer: %s", answer)
qa_pairs[compute_content_hash(question)] = {
"question": question,
"answer": answer,
}
else:
logger.warning("Error parsing the response %s", response)
return qa_pairs
async def generate(
self,
batch: tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
],
) -> dict[str, Any]:
"""
Generate QAs based on a given batch.
:param batch
:return: QA pairs
"""
result = {}
prompt = self.build_prompt(batch)
response = await self.llm_client.generate_answer(prompt)
qa_pairs = self.parse_response(response) # generate one or more QA pairs
nodes, _ = batch
for node in nodes:
node_data = node[1]
if "image_data" in node_data and node_data["image_data"]:
img_path = node_data["image_data"]["img_path"]
for qa in qa_pairs.values():
qa["img_path"] = img_path
result.update(qa_pairs)
return result
@staticmethod
def format_generation_results(
results: list[dict], output_data_format: str
) -> list[dict[str, Any]]:
if output_data_format == "Alpaca":
results = [
{
"instruction": v["question"],
"input": "",
"output": v["answer"],
"image": v.get("img_path", ""),
}
for item in results
for k, v in item.items()
]
elif output_data_format == "Sharegpt":
results = [
{
"conversations": [
{
"from": "human",
"value": [
{"text": v["question"], "image": v.get("img_path", "")}
],
},
{"from": "gpt", "value": [{"text": v["answer"]}]},
]
}
for item in results
for k, v in item.items()
]
elif output_data_format == "ChatML":
results = [
{
"messages": [
{
"role": "user",
"content": [
{"text": v["question"], "image": v.get("img_path", "")}
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": v["answer"]}],
},
]
}
for item in results
for k, v in item.items()
]
else:
raise ValueError(f"Unknown output data format: {output_data_format}")
return results