Skip to content

Commit a5ed2f4

Browse files
Add missing field declarations to dataclass child classes
Co-authored-by: ChenZiHong-Gavin <58508660+ChenZiHong-Gavin@users.noreply.github.com>
1 parent 5202d73 commit a5ed2f4

6 files changed

Lines changed: 16 additions & 4 deletions

File tree

graphgen/models/generator/aggregated_generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
22
from typing import Any
33

4-
from graphgen.bases import BaseGenerator
4+
from graphgen.bases import BaseGenerator, BaseLLMClient
55
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
66
from graphgen.utils import compute_content_hash, detect_main_language, logger
77

@@ -15,6 +15,8 @@ class AggregatedGenerator(BaseGenerator):
1515
2. question generation: Generate relevant questions based on the rephrased text.
1616
"""
1717

18+
llm_client: BaseLLMClient = None
19+
1820
@staticmethod
1921
def build_prompt(
2022
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]

graphgen/models/generator/atomic_generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from dataclasses import dataclass
22
from typing import Any
33

4-
from graphgen.bases import BaseGenerator
4+
from graphgen.bases import BaseGenerator, BaseLLMClient
55
from graphgen.templates import ATOMIC_GENERATION_PROMPT
66
from graphgen.utils import compute_content_hash, detect_main_language, logger
77

88

99
@dataclass
1010
class AtomicGenerator(BaseGenerator):
11+
llm_client: BaseLLMClient = None
12+
1113
@staticmethod
1214
def build_prompt(
1315
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]

graphgen/models/generator/cot_generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from dataclasses import dataclass
22
from typing import Any
33

4-
from graphgen.bases import BaseGenerator
4+
from graphgen.bases import BaseGenerator, BaseLLMClient
55
from graphgen.templates import COT_GENERATION_PROMPT
66
from graphgen.utils import compute_content_hash, detect_main_language, logger
77

88

99
@dataclass
1010
class CoTGenerator(BaseGenerator):
11+
llm_client: BaseLLMClient = None
12+
1113
@staticmethod
1214
def build_prompt(
1315
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]

graphgen/models/generator/multi_hop_generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from dataclasses import dataclass
22
from typing import Any
33

4-
from graphgen.bases import BaseGenerator
4+
from graphgen.bases import BaseGenerator, BaseLLMClient
55
from graphgen.templates import MULTI_HOP_GENERATION_PROMPT
66
from graphgen.utils import compute_content_hash, detect_main_language, logger
77

88

99
@dataclass
1010
class MultiHopGenerator(BaseGenerator):
11+
llm_client: BaseLLMClient = None
12+
1113
@staticmethod
1214
def build_prompt(
1315
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]

graphgen/models/tokenizer/hf_tokenizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
@dataclass
1010
class HFTokenizer(BaseTokenizer):
11+
model_name: str = "cl100k_base"
12+
1113
def __post_init__(self):
1214
self.enc = AutoTokenizer.from_pretrained(self.model_name)
1315

graphgen/models/tokenizer/tiktoken_tokenizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
@dataclass
1010
class TiktokenTokenizer(BaseTokenizer):
11+
model_name: str = "cl100k_base"
12+
1113
def __post_init__(self):
1214
self.enc = tiktoken.get_encoding(self.model_name)
1315

0 commit comments

Comments
 (0)