-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmodel.py
More file actions
303 lines (250 loc) · 11.1 KB
/
model.py
File metadata and controls
303 lines (250 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import os
import signal
import subprocess
import tempfile
import psutil
import atexit
import sys
import pathlib
# TODO: Global process tracking for proper cleanup
_training_process = None
_inference_process = None
_temp_files = []
import sys
import pathlib
# ... (imports)
def start_training(dict: dict):
print("\n========== MODEL.PY: START_TRAINING FUNCTION CALLED ==========")
global _training_process
print(f"[MODEL.PY] Input dict keys: {list(dict.keys())}")
print(f"[MODEL.PY] Arguments: {dict.get('arguments', {})}")
print(
f"[MODEL.PY] *** Log path from request: {dict.get('logPath', 'NOT PROVIDED')}"
)
print(
f"[MODEL.PY] Training config length: {len(dict.get('trainingConfig', ''))} chars"
)
# Parse YAML to show what OUTPUT_PATH is being used
try:
import yaml
config_obj = yaml.safe_load(dict.get("trainingConfig", ""))
dataset_output_path = config_obj.get("DATASET", {}).get(
"OUTPUT_PATH", "NOT SET"
)
print(f"[MODEL.PY] *** YAML DATASET.OUTPUT_PATH: {dataset_output_path}")
print(
f"[MODEL.PY] NOTE: PyTorch Connectomics will write checkpoints to OUTPUT_PATH"
)
print(f"[MODEL.PY] NOTE: TensorBoard logs should go to logPath")
except Exception as e:
print(f"[MODEL.PY] Could not parse YAML to check OUTPUT_PATH: {e}")
# TODO: Stop existing training process if running
if _training_process and _training_process.poll() is None:
print("[MODEL.PY] Existing training process detected, stopping it first...")
stop_training()
# Use absolute path relative to this file
# server_pytc/services/model.py -> server_pytc/ -> pytc-client/ -> pytorch_connectomics/scripts/main.py
print("[MODEL.PY] Resolving script path...")
current_dir = pathlib.Path(__file__).parent.parent.parent
print(f"[MODEL.PY] Current dir (project root): {current_dir}")
script_path = current_dir / "pytorch_connectomics" / "scripts" / "main.py"
print(f"[MODEL.PY] Script path: {script_path}")
if not script_path.exists():
print(f"[MODEL.PY] ✗ ERROR: Training script not found at {script_path}")
raise FileNotFoundError(f"Training script not found at {script_path}")
else:
print(f"[MODEL.PY] ✓ Training script found")
print(f"[MODEL.PY] Python executable: {sys.executable}")
command = [sys.executable, str(script_path)]
print(f"[MODEL.PY] Processing command-line arguments...")
for key, value in dict.get("arguments", {}).items():
if value is not None:
print(f"[MODEL.PY] Adding --{key} {value}")
command.extend([f"--{key}", str(value)])
# TODO: Write the value to a temporary file and track it for cleanup
print("[MODEL.PY] Creating temporary YAML config file...")
temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", suffix=".yaml")
config_content = dict["trainingConfig"]
print(f"[MODEL.PY] Writing config ({len(config_content)} chars) to temp file...")
temp_file.write(config_content)
temp_filepath = temp_file.name
temp_file.close()
_temp_files.append(temp_filepath)
print(f"[MODEL.PY] ✓ Temp config file created at: {temp_filepath}")
# Show first few lines of the temp file for debugging
with open(temp_filepath, "r") as f:
first_lines = "".join(f.readlines()[:20])
print(f"[MODEL.PY] Temp file preview (first 20 lines):\n{first_lines}\n")
command.extend(["--config-file", str(temp_filepath)])
# TODO: Execute the command using subprocess.Popen for proper async handling
print(f"[MODEL.PY] Final command: {' '.join(command)}")
print("[MODEL.PY] Starting subprocess...")
try:
_training_process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout
text=True,
bufsize=1, # Line buffered
cwd=str(current_dir), # Set working directory
)
print(
f"[MODEL.PY] ✓ Training process started with PID: {_training_process.pid}"
)
# Start a thread to read and log subprocess output
import threading
def log_subprocess_output():
print(
f"[MODEL.PY] === Training subprocess output (PID {_training_process.pid}) ==="
)
try:
for line in _training_process.stdout:
print(f"[TRAINING:{_training_process.pid}] {line.rstrip()}")
# Get exit code
_training_process.wait()
print(
f"[MODEL.PY] === Training subprocess finished with exit code: {_training_process.returncode} ==="
)
except Exception as e:
print(f"[MODEL.PY] Error reading subprocess output: {e}")
output_thread = threading.Thread(target=log_subprocess_output, daemon=True)
output_thread.start()
# Initialize TensorBoard to monitor the OUTPUT_PATH where PyTorch Connectomics writes logs
# PyTorch Connectomics writes logs to {OUTPUT_PATH}/log{timestamp}/
output_path = dict.get("outputPath")
log_path = dict.get("logPath")
print(f"[MODEL.PY] *** Output path from request: {output_path}")
print(
f"[MODEL.PY] *** Log path from request: {log_path} (for compatibility only)"
)
if output_path:
print(f"[MODEL.PY] *** Initializing TensorBoard to monitor: {output_path}")
print(
f"[MODEL.PY] NOTE: PyTorch Connectomics writes logs to {{OUTPUT_PATH}}/log{{timestamp}}/"
)
print(
f"[MODEL.PY] NOTE: TensorBoard will automatically find event files in subdirectories"
)
initialize_tensorboard(output_path)
print(f"[MODEL.PY] ✓ TensorBoard initialized for directory: {output_path}")
else:
print(
f"[MODEL.PY] ⚠ WARNING: No outputPath provided, TensorBoard not initialized"
)
result = {"status": "started", "pid": _training_process.pid}
print(f"[MODEL.PY] Returning: {result}")
print("========== MODEL.PY: END OF START_TRAINING ==========\n")
return result
except Exception as e:
print(
f"[MODEL.PY] ✗ ERROR starting training process: {type(e).__name__}: {str(e)}"
)
import traceback
print(traceback.format_exc())
# Cleanup temp file if process failed to start
if os.path.exists(temp_filepath):
print(f"[MODEL.PY] Cleaning up temp file: {temp_filepath}")
os.unlink(temp_filepath)
_temp_files.remove(temp_filepath)
print("========== MODEL.PY: END OF START_TRAINING (WITH ERROR) ==========\n")
raise
def stop_process_by_name(process_name):
"""Stop processes by name using psutil for better reliability"""
try:
for proc in psutil.process_iter(["pid", "name", "cmdline"]):
try:
if process_name in " ".join(proc.info["cmdline"] or []):
print(
f"Terminating process {proc.info['pid']}: {' '.join(proc.info['cmdline'])}"
)
proc.terminate()
proc.wait(
timeout=10
) # Wait up to 10 seconds for graceful termination
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.TimeoutExpired):
# Process already terminated or we don't have permission
continue
except Exception as e:
print(f"Error stopping processes by name '{process_name}': {e}")
def cleanup_temp_files():
"""Clean up temporary files created during training/inference"""
global _temp_files
for temp_file in _temp_files[:]: # Create a copy to iterate over
try:
if os.path.exists(temp_file):
os.unlink(temp_file)
print(f"Cleaned up temp file: {temp_file}")
_temp_files.remove(temp_file)
except Exception as e:
print(f"Error cleaning up temp file {temp_file}: {e}")
def stop_training():
global _training_process
# TODO: Stop the tracked training process first
if _training_process and _training_process.poll() is None:
try:
print(f"Terminating training process PID: {_training_process.pid}")
_training_process.terminate()
_training_process.wait(timeout=10)
except subprocess.TimeoutExpired:
print("Force killing training process...")
_training_process.kill()
_training_process.wait()
except Exception as e:
print(f"Error stopping training process: {e}")
finally:
_training_process = None
# Stop any remaining processes by name as fallback
stop_process_by_name("python pytorch_connectomics/scripts/main.py")
stop_tensorboard()
cleanup_temp_files()
return {"status": "stopped"}
tensorboard_url = None
def initialize_tensorboard(logPath):
print(f"[MODEL.PY] initialize_tensorboard called with logPath: {logPath}")
from tensorboard import program
tb = program.TensorBoard()
# tb.configure(argv=[None, "--logdir", "./logs"])
try:
print(f"[MODEL.PY] Configuring TensorBoard with logdir: {logPath}")
tb.configure(argv=[None, "--logdir", logPath, "--host", "0.0.0.0"])
tensorboard_url = tb.launch()
print(f"[MODEL.PY] ✓ TensorBoard is running at {tensorboard_url}")
except Exception as e:
tensorboard_url = "http://localhost:6006/"
print(
f"[MODEL.PY] ⚠ TensorBoard fallback to {tensorboard_url} due to error: {e}"
)
# return str(url)
def get_tensorboard():
return tensorboard_url
def stop_tensorboard():
stop_process_by_name("tensorboard")
def start_inference(dict: dict):
# Use absolute path relative to this file
current_dir = pathlib.Path(__file__).parent.parent.parent
script_path = current_dir / "pytorch_connectomics" / "scripts" / "main.py"
if not script_path.exists():
print(f"Error: Inference script not found at {script_path}")
raise FileNotFoundError(f"Inference script not found at {script_path}")
command = [sys.executable, str(script_path), "--inference"]
# Write the value to a temporary file
with tempfile.NamedTemporaryFile(
delete=False, mode="w", suffix=".yaml"
) as temp_file:
temp_file.write(dict["inferenceConfig"])
temp_filepath = temp_file.name
command.extend(["--config-file", str(temp_filepath)])
for key, value in dict["arguments"].items():
if value is not None:
command.extend([f"--{key}", str(value)])
# Execute the command using subprocess.call
print(command)
try:
subprocess.call(command)
except subprocess.CalledProcessError as e:
print(f"Error occurred: {e}")
print("start_inference")
def stop_inference():
process_name = "python pytorch_connectomics/scripts/main.py"
stop_process(process_name)
stop_tensorboard()