@@ -346,6 +346,20 @@ def create_experiment(
346346 uid = uuid .uuid4 ()
347347
348348 session = self ._session ()
349+ # TODO: add back the validation.
350+ # # verify user is in the team
351+ # membership = (
352+ # session.query(TeamMember)
353+ # .filter(
354+ # TeamMember.user_id == user_id,
355+ # TeamMember.team_id == team_id,
356+ # )
357+ # .first()
358+ # )
359+ # if membership is None:
360+ # session.close()
361+ # raise ValueError("User must be a member of the team to create experiment")
362+
349363 new_exp = Experiment (
350364 uuid = uid ,
351365 team_id = team_id ,
@@ -396,22 +410,22 @@ def get_experiment(self, experiment_id: uuid.UUID) -> Experiment | None:
396410 return exp
397411
398412 # Different team may have the same experiment name.
399- def get_exp_by_name (self , name : str , team_id : uuid .UUID ) -> Experiment | None :
413+ def get_exp_by_name (
414+ self , name : str , team_id : uuid .UUID , include_deleted : bool = False
415+ ) -> Experiment | None :
400416 # make sure the team exists
401417 team = self .get_team (team_id )
402418 if team is None :
403419 return None
404420
405421 session = self ._session ()
406- trial = (
407- session .query (Experiment )
408- .filter (
409- Experiment .name == name ,
410- Experiment .team_id == team_id ,
411- Experiment .is_del == 0 ,
412- )
413- .first ()
422+ query = session .query (Experiment ).filter (
423+ Experiment .name == name ,
424+ Experiment .team_id == team_id ,
414425 )
426+ if not include_deleted :
427+ query = query .filter (Experiment .is_del == 0 )
428+ trial = query .first ()
415429 session .close ()
416430 return trial
417431
@@ -532,6 +546,70 @@ def list_exps_by_timeframe(
532546 session .close ()
533547 return exps
534548
549+ def delete_experiment (self , experiment_id : uuid .UUID ) -> bool :
550+ session = self ._session ()
551+
552+ # Try to delete the experiment
553+ exp = (
554+ session .query (Experiment )
555+ .filter (Experiment .uuid == experiment_id , Experiment .is_del == 0 )
556+ .first ()
557+ )
558+
559+ if exp and exp .status == Status .RUNNING :
560+ raise ValueError (
561+ "Cannot delete a running experiment. Please stop it first."
562+ )
563+
564+ # Delete all runs associated with this experiment
565+ # (regardless of experiment status)
566+ session .query (Run ).filter (Run .experiment_id == experiment_id ).update (
567+ {Run .is_del : 1 }, synchronize_session = False
568+ )
569+ if exp :
570+ exp .is_del = 1
571+ session .commit ()
572+ session .close ()
573+ return True
574+
575+ # Even if experiment doesn't exist, commit the run deletions
576+ session .commit ()
577+ session .close ()
578+ return False
579+
580+ def delete_experiments (self , experiment_ids : list [uuid .UUID ]) -> int :
581+ """
582+ Batch delete experiments by setting is_del flag.
583+ Also deletes all associated runs.
584+ Returns the number of experiments successfully deleted.
585+ """
586+ session = self ._session ()
587+ # Delete the experiments
588+ # if experiment is running, skip deletion for that experiment
589+ filtered_exps = (
590+ session .query (Experiment .uuid )
591+ .filter (
592+ Experiment .uuid .in_ (experiment_ids ),
593+ Experiment .is_del == 0 ,
594+ Experiment .status != Status .RUNNING ,
595+ )
596+ .all ()
597+ )
598+ filtered_exp_ids = [exp_id for (exp_id ,) in filtered_exps ] # unpack tuples
599+
600+ deleted_count = (
601+ session .query (Experiment )
602+ .filter (Experiment .uuid .in_ (filtered_exp_ids ))
603+ .update ({Experiment .is_del : 1 }, synchronize_session = False )
604+ )
605+ # Delete all runs associated with these experiments
606+ session .query (Run ).filter (Run .experiment_id .in_ (filtered_exp_ids )).update (
607+ {Run .is_del : 1 }, synchronize_session = False
608+ )
609+ session .commit ()
610+ session .close ()
611+ return deleted_count
612+
535613 # ---------- Run APIs ----------
536614
537615 def create_run (
0 commit comments