Skip to content

Commit 70ccf3a

Browse files
committed
feat: delete Dataset (InftyAI#201)
* Add delete dataset feature Signed-off-by: kerthcet <kerthcet@gmail.com> * Add delete dataset feature Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 8052880 commit 70ccf3a

17 files changed

Lines changed: 391 additions & 142 deletions

File tree

alphatrion/artifact/artifact.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010

1111
class Artifact:
12-
def __init__(self, team_id: str, insecure: bool = False):
13-
self._team_id = team_id
12+
def __init__(self, insecure: bool = False):
1413
self._url = get_registry_url()
1514
self._client = oras.client.OrasClient(
1615
hostname=self._url.strip("/"), auth_backend="token", insecure=insecure
@@ -50,7 +49,7 @@ def push(
5049
if version is None:
5150
version = utiltime.now_2_hash()
5251

53-
path = f"{self._team_id}/{repo_name}:{version}"
52+
path = f"{repo_name}:{version}"
5453
target = f"{self._url}/{path}"
5554

5655
try:
@@ -61,7 +60,7 @@ def push(
6160
return path
6261

6362
def list_versions(self, repo_name: str) -> list[str]:
64-
target = f"{self._url}/{self._team_id}/{repo_name}"
63+
target = f"{self._url}/{repo_name}"
6564
try:
6665
tags = self._client.get_tags(target)
6766
return tags
@@ -91,7 +90,7 @@ def pull(
9190
(defaults to ORAS temp directory)
9291
:return: list of absolute file paths that were downloaded
9392
"""
94-
path = f"{self._team_id}/{repo_name}:{version}"
93+
path = f"{repo_name}:{version}"
9594
target = f"{self._url}/{path}"
9695

9796
if output_dir:
@@ -115,7 +114,7 @@ def pull(
115114
os.chdir(original_dir)
116115

117116
def delete(self, repo_name: str, versions: str | list[str]):
118-
target = f"{self._url}/{self._team_id}/{repo_name}"
117+
target = f"{self._url}/{repo_name}"
119118

120119
try:
121120
self._client.delete_tags(target, tags=versions)

alphatrion/log/log.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from alphatrion.snapshot.snapshot import (
1111
checkpoint_path,
1212
)
13+
from alphatrion.storage import runtime as storage_runtime
1314

1415
BEST_RESULT_PATH = "best_result_path"
15-
EXECUTION_RESULT = "execution_result"
1616

1717

1818
async def log_artifact(
@@ -45,7 +45,7 @@ async def log_artifact(
4545
if runtime is None:
4646
raise RuntimeError("Runtime is not initialized. Please call init() first.")
4747

48-
if not runtime.artifact_storage_enabled():
48+
if not storage_runtime.artifact_storage_enabled():
4949
raise RuntimeError(
5050
"Artifact storage is not enabled in the runtime."
5151
"Set ENABLE_ARTIFACT_STORAGE=true in the environment variables."
@@ -59,7 +59,7 @@ async def log_artifact(
5959

6060
loop = asyncio.get_running_loop()
6161
return await loop.run_in_executor(
62-
None, runtime._artifact.push, repo_name, paths, version
62+
None, runtime._artifact.push, f"{runtime.team_id}/{repo_name}", paths, version
6363
)
6464

6565

alphatrion/runtime/runtime.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import uuid
44

55
from alphatrion import envs
6-
from alphatrion.artifact.artifact import Artifact
76
from alphatrion.storage import runtime as storage_runtime
87
from alphatrion.storage.sqlstore import SQLStore
98

@@ -58,6 +57,7 @@ def __init__(
5857
storage_runtime.init()
5958
self._metadb = storage_runtime.storage_runtime().metadb
6059
self._tracestore = storage_runtime.storage_runtime().tracestore
60+
self._artifact = storage_runtime.storage_runtime().artifact
6161

6262
self._user_id = user_id
6363
self._team_id = team_id
@@ -74,18 +74,9 @@ def __init__(
7474
self._team_id = teams[0].uuid
7575

7676
self._root_path = os.getenv(envs.ROOT_PATH, os.path.expanduser("~/.alphatrion"))
77-
78-
artifact_insecure = os.getenv(envs.ARTIFACT_INSECURE, "false").lower() == "true"
79-
80-
if self.artifact_storage_enabled():
81-
self._artifact = Artifact(team_id=self._team_id, insecure=artifact_insecure)
82-
8377
if not os.path.exists(self._root_path):
8478
os.makedirs(self._root_path, exist_ok=True)
8579

86-
def artifact_storage_enabled(self) -> bool:
87-
return os.getenv(envs.ENABLE_ARTIFACT_STORAGE, "true").lower() == "true"
88-
8980
@property
9081
def metadb(self) -> SQLStore:
9182
return self._metadb
@@ -94,6 +85,10 @@ def metadb(self) -> SQLStore:
9485
def tracestore(self):
9586
return self._tracestore
9687

88+
@property
89+
def artifact(self):
90+
return self._artifact
91+
9792
@property
9893
def user_id(self) -> uuid.UUID:
9994
return self._user_id

alphatrion/server/graphql/resolvers.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,10 @@ async def list_artifact_tags(
352352
) -> list[ArtifactTag]:
353353
"""List tags for a repository."""
354354

355-
arf = artifact.Artifact(team_id=team_id, insecure=True)
356-
return [ArtifactTag(name=tag) for tag in arf.list_versions(repo_name)]
355+
arf = runtime.storage_runtime().artifact
356+
return [
357+
ArtifactTag(name=tag) for tag in arf.list_versions(f"{team_id}/{repo_name}")
358+
]
357359

358360
@staticmethod
359361
async def list_artifact_files(
@@ -362,8 +364,8 @@ async def list_artifact_files(
362364
"""List files in an artifact without loading content."""
363365

364366
try:
365-
arf = artifact.Artifact(team_id=team_id, insecure=True)
366-
file_paths = arf.pull(repo_name=repo_name, version=tag)
367+
arf = runtime.storage_runtime().artifact
368+
file_paths = arf.pull(repo_name=f"{team_id}/{repo_name}", version=tag)
367369

368370
if not file_paths:
369371
return []
@@ -405,11 +407,11 @@ async def get_artifact_content(
405407
"""Get artifact content from registry."""
406408
try:
407409
# Initialize artifact client
408-
arf = artifact.Artifact(team_id=team_id, insecure=True)
410+
arf = runtime.storage_runtime().artifact
409411

410412
# Pull the artifact - ORAS will manage temp directory
411413
# Returns absolute paths to files in ORAS temp directory
412-
file_paths = arf.pull(repo_name=repo_name, version=tag)
414+
file_paths = arf.pull(repo_name=f"{team_id}/{repo_name}", version=tag)
413415

414416
if not file_paths:
415417
raise RuntimeError("No files found in artifact")
@@ -875,4 +877,21 @@ def delete_experiments(experiment_ids: list[strawberry.ID]) -> int:
875877
@staticmethod
876878
def delete_dataset(dataset_id: strawberry.ID) -> bool:
877879
metadb = runtime.storage_runtime().metadb
880+
artifact = runtime.storage_runtime().artifact
881+
dataset = metadb.get_dataset(dataset_id=dataset_id)
882+
883+
# delete the artifact file as well
884+
if dataset:
885+
try:
886+
repo_name, version = dataset.path.split(":", 1)
887+
artifact.delete(repo_name=repo_name, versions=version)
888+
except Exception as e:
889+
print(f"Failed to delete artifact for dataset {dataset_id}: {e}")
890+
878891
return metadb.delete_dataset(dataset_id=dataset_id)
892+
893+
@staticmethod
894+
def delete_datasets(dataset_ids: list[strawberry.ID]) -> bool:
895+
for id in dataset_ids:
896+
GraphQLMutations.delete_dataset(dataset_id=id)
897+
return True

alphatrion/server/graphql/schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def datasets(
141141

142142
dataset: Dataset | None = strawberry.field(resolver=GraphQLResolvers.get_dataset)
143143

144+
144145
@strawberry.type
145146
class Mutation:
146147
@strawberry.mutation
@@ -175,5 +176,9 @@ def delete_experiments(self, experiment_ids: list[strawberry.ID]) -> int:
175176
def delete_dataset(self, dataset_id: strawberry.ID) -> bool:
176177
return GraphQLMutations.delete_dataset(dataset_id=dataset_id)
177178

179+
@strawberry.mutation
180+
def delete_datasets(self, dataset_ids: list[strawberry.ID]) -> bool:
181+
return GraphQLMutations.delete_datasets(dataset_ids=dataset_ids)
182+
178183

179184
schema = strawberry.Schema(query=Query, mutation=Mutation)

alphatrion/server/graphql/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ class RemoveUserFromTeamInput:
253253
user_id: strawberry.ID
254254
team_id: strawberry.ID
255255

256+
256257
# Artifact types
257258
@strawberry.type
258259
class ArtifactRepository:

alphatrion/storage/runtime.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from traceloop.sdk import Traceloop
77

88
from alphatrion import envs
9+
from alphatrion.artifact.artifact import Artifact
910
from alphatrion.storage.sqlstore import SQLStore
1011
from alphatrion.storage.tracestore import TraceStore
1112
from alphatrion.tracing.clickhouse_exporter import ClickHouseSpanExporter
@@ -54,6 +55,10 @@ def __init__(self):
5455
tracer_provider = trace.get_tracer_provider()
5556
tracer_provider.add_span_processor(ContextAttributesSpanProcessor())
5657

58+
artifact_insecure = os.getenv(envs.ARTIFACT_INSECURE, "false").lower() == "true"
59+
if artifact_storage_enabled():
60+
self._artifact = Artifact(insecure=artifact_insecure)
61+
5762
self._inited = True
5863

5964
@property
@@ -70,6 +75,10 @@ def flush(self):
7075
if isinstance(tracer_provider, TracerProvider):
7176
tracer_provider.force_flush(timeout_millis=5000)
7277

78+
@property
79+
def artifact(self):
80+
return self._artifact
81+
7382

7483
def init():
7584
"""
@@ -85,3 +94,7 @@ def storage_runtime() -> StorageRuntime:
8594
if __STORAGE_RUNTIME__ is None:
8695
raise RuntimeError("StorageRuntime is not initialized. Call init() first.")
8796
return __STORAGE_RUNTIME__
97+
98+
99+
def artifact_storage_enabled() -> bool:
100+
return os.getenv(envs.ENABLE_ARTIFACT_STORAGE, "true").lower() == "true"
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import { useMutation, useQueryClient } from '@tanstack/react-query';
2+
import { graphqlMutation, mutations } from '../lib/graphql-client';
3+
4+
interface DeleteDatasetResponse {
5+
deleteDataset: boolean;
6+
}
7+
8+
interface DeleteDatasetsResponse {
9+
deleteDatasets: number;
10+
}
11+
12+
/**
13+
* Hook to delete a single dataset
14+
*/
15+
export function useDeleteDataset() {
16+
const queryClient = useQueryClient();
17+
18+
return useMutation({
19+
mutationFn: async (datasetId: string) => {
20+
const data = await graphqlMutation<DeleteDatasetResponse>(
21+
mutations.deleteDataset,
22+
{ datasetId }
23+
);
24+
return data.deleteDataset;
25+
},
26+
onSuccess: () => {
27+
// Invalidate datasets queries to refetch the list
28+
queryClient.invalidateQueries({ queryKey: ['datasets'] });
29+
queryClient.invalidateQueries({ queryKey: ['dataset'] });
30+
},
31+
});
32+
}
33+
34+
/**
35+
* Hook to delete multiple datasets in batch
36+
*/
37+
export function useDeleteDatasets() {
38+
const queryClient = useQueryClient();
39+
40+
return useMutation({
41+
mutationFn: async (datasetIds: string[]) => {
42+
const data = await graphqlMutation<DeleteDatasetsResponse>(
43+
mutations.deleteDatasets,
44+
{ datasetIds }
45+
);
46+
return data.deleteDatasets;
47+
},
48+
onSuccess: () => {
49+
// Invalidate datasets queries to refetch the list
50+
queryClient.invalidateQueries({ queryKey: ['datasets'] });
51+
queryClient.invalidateQueries({ queryKey: ['dataset'] });
52+
},
53+
});
54+
}

dashboard/src/lib/graphql-client.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,4 +430,16 @@ export const mutations = {
430430
deleteExperiments(experimentIds: $experimentIds)
431431
}
432432
`,
433+
434+
deleteDataset: `
435+
mutation DeleteDataset($datasetId: ID!) {
436+
deleteDataset(datasetId: $datasetId)
437+
}
438+
`,
439+
440+
deleteDatasets: `
441+
mutation DeleteDatasets($datasetIds: [ID!]!) {
442+
deleteDatasets(datasetIds: $datasetIds)
443+
}
444+
`,
433445
};

0 commit comments

Comments
 (0)