|
78 | 78 | import argparse |
79 | 79 | import logging |
80 | 80 | import os |
81 | | -from dataclasses import dataclass, field |
| 81 | +from dataclasses import dataclass |
82 | 82 | from pathlib import Path |
83 | 83 | from typing import Any |
84 | 84 |
|
85 | 85 | import yaml |
| 86 | +from conversation_utils import ( |
| 87 | + has_tool_turns, |
| 88 | + load_augmentations, |
| 89 | + make_augment_fn, |
| 90 | + normalize_messages, |
| 91 | + strip_assistant_turns, |
| 92 | +) |
86 | 93 | from datasets import concatenate_datasets, load_dataset |
87 | 94 |
|
88 | | -from conversation_utils import has_tool_turns, load_augmentations, make_augment_fn, normalize_messages, strip_assistant_turns |
89 | | - |
90 | 95 | logging.basicConfig( |
91 | 96 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S" |
92 | 97 | ) |
@@ -238,17 +243,14 @@ def main() -> None: |
238 | 243 | for spec in dataset_specs: |
239 | 244 | logger.info("Loading %s (augment=%s)", spec.repo_id, spec.augment) |
240 | 245 | for split in spec.splits: |
241 | | - ds = load_split(spec.repo_id, split, spec.cap_per_split, args.num_proc, |
242 | | - args.mode) |
| 246 | + ds = load_split(spec.repo_id, split, spec.cap_per_split, args.num_proc, args.mode) |
243 | 247 | if args.mode == "generate" and not spec.augment: |
244 | 248 | non_augmentable_parts.append(ds) |
245 | 249 | else: |
246 | 250 | augmentable_parts.append(ds) |
247 | 251 |
|
248 | 252 | augmentable = concatenate_datasets(augmentable_parts) if augmentable_parts else None |
249 | | - non_augmentable = ( |
250 | | - concatenate_datasets(non_augmentable_parts) if non_augmentable_parts else None |
251 | | - ) |
| 253 | + non_augmentable = concatenate_datasets(non_augmentable_parts) if non_augmentable_parts else None |
252 | 254 | if augmentable is not None: |
253 | 255 | logger.info("Augmentable rows: %d", len(augmentable)) |
254 | 256 | if non_augmentable is not None: |
|
0 commit comments