Skip to content

Commit b3c3324

Browse files
authored
Delete experiments should delete runs as well (#191)
* Delete experiments should delete runs as well Signed-off-by: kerthcet <kerthcet@gmail.com> * fix lint Signed-off-by: kerthcet <kerthcet@gmail.com> * fix lint Signed-off-by: kerthcet <kerthcet@gmail.com> * Add 4 tags per line Signed-off-by: kerthcet <kerthcet@gmail.com> * update layout Signed-off-by: kerthcet <kerthcet@gmail.com> * fix test Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent e8a5be8 commit b3c3324

17 files changed

Lines changed: 1443 additions & 627 deletions

File tree

alphatrion/server/graphql/resolvers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,21 @@ def remove_user_from_team(input: RemoveUserFromTeamInput) -> bool:
692692

693693
# Remove user from team (deletes TeamMember entry)
694694
return metadb.remove_user_from_team(user_id=user_id, team_id=team_id)
695+
696+
@staticmethod
697+
# TODO: We should have the team_id in the header for authz, and verify the
698+
# team_id matches the experiment's team_id before allowing deletion.
699+
def delete_experiment(experiment_id: strawberry.ID) -> bool:
700+
metadb = runtime.storage_runtime().metadb
701+
# Soft delete experiment by setting is_del flag
702+
return metadb.delete_experiment(experiment_id=experiment_id)
703+
704+
@staticmethod
705+
# TODO: We should have the team_id in the header for authz, and verify the
706+
# team_id matches the experiment's team_id before allowing deletion.
707+
def delete_experiments(experiment_ids: list[strawberry.ID]) -> int:
708+
metadb = runtime.storage_runtime().metadb
709+
# Convert strawberry IDs to UUIDs
710+
uuids = [uuid.UUID(exp_id) for exp_id in experiment_ids]
711+
# Soft delete experiments by setting is_del flag
712+
return metadb.delete_experiments(experiment_ids=uuids)

alphatrion/server/graphql/schema.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,5 +126,13 @@ def add_user_to_team(self, input: AddUserToTeamInput) -> bool:
126126
def remove_user_from_team(self, input: RemoveUserFromTeamInput) -> bool:
127127
return GraphQLMutations.remove_user_from_team(input=input)
128128

129+
@strawberry.mutation
130+
def delete_experiment(self, experiment_id: strawberry.ID) -> bool:
131+
return GraphQLMutations.delete_experiment(experiment_id=experiment_id)
132+
133+
@strawberry.mutation
134+
def delete_experiments(self, experiment_ids: list[strawberry.ID]) -> int:
135+
return GraphQLMutations.delete_experiments(experiment_ids=experiment_ids)
136+
129137

130138
schema = strawberry.Schema(query=Query, mutation=Mutation)

alphatrion/storage/sql_models.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ class Experiment(Base):
114114

115115
uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
116116
team_id = Column(UUID(as_uuid=True), nullable=False)
117-
user_id = Column(UUID(as_uuid=True), nullable=True)
117+
user_id = Column(
118+
UUID(as_uuid=True), nullable=True, comment="User who created the experiment"
119+
)
118120
name = Column(String, nullable=False)
119121
description = Column(String, nullable=True)
120122
meta = Column(
@@ -171,7 +173,9 @@ class Run(Base):
171173
uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
172174
team_id = Column(UUID(as_uuid=True), nullable=False)
173175
experiment_id = Column(UUID(as_uuid=True), nullable=False)
174-
user_id = Column(UUID(as_uuid=True), nullable=True)
176+
user_id = Column(
177+
UUID(as_uuid=True), nullable=True, comment="User who created the run"
178+
)
175179
meta = Column(
176180
MutableDict.as_mutable(JSON),
177181
nullable=True,

alphatrion/storage/sqlstore.py

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,20 @@ def create_experiment(
346346
uid = uuid.uuid4()
347347

348348
session = self._session()
349+
# TODO: add back the validation.
350+
# # verify user is in the team
351+
# membership = (
352+
# session.query(TeamMember)
353+
# .filter(
354+
# TeamMember.user_id == user_id,
355+
# TeamMember.team_id == team_id,
356+
# )
357+
# .first()
358+
# )
359+
# if membership is None:
360+
# session.close()
361+
# raise ValueError("User must be a member of the team to create experiment")
362+
349363
new_exp = Experiment(
350364
uuid=uid,
351365
team_id=team_id,
@@ -396,22 +410,22 @@ def get_experiment(self, experiment_id: uuid.UUID) -> Experiment | None:
396410
return exp
397411

398412
# Different team may have the same experiment name.
399-
def get_exp_by_name(self, name: str, team_id: uuid.UUID) -> Experiment | None:
413+
def get_exp_by_name(
414+
self, name: str, team_id: uuid.UUID, include_deleted: bool = False
415+
) -> Experiment | None:
400416
# make sure the team exists
401417
team = self.get_team(team_id)
402418
if team is None:
403419
return None
404420

405421
session = self._session()
406-
trial = (
407-
session.query(Experiment)
408-
.filter(
409-
Experiment.name == name,
410-
Experiment.team_id == team_id,
411-
Experiment.is_del == 0,
412-
)
413-
.first()
422+
query = session.query(Experiment).filter(
423+
Experiment.name == name,
424+
Experiment.team_id == team_id,
414425
)
426+
if not include_deleted:
427+
query = query.filter(Experiment.is_del == 0)
428+
trial = query.first()
415429
session.close()
416430
return trial
417431

@@ -532,6 +546,70 @@ def list_exps_by_timeframe(
532546
session.close()
533547
return exps
534548

549+
def delete_experiment(self, experiment_id: uuid.UUID) -> bool:
550+
session = self._session()
551+
552+
# Try to delete the experiment
553+
exp = (
554+
session.query(Experiment)
555+
.filter(Experiment.uuid == experiment_id, Experiment.is_del == 0)
556+
.first()
557+
)
558+
559+
if exp and exp.status == Status.RUNNING:
560+
raise ValueError(
561+
"Cannot delete a running experiment. Please stop it first."
562+
)
563+
564+
# Delete all runs associated with this experiment
565+
# (regardless of experiment status)
566+
session.query(Run).filter(Run.experiment_id == experiment_id).update(
567+
{Run.is_del: 1}, synchronize_session=False
568+
)
569+
if exp:
570+
exp.is_del = 1
571+
session.commit()
572+
session.close()
573+
return True
574+
575+
# Even if experiment doesn't exist, commit the run deletions
576+
session.commit()
577+
session.close()
578+
return False
579+
580+
def delete_experiments(self, experiment_ids: list[uuid.UUID]) -> int:
581+
"""
582+
Batch delete experiments by setting is_del flag.
583+
Also deletes all associated runs.
584+
Returns the number of experiments successfully deleted.
585+
"""
586+
session = self._session()
587+
# Delete the experiments
588+
# if experiment is running, skip deletion for that experiment
589+
filtered_exps = (
590+
session.query(Experiment.uuid)
591+
.filter(
592+
Experiment.uuid.in_(experiment_ids),
593+
Experiment.is_del == 0,
594+
Experiment.status != Status.RUNNING,
595+
)
596+
.all()
597+
)
598+
filtered_exp_ids = [exp_id for (exp_id,) in filtered_exps] # unpack tuples
599+
600+
deleted_count = (
601+
session.query(Experiment)
602+
.filter(Experiment.uuid.in_(filtered_exp_ids))
603+
.update({Experiment.is_del: 1}, synchronize_session=False)
604+
)
605+
# Delete all runs associated with these experiments
606+
session.query(Run).filter(Run.experiment_id.in_(filtered_exp_ids)).update(
607+
{Run.is_del: 1}, synchronize_session=False
608+
)
609+
session.commit()
610+
session.close()
611+
return deleted_count
612+
535613
# ---------- Run APIs ----------
536614

537615
def create_run(
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import * as React from 'react';
2+
import { cn } from '../../lib/utils';
3+
4+
export interface CheckboxProps
5+
extends React.InputHTMLAttributes<HTMLInputElement> {}
6+
7+
const Checkbox = React.forwardRef<HTMLInputElement, CheckboxProps>(
8+
({ className, ...props }, ref) => {
9+
return (
10+
<input
11+
type="checkbox"
12+
className={cn(
13+
'h-4 w-4 rounded border-gray-300 text-primary cursor-pointer',
14+
'focus:ring-0 focus:ring-offset-0 focus:outline-none',
15+
'checked:border-primary checked:bg-primary',
16+
className
17+
)}
18+
ref={ref}
19+
{...props}
20+
/>
21+
);
22+
}
23+
);
24+
25+
Checkbox.displayName = 'Checkbox';
26+
27+
export { Checkbox };
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import { useMutation, useQueryClient } from '@tanstack/react-query';
2+
import { graphqlMutation, mutations } from '../lib/graphql-client';
3+
4+
interface DeleteExperimentResponse {
5+
deleteExperiment: boolean;
6+
}
7+
8+
interface DeleteExperimentsResponse {
9+
deleteExperiments: number;
10+
}
11+
12+
/**
13+
* Hook to delete a single experiment
14+
*/
15+
export function useDeleteExperiment() {
16+
const queryClient = useQueryClient();
17+
18+
return useMutation({
19+
mutationFn: async (experimentId: string) => {
20+
const data = await graphqlMutation<DeleteExperimentResponse>(
21+
mutations.deleteExperiment,
22+
{ experimentId }
23+
);
24+
return data.deleteExperiment;
25+
},
26+
onSuccess: () => {
27+
// Invalidate experiments queries to refetch the list
28+
queryClient.invalidateQueries({ queryKey: ['experiments'] });
29+
queryClient.invalidateQueries({ queryKey: ['experiment'] });
30+
},
31+
});
32+
}
33+
34+
/**
35+
* Hook to delete multiple experiments in batch
36+
*/
37+
export function useDeleteExperiments() {
38+
const queryClient = useQueryClient();
39+
40+
return useMutation({
41+
mutationFn: async (experimentIds: string[]) => {
42+
const data = await graphqlMutation<DeleteExperimentsResponse>(
43+
mutations.deleteExperiments,
44+
{ experimentIds }
45+
);
46+
return data.deleteExperiments;
47+
},
48+
onSuccess: () => {
49+
// Invalidate experiments queries to refetch the list
50+
queryClient.invalidateQueries({ queryKey: ['experiments'] });
51+
queryClient.invalidateQueries({ queryKey: ['experiment'] });
52+
},
53+
});
54+
}

dashboard/src/lib/graphql-client.ts

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@ import axios from 'axios';
33
/**
44
* GraphQL client for AlphaTrion backend
55
*
6-
* The backend provides a read-only GraphQL API at /graphql
7-
* with queries for teams, experiments, runs, and metrics.
8-
*
9-
* No subscriptions or mutations are currently supported.
6+
* The backend provides a GraphQL API at /graphql
7+
* with queries and mutations for teams, experiments, runs, and metrics.
108
*/
119

1210
// Use relative URL to work with proxy in development
@@ -63,6 +61,17 @@ export async function graphqlQuery<T>(
6361
}
6462
}
6563

64+
/**
65+
* Execute a GraphQL mutation
66+
*/
67+
export async function graphqlMutation<T>(
68+
mutation: string,
69+
variables?: Record<string, unknown>
70+
): Promise<T> {
71+
// Mutations use the same endpoint and logic as queries
72+
return graphqlQuery<T>(mutation, variables);
73+
}
74+
6675
// GraphQL query templates
6776
export const queries = {
6877
listTeams: `
@@ -343,3 +352,18 @@ export const queries = {
343352
`,
344353

345354
};
355+
356+
// GraphQL mutation templates
357+
export const mutations = {
358+
deleteExperiment: `
359+
mutation DeleteExperiment($experimentId: ID!) {
360+
deleteExperiment(experimentId: $experimentId)
361+
}
362+
`,
363+
364+
deleteExperiments: `
365+
mutation DeleteExperiments($experimentIds: [ID!]!) {
366+
deleteExperiments(experimentIds: $experimentIds)
367+
}
368+
`,
369+
};

dashboard/src/pages/experiments/[id].tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ export function ExperimentDetailPage() {
366366
<TableRow className="hover:bg-transparent border-b">
367367
<TableHead className="h-11 text-xs font-semibold uppercase tracking-wider text-muted-foreground bg-muted/50">UUID</TableHead>
368368
<TableHead className="h-11 text-xs font-semibold uppercase tracking-wider text-muted-foreground bg-muted/50">Status</TableHead>
369-
<TableHead className="h-11 text-xs font-semibold uppercase tracking-wider text-muted-foreground bg-muted/50 text-right">Created</TableHead>
369+
<TableHead className="h-11 text-xs font-semibold uppercase tracking-wider text-muted-foreground bg-muted/50">Created</TableHead>
370370
</TableRow>
371371
</TableHeader>
372372
<TableBody>
@@ -385,7 +385,7 @@ export function ExperimentDetailPage() {
385385
{run.status}
386386
</Badge>
387387
</TableCell>
388-
<TableCell className="py-3 text-sm text-muted-foreground text-right">
388+
<TableCell className="py-3 text-sm text-muted-foreground">
389389
{formatDistanceToNow(new Date(run.createdAt), {
390390
addSuffix: true,
391391
})}

0 commit comments

Comments
 (0)