Skip to content

Commit ea2955a

Browse files
author
root
committed
Harden scheduled training state recovery
1 parent ca4cf96 commit ea2955a

3 files changed

Lines changed: 148 additions & 56 deletions

File tree

src/scheduler.rs

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ use crate::profile::ModerationProfile;
1111
use crate::sample_rpc::SampleRpcConfig;
1212
use crate::training::{evaluate_training_need, fetch_training_sample_count_via_rpc};
1313

14+
const RUNNING_STATUS_STALE_SECS: u64 = 2 * 60 * 60;
15+
1416
#[derive(Debug, Clone, PartialEq, Eq)]
1517
pub struct TrainingSubprocessCommand {
1618
pub program: String,
@@ -95,7 +97,7 @@ pub fn cooldown_allows_training(
9597
Ok(elapsed_secs >= cooldown_secs)
9698
}
9799

98-
pub fn training_launch_allowed(profile: &ModerationProfile) -> Result<bool> {
100+
pub fn training_launch_allowed(profile: &ModerationProfile, now_unix_secs: u64) -> Result<bool> {
99101
let Some(status) = profile.training_status() else {
100102
return Ok(true);
101103
};
@@ -104,7 +106,46 @@ pub fn training_launch_allowed(profile: &ModerationProfile) -> Result<bool> {
104106
return Ok(true);
105107
};
106108

107-
Ok(state != "running")
109+
if state != "running" {
110+
return Ok(true);
111+
}
112+
113+
let Some(timestamp) = status.get("timestamp").and_then(|value| value.as_u64()) else {
114+
return Ok(true);
115+
};
116+
if now_unix_secs.saturating_sub(timestamp) >= RUNNING_STATUS_STALE_SECS {
117+
return Ok(true);
118+
}
119+
120+
let Some(pid) = status
121+
.get("pid")
122+
.and_then(|value| value.as_u64())
123+
.and_then(|value| u32::try_from(value).ok())
124+
else {
125+
return Ok(false);
126+
};
127+
128+
Ok(!training_process_matches_profile(pid, &profile.profile_name))
129+
}
130+
131+
fn training_process_matches_profile(pid: u32, profile_name: &str) -> bool {
132+
let cmdline_path = format!("/proc/{pid}/cmdline");
133+
let Ok(raw) = fs::read(cmdline_path) else {
134+
return false;
135+
};
136+
if raw.is_empty() {
137+
return false;
138+
}
139+
140+
let args = raw
141+
.split(|byte| *byte == 0)
142+
.filter(|part| !part.is_empty())
143+
.filter_map(|part| std::str::from_utf8(part).ok())
144+
.collect::<Vec<_>>();
145+
146+
args.windows(2).any(|window| {
147+
window[0] == "train-profile" && window[1] == profile_name
148+
})
108149
}
109150

110151
pub fn build_training_subprocess_command(
@@ -139,7 +180,7 @@ pub async fn plan_training_round(
139180
let mut planned = Vec::new();
140181

141182
for scanned in scan_profiles(root_dir)? {
142-
if !training_launch_allowed(&scanned.profile)? {
183+
if !training_launch_allowed(&scanned.profile, now_unix_secs)? {
143184
continue;
144185
}
145186

src/training.rs

Lines changed: 102 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -247,60 +247,110 @@ pub async fn run_training_subprocess_from_args(args: &[String]) -> Result<()> {
247247
.cloned()
248248
.ok_or_else(|| anyhow!("usage: {} train-profile <profile-name>", binary_name(args)))?;
249249
let root_dir = std::env::current_dir().context("failed to resolve current dir")?;
250-
let settings = crate::config::Settings::load(&root_dir)?;
251-
let rpc = SampleRpcConfig::from_settings(&settings)?;
252-
let profile = ModerationProfile::load(&root_dir, &profile_name)?;
253-
254-
let initial_sample_count = fetch_training_sample_count_via_rpc(&rpc, &profile)
255-
.await
256-
.unwrap_or(0);
257-
write_training_status(
258-
&profile,
259-
&serde_json::json!({
260-
"status": "running",
261-
"message": "training in progress",
262-
"timestamp": current_unix_secs(),
263-
"profile": profile.profile_name,
264-
"model_type": profile.config.local_model_type,
265-
"sample_count": initial_sample_count
266-
}),
267-
)?;
268-
269-
match run_profile_training(&rpc, &profile).await {
270-
Ok(output) => {
271-
write_training_status(
272-
&profile,
273-
&serde_json::json!({
274-
"status": "success",
275-
"message": "training completed",
276-
"timestamp": current_unix_secs(),
277-
"profile": profile.profile_name,
278-
"model_type": profile.config.local_model_type,
279-
"sample_count": output.sample_count,
280-
"pass_count": output.pass_count,
281-
"violation_count": output.violation_count,
282-
"runtime_json_path": output.runtime_json_path,
283-
"runtime_coef_path": output.runtime_coef_path,
284-
"model_marker_path": output.model_marker_path
285-
}),
286-
)?;
287-
Ok(())
288-
}
289-
Err(err) => {
290-
write_training_status(
291-
&profile,
292-
&serde_json::json!({
293-
"status": "failed",
294-
"message": format!("{err:#}"),
295-
"timestamp": current_unix_secs(),
296-
"profile": profile.profile_name,
297-
"model_type": profile.config.local_model_type,
298-
"sample_count": initial_sample_count
299-
}),
300-
)?;
301-
Err(err)
250+
let result = async {
251+
let settings = crate::config::Settings::load(&root_dir)?;
252+
let rpc = SampleRpcConfig::from_settings(&settings)?;
253+
let profile = ModerationProfile::load(&root_dir, &profile_name)?;
254+
255+
let initial_sample_count = fetch_training_sample_count_via_rpc(&rpc, &profile)
256+
.await
257+
.unwrap_or(0);
258+
write_training_status(
259+
&profile,
260+
&serde_json::json!({
261+
"status": "running",
262+
"message": "training in progress",
263+
"timestamp": current_unix_secs(),
264+
"started_at": current_unix_secs(),
265+
"pid": std::process::id(),
266+
"profile": profile.profile_name,
267+
"model_type": profile.config.local_model_type,
268+
"sample_count": initial_sample_count
269+
}),
270+
)?;
271+
272+
match run_profile_training(&rpc, &profile).await {
273+
Ok(output) => {
274+
write_training_status(
275+
&profile,
276+
&serde_json::json!({
277+
"status": "success",
278+
"message": "training completed",
279+
"timestamp": current_unix_secs(),
280+
"profile": profile.profile_name,
281+
"model_type": profile.config.local_model_type,
282+
"sample_count": output.sample_count,
283+
"pass_count": output.pass_count,
284+
"violation_count": output.violation_count,
285+
"runtime_json_path": output.runtime_json_path,
286+
"runtime_coef_path": output.runtime_coef_path,
287+
"model_marker_path": output.model_marker_path
288+
}),
289+
)?;
290+
Ok(())
291+
}
292+
Err(err) => {
293+
write_training_status(
294+
&profile,
295+
&serde_json::json!({
296+
"status": "failed",
297+
"message": format!("{err:#}"),
298+
"timestamp": current_unix_secs(),
299+
"profile": profile.profile_name,
300+
"model_type": profile.config.local_model_type,
301+
"sample_count": initial_sample_count
302+
}),
303+
)?;
304+
Err(err)
305+
}
302306
}
303307
}
308+
.await;
309+
310+
if let Err(err) = &result {
311+
let _ = write_training_failure_status(
312+
&root_dir,
313+
&profile_name,
314+
format!("{err:#}"),
315+
);
316+
}
317+
318+
result
319+
}
320+
321+
fn write_training_failure_status(
322+
root_dir: &Path,
323+
profile_name: &str,
324+
message: String,
325+
) -> Result<()> {
326+
let status_path = root_dir
327+
.join("configs")
328+
.join("mod_profiles")
329+
.join(profile_name)
330+
.join(".train_status.json");
331+
if let Some(parent) = status_path.parent() {
332+
fs::create_dir_all(parent)
333+
.with_context(|| format!("failed to create profile dir {}", parent.display()))?;
334+
}
335+
336+
let payload = serde_json::json!({
337+
"status": "failed",
338+
"message": message,
339+
"timestamp": current_unix_secs(),
340+
"profile": profile_name,
341+
});
342+
let tmp_path = status_path.with_extension("json.tmp");
343+
let encoded = serde_json::to_vec_pretty(&payload).context("failed to encode training failure status")?;
344+
fs::write(&tmp_path, encoded)
345+
.with_context(|| format!("failed to write training failure status tmp {}", tmp_path.display()))?;
346+
fs::rename(&tmp_path, &status_path).with_context(|| {
347+
format!(
348+
"failed to replace training failure status {} with {}",
349+
status_path.display(),
350+
tmp_path.display()
351+
)
352+
})?;
353+
Ok(())
304354
}
305355

306356
pub fn train_hashlinear_runtime(

tests/scheduler_tests.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ fn scheduler_skips_profile_when_training_is_already_running() {
204204
}),
205205
);
206206

207-
let decision = scheduler::training_launch_allowed(&profile).expect("training launch decision");
207+
let decision = scheduler::training_launch_allowed(&profile, current_unix_secs())
208+
.expect("training launch decision");
208209

209210
assert!(!decision);
210211
}

0 commit comments

Comments
 (0)