@@ -247,60 +247,110 @@ pub async fn run_training_subprocess_from_args(args: &[String]) -> Result<()> {
247247 . cloned ( )
248248 . ok_or_else ( || anyhow ! ( "usage: {} train-profile <profile-name>" , binary_name( args) ) ) ?;
249249 let root_dir = std:: env:: current_dir ( ) . context ( "failed to resolve current dir" ) ?;
250- let settings = crate :: config:: Settings :: load ( & root_dir) ?;
251- let rpc = SampleRpcConfig :: from_settings ( & settings) ?;
252- let profile = ModerationProfile :: load ( & root_dir, & profile_name) ?;
253-
254- let initial_sample_count = fetch_training_sample_count_via_rpc ( & rpc, & profile)
255- . await
256- . unwrap_or ( 0 ) ;
257- write_training_status (
258- & profile,
259- & serde_json:: json!( {
260- "status" : "running" ,
261- "message" : "training in progress" ,
262- "timestamp" : current_unix_secs( ) ,
263- "profile" : profile. profile_name,
264- "model_type" : profile. config. local_model_type,
265- "sample_count" : initial_sample_count
266- } ) ,
267- ) ?;
268-
269- match run_profile_training ( & rpc, & profile) . await {
270- Ok ( output) => {
271- write_training_status (
272- & profile,
273- & serde_json:: json!( {
274- "status" : "success" ,
275- "message" : "training completed" ,
276- "timestamp" : current_unix_secs( ) ,
277- "profile" : profile. profile_name,
278- "model_type" : profile. config. local_model_type,
279- "sample_count" : output. sample_count,
280- "pass_count" : output. pass_count,
281- "violation_count" : output. violation_count,
282- "runtime_json_path" : output. runtime_json_path,
283- "runtime_coef_path" : output. runtime_coef_path,
284- "model_marker_path" : output. model_marker_path
285- } ) ,
286- ) ?;
287- Ok ( ( ) )
288- }
289- Err ( err) => {
290- write_training_status (
291- & profile,
292- & serde_json:: json!( {
293- "status" : "failed" ,
294- "message" : format!( "{err:#}" ) ,
295- "timestamp" : current_unix_secs( ) ,
296- "profile" : profile. profile_name,
297- "model_type" : profile. config. local_model_type,
298- "sample_count" : initial_sample_count
299- } ) ,
300- ) ?;
301- Err ( err)
250+ let result = async {
251+ let settings = crate :: config:: Settings :: load ( & root_dir) ?;
252+ let rpc = SampleRpcConfig :: from_settings ( & settings) ?;
253+ let profile = ModerationProfile :: load ( & root_dir, & profile_name) ?;
254+
255+ let initial_sample_count = fetch_training_sample_count_via_rpc ( & rpc, & profile)
256+ . await
257+ . unwrap_or ( 0 ) ;
258+ write_training_status (
259+ & profile,
260+ & serde_json:: json!( {
261+ "status" : "running" ,
262+ "message" : "training in progress" ,
263+ "timestamp" : current_unix_secs( ) ,
264+ "started_at" : current_unix_secs( ) ,
265+ "pid" : std:: process:: id( ) ,
266+ "profile" : profile. profile_name,
267+ "model_type" : profile. config. local_model_type,
268+ "sample_count" : initial_sample_count
269+ } ) ,
270+ ) ?;
271+
272+ match run_profile_training ( & rpc, & profile) . await {
273+ Ok ( output) => {
274+ write_training_status (
275+ & profile,
276+ & serde_json:: json!( {
277+ "status" : "success" ,
278+ "message" : "training completed" ,
279+ "timestamp" : current_unix_secs( ) ,
280+ "profile" : profile. profile_name,
281+ "model_type" : profile. config. local_model_type,
282+ "sample_count" : output. sample_count,
283+ "pass_count" : output. pass_count,
284+ "violation_count" : output. violation_count,
285+ "runtime_json_path" : output. runtime_json_path,
286+ "runtime_coef_path" : output. runtime_coef_path,
287+ "model_marker_path" : output. model_marker_path
288+ } ) ,
289+ ) ?;
290+ Ok ( ( ) )
291+ }
292+ Err ( err) => {
293+ write_training_status (
294+ & profile,
295+ & serde_json:: json!( {
296+ "status" : "failed" ,
297+ "message" : format!( "{err:#}" ) ,
298+ "timestamp" : current_unix_secs( ) ,
299+ "profile" : profile. profile_name,
300+ "model_type" : profile. config. local_model_type,
301+ "sample_count" : initial_sample_count
302+ } ) ,
303+ ) ?;
304+ Err ( err)
305+ }
302306 }
303307 }
308+ . await ;
309+
310+ if let Err ( err) = & result {
311+ let _ = write_training_failure_status (
312+ & root_dir,
313+ & profile_name,
314+ format ! ( "{err:#}" ) ,
315+ ) ;
316+ }
317+
318+ result
319+ }
320+
321+ fn write_training_failure_status (
322+ root_dir : & Path ,
323+ profile_name : & str ,
324+ message : String ,
325+ ) -> Result < ( ) > {
326+ let status_path = root_dir
327+ . join ( "configs" )
328+ . join ( "mod_profiles" )
329+ . join ( profile_name)
330+ . join ( ".train_status.json" ) ;
331+ if let Some ( parent) = status_path. parent ( ) {
332+ fs:: create_dir_all ( parent)
333+ . with_context ( || format ! ( "failed to create profile dir {}" , parent. display( ) ) ) ?;
334+ }
335+
336+ let payload = serde_json:: json!( {
337+ "status" : "failed" ,
338+ "message" : message,
339+ "timestamp" : current_unix_secs( ) ,
340+ "profile" : profile_name,
341+ } ) ;
342+ let tmp_path = status_path. with_extension ( "json.tmp" ) ;
343+ let encoded = serde_json:: to_vec_pretty ( & payload) . context ( "failed to encode training failure status" ) ?;
344+ fs:: write ( & tmp_path, encoded)
345+ . with_context ( || format ! ( "failed to write training failure status tmp {}" , tmp_path. display( ) ) ) ?;
346+ fs:: rename ( & tmp_path, & status_path) . with_context ( || {
347+ format ! (
348+ "failed to replace training failure status {} with {}" ,
349+ status_path. display( ) ,
350+ tmp_path. display( )
351+ )
352+ } ) ?;
353+ Ok ( ( ) )
304354}
305355
306356pub fn train_hashlinear_runtime (
0 commit comments