Skip to content

Commit 5aa4eaa

Browse files
committed
Fix PR 166 runtime cleanup and chatbot test collection
1 parent e8ecce8 commit 5aa4eaa

4 files changed

Lines changed: 43 additions & 12 deletions

File tree

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
Shows all RAG retrievals, tool calls, and final responses.
55
66
Usage:
7-
python test_agent.py # interactive mode
8-
python test_agent.py -b # batch: run 20-question PyTC eval
9-
python test_agent.py "your question" # single question
7+
python agent_cli.py # interactive mode
8+
python agent_cli.py -b # batch: run 20-question PyTC eval
9+
python agent_cli.py "your question" # single question
1010
"""
1111

1212
import os
@@ -69,7 +69,7 @@ def run_batch():
6969
print(f"{'#'*80}")
7070

7171

72-
def test_single(question: str):
72+
def run_single(question: str):
7373
"""Test the agent with a single question."""
7474
print(f"\n{'='*80}")
7575
print(f"QUESTION: {question}")
@@ -127,6 +127,6 @@ def interactive_mode():
127127
elif args.interactive:
128128
interactive_mode()
129129
elif args.question:
130-
test_single(" ".join(args.question))
130+
run_single(" ".join(args.question))
131131
else:
132132
interactive_mode()

server_pytc/services/model.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,18 @@ def _matches_pytc_mode_process(cmdline: list[str], mode: str) -> bool:
444444
if script_path not in normalized:
445445
return False
446446

447-
try:
447+
# Support both the legacy "--mode train/test" CLI and the newer
448+
# "--inference" flag that PyTC now uses to switch the entrypoint into test mode.
449+
if "--mode" in normalized:
448450
mode_index = normalized.index("--mode")
449-
except ValueError:
450-
return False
451+
return mode_index + 1 < len(normalized) and normalized[mode_index + 1] == mode
451452

452-
return mode_index + 1 < len(normalized) and normalized[mode_index + 1] == mode
453+
is_inference = "--inference" in normalized
454+
if mode == "test":
455+
return is_inference
456+
if mode == "train":
457+
return not is_inference
458+
return False
453459

454460

455461
def stop_pytc_processes(mode: str):

tests/test_pytc_runtime_routes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ def test_write_temp_config_uses_origin_parent_for_relative_bases(self):
129129
config_path = model_service._write_temp_config(
130130
"foo: bar\n",
131131
"training",
132-
config_origin_path="tutorials/neuron_snemi.yaml",
132+
config_origin_path="configs/SNEMI/SNEMI-Base.yaml",
133133
)
134134
written_path = pathlib.Path(config_path)
135135
expected_parent = (
136-
model_service._project_root() / "pytorch_connectomics" / "tutorials"
136+
model_service._project_root() / "pytorch_connectomics" / "configs" / "SNEMI"
137137
)
138138

139139
self.assertTrue(written_path.exists())

tests/test_worker_model_service.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class WorkerModelServiceTests(unittest.TestCase):
99
def tearDown(self):
1010
model_service.cleanup_temp_files()
1111

12-
def test_matches_pytc_mode_process_with_config_between_script_and_mode(self):
12+
def test_matches_pytc_mode_process_supports_legacy_cli(self):
1313
script_path = str(model_service._pytc_script_path())
1414
cmdline = [
1515
"/usr/bin/python",
@@ -23,6 +23,31 @@ def test_matches_pytc_mode_process_with_config_between_script_and_mode(self):
2323
self.assertTrue(model_service._matches_pytc_mode_process(cmdline, "train"))
2424
self.assertFalse(model_service._matches_pytc_mode_process(cmdline, "test"))
2525

26+
def test_matches_pytc_mode_process_supports_current_cli(self):
27+
script_path = str(model_service._pytc_script_path())
28+
train_cmdline = [
29+
"/usr/bin/python",
30+
script_path,
31+
"--config-file",
32+
"/tmp/runtime.yaml",
33+
]
34+
inference_cmdline = [
35+
"/usr/bin/python",
36+
script_path,
37+
"--config-file",
38+
"/tmp/runtime.yaml",
39+
"--inference",
40+
]
41+
42+
self.assertTrue(model_service._matches_pytc_mode_process(train_cmdline, "train"))
43+
self.assertFalse(model_service._matches_pytc_mode_process(train_cmdline, "test"))
44+
self.assertTrue(
45+
model_service._matches_pytc_mode_process(inference_cmdline, "test")
46+
)
47+
self.assertFalse(
48+
model_service._matches_pytc_mode_process(inference_cmdline, "train")
49+
)
50+
2651
def test_cleanup_temp_files_is_scoped_by_kind(self):
2752
training_file = tempfile.NamedTemporaryFile(delete=False)
2853
inference_file = tempfile.NamedTemporaryFile(delete=False)

0 commit comments

Comments
 (0)