Skip to content

Commit 4271c9c

Browse files
committed
fix(examples): address PR #1176 review feedback
- Fix module docstring in conversation_utils.py: clarify that user_suffix augmentations apply to all user turns, not just the last - Map developer role to system in normalize_messages() to preserve developer instructions rather than silently dropping them - Raise ValueError instead of warning+return when combined dataset is empty in make_nemotron_ptv3_dataset.py - Re-raise all exceptions in query.py LLM.generate() so datasets.map() halts on any failure (not only connection errors) - Make tool role explicit in synthesize(); raise ValueError for truly unknown roles to catch typos/unsupported roles early - Fix local file format detection in query.py: use "parquet" loader for .parquet files instead of hardcoded "json" - Replace eval with direct argv in query.sh to eliminate shell injection Signed-off-by: chenhany <chenhany@nvidia.com>
1 parent 663d0e1 commit 4271c9c

4 files changed

Lines changed: 17 additions & 15 deletions

File tree

examples/dataset/conversation_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
skeleton to the target model turn-by-turn, appending each generated response before
3030
sending the next user turn, so the model produces coherent multi-turn continuations.
3131
32-
Augmentations are applied only to the *last* user message (the new prompt), not to
33-
earlier user turns that are already part of the established context.
32+
Augmentations (``user_suffix``) are applied to *all* user messages so that the
33+
language or style instruction is present at every turn — important for multi-turn
34+
synthetic generation where the model must maintain the requested style throughout.
3435
"""
3536

3637
import logging
@@ -190,5 +191,8 @@ def normalize_messages(example: dict[str, Any], idx: int) -> dict[str, Any]:
190191
"content": m.get("content") or "",
191192
"tool_call_id": m.get("tool_call_id", ""),
192193
})
194+
elif role == "developer":
195+
# Map developer-role messages to system per OpenAI schema conventions.
196+
normalized.append({"role": "system", "content": m.get("content") or ""})
193197
# other roles (e.g. function, unknown) are dropped
194198
return {"messages": normalized}

examples/dataset/make_nemotron_ptv3_dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,7 @@ def main() -> None:
271271
parts_to_combine.append(non_augmentable)
272272

273273
if not parts_to_combine:
274-
logger.warning("No data to combine — all rows were filtered out. Exiting.")
275-
return
274+
raise ValueError("No data to combine — all rows were filtered out.")
276275

277276
combined = concatenate_datasets(parts_to_combine)
278277
logger.info("Combined (pre-shuffle): %d rows", len(combined))

tools/launcher/common/query.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,9 @@ def generate(self, messages, verbose=False, **chat_template_kwargs):
7878
new_message = {"role": "assistant", "content": new_message}
7979
except Exception as e:
8080
print(e)
81-
8281
if "Connection error" in str(e):
8382
early_termination = True
84-
raise # propagate so datasets.map() halts the shard
85-
86-
new_message = None
83+
raise # always propagate so datasets.map() halts the shard
8784

8885
return new_message
8986

@@ -168,10 +165,11 @@ def synthesize(data):
168165
elif role == "assistant":
169166
# Original assistant messages are not used — the model generates fresh responses.
170167
pass
171-
else:
172-
# Skip unknown roles (e.g. tool) — agentic datasets include tool turns
173-
# that are not sent to the generation model.
168+
elif role == "tool":
169+
# Tool turns are not sent to the generation model — skip them.
174170
pass
171+
else:
172+
raise ValueError(f"Unexpected message role {role!r} in conversation.")
175173

176174
# Restore the full reasoning trace for the last generated assistant turn.
177175
if enable_thinking and last_full_message is not None:
@@ -185,7 +183,9 @@ def synthesize(data):
185183

186184
# Support both HF Hub repo IDs and local file paths (.jsonl, .json, .parquet, etc.)
187185
if os.path.isfile(args.data):
188-
dataset = load_dataset("json", data_files=args.data, split=args.data_split)
186+
ext = os.path.splitext(args.data)[1].lower()
187+
fmt = "parquet" if ext == ".parquet" else "json"
188+
dataset = load_dataset(fmt, data_files={"train": args.data}, split=args.data_split)
189189
else:
190190
dataset = load_dataset(args.data, split=args.data_split)
191191

tools/launcher/common/vllm/query.sh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,8 @@ while true; do
119119
done
120120

121121
pip3 install -q datasets openai 2>/dev/null || true
122-
cmd="python3 common/query.py http://localhost:8000/v1 ${MODEL} ${QUERY_ARGS[*]}"
123-
echo "Running command: $cmd"
124-
eval $cmd
122+
echo "Running: python3 common/query.py http://localhost:8000/v1 ${MODEL} ${QUERY_ARGS[*]}"
123+
python3 common/query.py http://localhost:8000/v1 "${MODEL}" "${QUERY_ARGS[@]}"
125124
echo "Main process exit"
126125

127126
kill $SERVER_PID

0 commit comments

Comments
 (0)