-
Notifications
You must be signed in to change notification settings - Fork 82
Expand file tree
/
Copy pathbase_generator.py
More file actions
77 lines (67 loc) · 2.36 KB
/
base_generator.py
File metadata and controls
77 lines (67 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from abc import ABC, abstractmethod
from typing import Any
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
class BaseGenerator(ABC):
"""
Generate QAs based on given prompts.
"""
def __init__(self, llm_client: BaseLLMWrapper):
self.llm_client = llm_client
@staticmethod
@abstractmethod
def build_prompt(
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
"""Build prompt for LLM based on the given batch"""
@staticmethod
@abstractmethod
def parse_response(response: str) -> list[dict]:
"""Parse the LLM response and return the generated QAs"""
async def generate(
self,
batch: tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
],
) -> list[dict]:
"""
Generate QAs based on a given batch.
:param batch
:return: QA pairs
"""
prompt = self.build_prompt(batch)
response = await self.llm_client.generate_answer(prompt)
qa_pairs = self.parse_response(response) # generate one or more QA pairs
return qa_pairs
@staticmethod
def format_generation_results(
result: dict, output_data_format: str
) -> dict[str, Any]:
question = result.get("question", "")
answer = result.get("answer", "")
if "options" in result and result["options"]:
options = result["options"]
options_str = "\n".join(
[f"{key}. {options[key]}" for key in sorted(options.keys())]
)
question += f"\nOptions:\n{options_str}"
if output_data_format == "Alpaca":
return {
"instruction": question,
"input": "",
"output": answer,
}
if output_data_format == "Sharegpt":
return {
"conversations": [
{"from": "human", "value": question},
{"from": "gpt", "value": answer},
]
}
if output_data_format == "ChatML":
return {
"messages": [
{"role": "user", "content": question},
{"role": "assistant", "content": answer},
]
}
raise ValueError(f"Unknown output data format: {output_data_format}")