Skip to content

Commit 3ba8e1a

Browse files
committed
fix(code-quality): apply pre-commit ruff fixes
1 parent e09b385 commit 3ba8e1a

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed

examples/dataset/conversation_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,13 @@ def normalize_messages(example: dict[str, Any], idx: int) -> dict[str, Any]:
186186
msg["tool_calls"] = m["tool_calls"]
187187
normalized.append(msg)
188188
elif role == "tool":
189-
normalized.append({
190-
"role": "tool",
191-
"content": m.get("content") or "",
192-
"tool_call_id": m.get("tool_call_id", ""),
193-
})
189+
normalized.append(
190+
{
191+
"role": "tool",
192+
"content": m.get("content") or "",
193+
"tool_call_id": m.get("tool_call_id", ""),
194+
}
195+
)
194196
elif role == "developer":
195197
# Map developer-role messages to system per OpenAI schema conventions.
196198
normalized.append({"role": "system", "content": m.get("content") or ""})

examples/dataset/make_nemotron_ptv2_dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,14 @@
5959
import os
6060
from pathlib import Path
6161

62-
from datasets import concatenate_datasets, load_dataset
63-
6462
from conversation_utils import (
6563
has_tool_turns,
6664
load_augmentations,
6765
make_augment_fn,
6866
normalize_messages,
6967
strip_assistant_turns,
7068
)
69+
from datasets import concatenate_datasets, load_dataset
7170

7271
logging.basicConfig(
7372
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S"

examples/dataset/make_nemotron_ptv3_dataset.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,20 @@
7878
import argparse
7979
import logging
8080
import os
81-
from dataclasses import dataclass, field
81+
from dataclasses import dataclass
8282
from pathlib import Path
8383
from typing import Any
8484

8585
import yaml
86+
from conversation_utils import (
87+
has_tool_turns,
88+
load_augmentations,
89+
make_augment_fn,
90+
normalize_messages,
91+
strip_assistant_turns,
92+
)
8693
from datasets import concatenate_datasets, load_dataset
8794

88-
from conversation_utils import has_tool_turns, load_augmentations, make_augment_fn, normalize_messages, strip_assistant_turns
89-
9095
logging.basicConfig(
9196
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S"
9297
)
@@ -238,17 +243,14 @@ def main() -> None:
238243
for spec in dataset_specs:
239244
logger.info("Loading %s (augment=%s)", spec.repo_id, spec.augment)
240245
for split in spec.splits:
241-
ds = load_split(spec.repo_id, split, spec.cap_per_split, args.num_proc,
242-
args.mode)
246+
ds = load_split(spec.repo_id, split, spec.cap_per_split, args.num_proc, args.mode)
243247
if args.mode == "generate" and not spec.augment:
244248
non_augmentable_parts.append(ds)
245249
else:
246250
augmentable_parts.append(ds)
247251

248252
augmentable = concatenate_datasets(augmentable_parts) if augmentable_parts else None
249-
non_augmentable = (
250-
concatenate_datasets(non_augmentable_parts) if non_augmentable_parts else None
251-
)
253+
non_augmentable = concatenate_datasets(non_augmentable_parts) if non_augmentable_parts else None
252254
if augmentable is not None:
253255
logger.info("Augmentable rows: %d", len(augmentable))
254256
if non_augmentable is not None:

0 commit comments

Comments
 (0)