Skip to content

Commit fcc1ca1

Browse files
committed
feat(api): add batch cancellation for tasks and workers
1 parent c553c00 commit fcc1ca1

10 files changed

Lines changed: 905 additions & 124 deletions

File tree

netmito/src/api/tasks.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub fn tasks_router(st: InfraPool) -> Router<InfraPool> {
3232
.route("/{uuid}/artifacts/{content_type}", delete(delete_artifact))
3333
.route("/{uuid}/artifacts", post(upload_artifact))
3434
.route("/query", post(query_tasks))
35+
.route("/cancel", post(cancel_tasks))
3536
.route_layer(middleware::from_fn_with_state(
3637
st.clone(),
3738
user_auth_middleware,
@@ -149,6 +150,24 @@ pub async fn query_tasks(
149150
Ok(Json(tasks))
150151
}
151152

153+
pub async fn cancel_tasks(
154+
Extension(u): Extension<AuthUser>,
155+
State(pool): State<InfraPool>,
156+
Json(req): Json<TasksCancelByFilterReq>,
157+
) -> Result<Json<TasksCancelByFilterResp>, ApiError> {
158+
let resp = service::task::cancel_tasks_by_filter(u.id, &pool, req)
159+
.await
160+
.map_err(|e| match e {
161+
crate::error::Error::AuthError(err) => ApiError::AuthError(err),
162+
crate::error::Error::ApiError(e) => e,
163+
_ => {
164+
tracing::error!("{}", e);
165+
ApiError::InternalServerError
166+
}
167+
})?;
168+
Ok(Json(resp))
169+
}
170+
152171
pub async fn upload_artifact(
153172
Extension(u): Extension<AuthUser>,
154173
State(pool): State<InfraPool>,

netmito/src/api/workers.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ pub fn workers_router(st: InfraPool) -> Router<InfraPool> {
5454
)
5555
.route("/", post(user_register_worker))
5656
.route("/query", post(user_query_workers))
57+
.route("/shutdown", post(user_shutdown_workers))
5758
.route_layer(middleware::from_fn_with_state(
5859
st.clone(),
5960
user_auth_middleware,
@@ -364,3 +365,21 @@ pub async fn user_query_workers(
364365
})?;
365366
Ok(Json(resp))
366367
}
368+
369+
pub async fn user_shutdown_workers(
370+
Extension(u): Extension<AuthUser>,
371+
State(pool): State<InfraPool>,
372+
Json(req): Json<WorkersShutdownByFilterReq>,
373+
) -> Result<Json<WorkersShutdownByFilterResp>, ApiError> {
374+
let resp = service::worker::shutdown_workers_by_filter(u.id, &pool, req)
375+
.await
376+
.map_err(|e| match e {
377+
crate::error::Error::AuthError(err) => ApiError::AuthError(err),
378+
crate::error::Error::ApiError(e) => e,
379+
_ => {
380+
tracing::error!("{}", e);
381+
ApiError::InternalServerError
382+
}
383+
})?;
384+
Ok(Json(resp))
385+
}

netmito/src/client/http.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,30 @@ impl MitoHttpClient {
737737
}
738738
}
739739

740+
pub async fn shutdown_workers_by_filter(
741+
&mut self,
742+
req: WorkersShutdownByFilterReq,
743+
) -> crate::error::Result<WorkersShutdownByFilterResp> {
744+
self.url.set_path("workers/shutdown");
745+
let resp = self
746+
.http_client
747+
.post(self.url.as_str())
748+
.bearer_auth(&self.credential)
749+
.json(&req)
750+
.send()
751+
.await
752+
.map_err(map_reqwest_err)?;
753+
if resp.status().is_success() {
754+
let resp = resp
755+
.json::<WorkersShutdownByFilterResp>()
756+
.await
757+
.map_err(RequestError::from)?;
758+
Ok(resp)
759+
} else {
760+
Err(get_error_from_resp(resp).await.into())
761+
}
762+
}
763+
740764
pub async fn replace_worker_tags(
741765
&mut self,
742766
uuid: Uuid,
@@ -842,6 +866,30 @@ impl MitoHttpClient {
842866
}
843867
}
844868

869+
pub async fn cancel_tasks_by_filter(
870+
&mut self,
871+
req: TasksCancelByFilterReq,
872+
) -> crate::error::Result<TasksCancelByFilterResp> {
873+
self.url.set_path("tasks/cancel");
874+
let resp = self
875+
.http_client
876+
.post(self.url.as_str())
877+
.bearer_auth(&self.credential)
878+
.json(&req)
879+
.send()
880+
.await
881+
.map_err(map_reqwest_err)?;
882+
if resp.status().is_success() {
883+
let resp = resp
884+
.json::<TasksCancelByFilterResp>()
885+
.await
886+
.map_err(RequestError::from)?;
887+
Ok(resp)
888+
} else {
889+
Err(get_error_from_resp(resp).await.into())
890+
}
891+
}
892+
845893
pub async fn update_task_labels(
846894
&mut self,
847895
uuid: Uuid,

netmito/src/client/mod.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,15 @@ impl MitoClient {
736736
.await
737737
}
738738

739+
pub async fn workers_batch_cancel(
740+
&mut self,
741+
args: CancelWorkersArgs,
742+
) -> crate::error::Result<WorkersShutdownByFilterResp> {
743+
self.http_client
744+
.shutdown_workers_by_filter(args.into())
745+
.await
746+
}
747+
739748
pub async fn workers_update_tags(
740749
&mut self,
741750
args: WorkerUpdateTagsArgs,
@@ -776,6 +785,13 @@ impl MitoClient {
776785
self.http_client.cancel_task_by_uuid(uuid).await
777786
}
778787

788+
pub async fn tasks_batch_cancel(
789+
&mut self,
790+
args: CancelTasksArgs,
791+
) -> crate::error::Result<TasksCancelByFilterResp> {
792+
self.http_client.cancel_tasks_by_filter(args.into()).await
793+
}
794+
779795
pub async fn tasks_update_labels(
780796
&mut self,
781797
args: UpdateTaskLabelsArgs,
@@ -1012,6 +1028,18 @@ impl MitoClient {
10121028
tracing::error!("{}", e);
10131029
}
10141030
},
1031+
WorkersCommands::CancelMany(args) => match self.workers_batch_cancel(args).await {
1032+
Ok(resp) => {
1033+
tracing::info!(
1034+
"Shutdown {} workers in group {}",
1035+
resp.shutdown_count,
1036+
resp.group_name
1037+
);
1038+
}
1039+
Err(e) => {
1040+
tracing::error!("{}", e);
1041+
}
1042+
},
10151043
WorkersCommands::UpdateTags(args) => match self.workers_update_tags(args).await {
10161044
Ok(_) => {
10171045
tracing::info!("Worker tags updated successfully");
@@ -1308,6 +1336,18 @@ impl MitoClient {
13081336
tracing::error!("{}", e);
13091337
}
13101338
},
1339+
TasksCommands::CancelMany(args) => match self.tasks_batch_cancel(args).await {
1340+
Ok(resp) => {
1341+
tracing::info!(
1342+
"Cancelled {} tasks in group {}",
1343+
resp.cancelled_count,
1344+
resp.group_name
1345+
);
1346+
}
1347+
Err(e) => {
1348+
tracing::error!("{}", e);
1349+
}
1350+
},
13111351
TasksCommands::UpdateLabels(args) => match self.tasks_update_labels(args).await {
13121352
Ok(_) => {
13131353
tracing::info!("Task labels updated successfully");

netmito/src/config/client/tasks.rs

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ use uuid::Uuid;
55
use crate::{
66
config::client::parse_resources,
77
entity::state::{TaskExecState, TaskState},
8-
schema::{ChangeTaskReq, RemoteResourceDownload, TaskSpec, TasksQueryReq, UpdateTaskLabelsReq},
8+
schema::{
9+
ChangeTaskReq, RemoteResourceDownload, TaskSpec, TasksCancelByFilterReq, TasksQueryReq,
10+
UpdateTaskLabelsReq,
11+
},
912
};
1013

1114
use super::{parse_key_val, parse_watch_task, ArtifactsArgs};
@@ -26,6 +29,8 @@ pub enum TasksCommands {
2629
Query(QueryTasksArgs),
2730
/// Cancel a task
2831
Cancel(CancelTaskArgs),
32+
/// Cancel multiple tasks subject to the filter
33+
CancelMany(CancelTasksArgs),
2934
/// Replace labels of a task
3035
UpdateLabels(UpdateTaskLabelsArgs),
3136
/// Update the spec of a task
@@ -117,6 +122,31 @@ pub struct CancelTaskArgs {
117122
pub uuid: Uuid,
118123
}
119124

125+
#[derive(Serialize, Debug, Deserialize, Args, Clone)]
126+
pub struct CancelTasksArgs {
127+
/// The username of the creator who submitted the tasks
128+
#[arg(short, long, num_args = 0.., value_delimiter = ',')]
129+
pub creators: Vec<String>,
130+
/// The name of the group the tasks belong to
131+
#[arg(short, long)]
132+
pub group: Option<String>,
133+
/// The tags of the tasks
134+
#[arg(short, long, num_args = 0.., value_delimiter = ',')]
135+
pub tags: Vec<String>,
136+
/// The labels of the tasks
137+
#[arg(short, long, num_args = 0.., value_delimiter = ',')]
138+
pub labels: Vec<String>,
139+
/// The state of the tasks
140+
#[arg(short, long, num_args = 0.., value_delimiter = ',')]
141+
pub state: Vec<TaskState>,
142+
/// The exit status of the tasks, support operators like `=`(default), `!=`, `<`, `<=`, `>`, `>=`
143+
#[arg(short, long)]
144+
pub exit_status: Option<String>,
145+
/// The priority of the tasks, support operators like `=`(default), `!=`, `<`, `<=`, `>`, `>=`
146+
#[arg(short, long)]
147+
pub priority: Option<String>,
148+
}
149+
120150
#[derive(Serialize, Debug, Deserialize, Args, Clone)]
121151
pub struct UpdateTaskLabelsArgs {
122152
/// The UUID of the task
@@ -222,3 +252,33 @@ impl From<UpdateTaskLabelsArgs> for UpdateTaskLabelsReq {
222252
}
223253
}
224254
}
255+
256+
impl From<CancelTasksArgs> for TasksCancelByFilterReq {
257+
fn from(args: CancelTasksArgs) -> Self {
258+
Self {
259+
creator_usernames: if args.creators.is_empty() {
260+
None
261+
} else {
262+
Some(args.creators.into_iter().collect())
263+
},
264+
group_name: args.group,
265+
tags: if args.tags.is_empty() {
266+
None
267+
} else {
268+
Some(args.tags.into_iter().collect())
269+
},
270+
labels: if args.labels.is_empty() {
271+
None
272+
} else {
273+
Some(args.labels.into_iter().collect())
274+
},
275+
states: if args.state.is_empty() {
276+
None
277+
} else {
278+
Some(args.state.into_iter().collect())
279+
},
280+
exit_status: args.exit_status,
281+
priority: args.priority,
282+
}
283+
}
284+
}

netmito/src/config/client/workers.rs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::{
66
entity::role::GroupWorkerRole,
77
schema::{
88
RemoveGroupWorkerRoleReq, ReplaceWorkerLabelsReq, ReplaceWorkerTagsReq,
9-
UpdateGroupWorkerRoleReq, WorkersQueryReq,
9+
UpdateGroupWorkerRoleReq, WorkerShutdownOp, WorkersQueryReq, WorkersShutdownByFilterReq,
1010
},
1111
};
1212

@@ -22,6 +22,8 @@ pub struct WorkersArgs {
2222
pub enum WorkersCommands {
2323
/// Cancel a worker
2424
Cancel(CancelWorkerArgs),
25+
/// Cancel multiple workers subject to the filter
26+
CancelMany(CancelWorkersArgs),
2527
/// Replace tags of a worker
2628
UpdateTags(WorkerUpdateTagsArgs),
2729
/// Replace labels of a worker
@@ -46,6 +48,29 @@ pub struct CancelWorkerArgs {
4648
pub force: bool,
4749
}
4850

51+
#[derive(Serialize, Debug, Deserialize, Args, Clone)]
52+
pub struct CancelWorkersArgs {
53+
/// The name of the group has access to the workers
54+
#[arg(short, long)]
55+
pub group: Option<String>,
56+
/// The role of the group on the workers
57+
#[arg(short, long, num_args = 0.., value_delimiter = ',')]
58+
pub role: Vec<GroupWorkerRole>,
59+
/// The tags of the workers
60+
#[arg(short, long, num_args = 0.., value_delimiter = ',')]
61+
pub tags: Vec<String>,
62+
/// The labels of the workers
63+
#[arg(short, long, num_args = 0.., value_delimiter = ',')]
64+
pub labels: Vec<String>,
65+
/// The username of the creator
66+
#[arg(long)]
67+
pub creator: Option<String>,
68+
/// Whether to force the workers to shutdown.
69+
/// If not specified, the workers will be try to shutdown gracefully
70+
#[arg(short, long)]
71+
pub force: bool,
72+
}
73+
4974
#[derive(Serialize, Debug, Deserialize, Args, Clone)]
5075
pub struct WorkerUpdateTagsArgs {
5176
/// The UUID of the worker
@@ -169,3 +194,32 @@ impl From<RemoveWorkerGroupArgs> for RemoveGroupWorkerRoleReq {
169194
}
170195
}
171196
}
197+
198+
impl From<CancelWorkersArgs> for WorkersShutdownByFilterReq {
199+
fn from(args: CancelWorkersArgs) -> Self {
200+
Self {
201+
group_name: args.group,
202+
role: if args.role.is_empty() {
203+
None
204+
} else {
205+
Some(args.role.into_iter().collect())
206+
},
207+
tags: if args.tags.is_empty() {
208+
None
209+
} else {
210+
Some(args.tags.into_iter().collect())
211+
},
212+
labels: if args.labels.is_empty() {
213+
None
214+
} else {
215+
Some(args.labels.into_iter().collect())
216+
},
217+
creator_username: args.creator,
218+
op: if args.force {
219+
WorkerShutdownOp::Force
220+
} else {
221+
WorkerShutdownOp::Graceful
222+
},
223+
}
224+
}
225+
}

netmito/src/entity/role.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,18 @@ use serde::{Deserialize, Serialize};
66

77
/// The role of a user to a group.
88
#[derive(
9-
EnumIter, DeriveActiveEnum, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, ValueEnum,
9+
EnumIter,
10+
DeriveActiveEnum,
11+
Clone,
12+
Debug,
13+
PartialEq,
14+
Eq,
15+
Hash,
16+
Serialize,
17+
Deserialize,
18+
ValueEnum,
19+
Ord,
20+
PartialOrd,
1021
)]
1122
#[sea_orm(rs_type = "i32", db_type = "Integer")]
1223
pub enum UserGroupRole {

0 commit comments

Comments
 (0)