Skip to content

Commit ea96467

Browse files
committed
address review comments
Signed-off-by: Steve Han <sthan@nvidia.com>
1 parent bf5d0be commit ea96467

4 files changed

Lines changed: 23 additions & 23 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,6 @@ htmlcov/
3535

3636
# Distribution
3737
*.tar.gz
38+
39+
# CI artifacts
40+
*artifacts/

plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Literal
2222

2323
import nltk
24+
import yaml
2425
from nltk.tokenize import sent_tokenize
2526

2627
logger = logging.getLogger(__name__)
@@ -40,8 +41,6 @@ def load_multi_doc_manifest(manifest_path: Path | None) -> list[list[str]]:
4041
Returns:
4142
List of bundles, each a list of file-path strings.
4243
"""
43-
import yaml
44-
4544
if not manifest_path:
4645
return []
4746

@@ -160,19 +159,19 @@ def build_bundles(
160159
return [b for b in bundles if b]
161160

162161

163-
def group_chunks_by_doc(chunks: list[dict]) -> dict[str, list[tuple[int, dict]]]:
162+
def group_chunks_by_doc(chunks: list[dict]) -> dict[str, list[dict]]:
164163
"""Group chunks by their ``doc_id`` field."""
165-
grouped: dict[str, list[tuple[int, dict]]] = defaultdict(list)
166-
for idx, chunk in enumerate(chunks):
164+
grouped: dict[str, list[dict]] = defaultdict(list)
165+
for chunk in chunks:
167166
doc_id = chunk.get("doc_id", "default")
168-
grouped[doc_id].append((idx, chunk))
167+
grouped[doc_id].append(chunk)
169168
return dict(grouped)
170169

171170

172-
def format_section_chunks(indexed_chunks: list[tuple[int, dict]], section_number: int) -> str:
173-
"""Render a list of indexed chunks into a section string."""
171+
def format_section_chunks(section_chunks: list[dict], section_number: int) -> str:
172+
"""Render a list of chunks into a section string."""
174173
section_lines: list[str] = []
175-
for _, chunk in indexed_chunks:
174+
for chunk in section_chunks:
176175
text = chunk.get("text", "").strip()
177176
if not text:
178177
continue
@@ -203,8 +202,7 @@ def chunks_to_sections_sequential(chunks: list[dict], num_sections: int = 1) ->
203202
for i in range(num_sections):
204203
start_idx = i * section_size
205204
end_idx = (i + 1) * section_size if i < num_sections - 1 else total
206-
indexed_chunks = [(j, chunks[j]) for j in range(start_idx, end_idx)]
207-
section_text = format_section_chunks(indexed_chunks, i + 1)
205+
section_text = format_section_chunks(chunks[start_idx:end_idx], i + 1)
208206
if section_text:
209207
formatted_sections.append(section_text)
210208

@@ -222,9 +220,9 @@ def chunks_to_sections_doc_balanced(chunks: list[dict], num_sections: int = 1) -
222220

223221
chunk_sizes = {doc_id: max(1, math.ceil(len(entries) / num_sections)) for doc_id, entries in grouped.items()}
224222

225-
sections: list[list[tuple[int, dict]]] = []
223+
sections: list[list[dict]] = []
226224
for part_idx in range(num_sections):
227-
part_entries: list[tuple[int, dict]] = []
225+
part_entries: list[dict] = []
228226
for doc_id, entries in grouped.items():
229227
chunk_size = chunk_sizes[doc_id]
230228
start = part_idx * chunk_size
@@ -235,8 +233,8 @@ def chunks_to_sections_doc_balanced(chunks: list[dict], num_sections: int = 1) -
235233
sections.append(part_entries)
236234

237235
formatted_sections: list[str] = []
238-
for i, indexed_chunks in enumerate(sections):
239-
section_text = format_section_chunks(indexed_chunks, i + 1)
236+
for i, section_chunks in enumerate(sections):
237+
section_text = format_section_chunks(section_chunks, i + 1)
240238
if section_text:
241239
formatted_sections.append(section_text)
242240

@@ -254,7 +252,7 @@ def chunks_to_sections_interleaved(chunks: list[dict], num_sections: int = 1) ->
254252

255253
doc_iterators = {doc_id: deque(entries) for doc_id, entries in grouped.items()}
256254
doc_order = list(grouped.keys())
257-
interleaved: list[tuple[int, dict]] = []
255+
interleaved: list[dict] = []
258256

259257
while True:
260258
added = False
@@ -276,8 +274,7 @@ def chunks_to_sections_interleaved(chunks: list[dict], num_sections: int = 1) ->
276274
for i in range(num_sections):
277275
start_idx = i * section_size
278276
end_idx = (i + 1) * section_size if i < num_sections - 1 else total
279-
indexed_chunks = interleaved[start_idx:end_idx]
280-
section_text = format_section_chunks(indexed_chunks, i + 1)
277+
section_text = format_section_chunks(interleaved[start_idx:end_idx], i + 1)
281278
if section_text:
282279
formatted_sections.append(section_text)
283280

plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from typing import Any
1616

1717
import numpy as np
18+
from data_designer.config.errors import BuilderConfigurationError
1819
from data_designer.config.models import GenerationType
1920
from data_designer.engine.column_generators.generators.base import (
2021
ColumnGeneratorWithModelRegistry,
2122
GenerationStrategy,
2223
)
23-
from data_designer.engine.dataset_builders.errors import DatasetGenerationError
2424
from data_designer.engine.models.facade import ModelFacade
2525

2626
from data_designer_retrieval_sdg.config import EmbeddingDedupColumnConfig
@@ -65,14 +65,14 @@ def _validate(self) -> None:
6565
from the facade or a 400 from the embeddings endpoint.
6666
6767
Raises:
68-
DatasetGenerationError: When ``self.config.model_alias`` resolves
68+
BuilderConfigurationError: When ``self.config.model_alias`` resolves
6969
to a :class:`ModelConfig` whose inference parameters are not
7070
``EmbeddingInferenceParams``.
7171
"""
7272
super()._validate()
7373
model_config = self.get_model_config(model_alias=self.config.model_alias)
7474
if model_config.generation_type != GenerationType.EMBEDDING:
75-
raise DatasetGenerationError(
75+
raise BuilderConfigurationError(
7676
f"EmbeddingDedupColumnGenerator requires an embedding model, "
7777
f"but model alias {self.config.model_alias!r} resolves to a "
7878
f"{model_config.generation_type.value!r} model. Configure a "

plugins/data-designer-retrieval-sdg/tests/test_dedup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from unittest.mock import AsyncMock, MagicMock
88

99
import pytest
10+
from data_designer.config.errors import BuilderConfigurationError
1011
from data_designer.config.models import (
1112
ChatCompletionInferenceParams,
1213
EmbeddingInferenceParams,
1314
ModelConfig,
1415
)
15-
from data_designer.engine.dataset_builders.errors import DatasetGenerationError
1616

1717
from data_designer_retrieval_sdg.config import EmbeddingDedupColumnConfig
1818
from data_designer_retrieval_sdg.dedup import EmbeddingDedupColumnGenerator
@@ -189,7 +189,7 @@ def test_validate_rejects_chat_model() -> None:
189189
model="some/chat-model",
190190
inference_parameters=ChatCompletionInferenceParams(),
191191
)
192-
with pytest.raises(DatasetGenerationError, match="embed"):
192+
with pytest.raises(BuilderConfigurationError, match="embed"):
193193
gen._validate()
194194

195195

0 commit comments

Comments
 (0)