Skip to content

Commit d3e2cf7

Browse files
committed
refactor: remove annoying state param
1 parent 588555d commit d3e2cf7

1 file changed

Lines changed: 61 additions & 81 deletions

File tree

src/main.rs

Lines changed: 61 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use clap::Parser;
22
use serde::{Deserialize, Serialize};
3-
use std::{collections::HashMap, fmt::Display};
3+
use std::{cell::RefCell, collections::HashMap, fmt::Display, rc::Rc};
44
use tokio::fs;
55
mod ztapi;
66
use crate::ztapi::Client;
@@ -16,14 +16,13 @@ async fn main() -> anyhow::Result<()> {
1616
}
1717

1818
env_logger::init();
19-
Args::parse().apply((), &mut ()).await
19+
Args::parse().apply(()).await
2020
}
2121

2222
trait Apply {
2323
type Context;
24-
type PersistentState;
2524

26-
async fn apply(self, ctx: Self::Context, ps: &mut Self::PersistentState) -> anyhow::Result<()>;
25+
async fn apply(self, ctx: Self::Context) -> anyhow::Result<()>;
2726
}
2827

2928
#[derive(Debug, Parser)]
@@ -51,28 +50,14 @@ struct Args {
5150
/// Base URL for the ZeroTier API endpoint
5251
endpoint: String,
5352

54-
#[clap(long = "state_path", short = 's', env = "STATE_PATH", default_value = "state.json")]
55-
/// Path to state file which contains datas used only by ztcli
56-
state_path: String,
57-
58-
#[clap(long = "disable_state", short = 'S', env = "DISABLE_STATE", default_value = "false")]
59-
/// If true, ztcli will not read state from filesystem, and will not sync changes to it
60-
disable_state: bool,
61-
6253
#[clap(subcommand)]
6354
cmd: Command,
6455
}
6556

66-
#[derive(Serialize, Deserialize, Default)]
67-
struct States {
68-
ctrl_net_states: CtrlNetStates,
69-
}
70-
7157
impl Apply for Args {
7258
type Context = ();
73-
type PersistentState = ();
7459

75-
async fn apply(self, _: Self::Context, _: &mut Self::PersistentState) -> anyhow::Result<()> {
60+
async fn apply(self, _: Self::Context) -> anyhow::Result<()> {
7661
let Some(token) = self.token.clone().or_else(|| match std::fs::read_to_string(&self.token_path) {
7762
Ok(content) => Some(content),
7863
Err(e) => {
@@ -83,20 +68,11 @@ impl Apply for Args {
8368
anyhow::bail!("No authentication token provided");
8469
};
8570

86-
let mut ps: States = if self.disable_state || !std::fs::exists(&self.state_path)? {
87-
Default::default()
88-
} else {
89-
let content = fs::read_to_string(&self.state_path).await?;
90-
serde_json::from_str(&content)?
91-
};
71+
//TODO:
9272

9373
let client = ztapi::Client::new(self.endpoint.as_str(), &token)?;
94-
self.cmd.apply(client, &mut ps.ctrl_net_states).await?;
74+
self.cmd.apply(client).await?;
9575

96-
if !self.disable_state {
97-
let content = serde_json::to_string(&ps)?;
98-
fs::write(&self.state_path, content).await?;
99-
}
10076
Ok(())
10177
}
10278
}
@@ -136,12 +112,11 @@ enum Command {
136112

137113
impl Apply for Command {
138114
type Context = Client;
139-
type PersistentState = CtrlNetStates;
140115

141-
async fn apply(self, client: Self::Context, ps: &mut Self::PersistentState) -> anyhow::Result<()> {
116+
async fn apply(self, client: Self::Context) -> anyhow::Result<()> {
142117
match self {
143118
#[cfg(feature = "clap_complete")]
144-
Self::Completions(args) => args.apply((), &mut ()).await?,
119+
Self::Completions(args) => args.apply(()).await?,
145120
#[cfg(feature = "clap_usage")]
146121
Self::Usage => {
147122
use clap::CommandFactory as _;
@@ -168,15 +143,15 @@ impl Apply for Command {
168143
pretty_print(&i);
169144
}
170145
}
171-
Self::Network(args) => args.apply(client, &mut ()).await?,
146+
Self::Network(args) => args.apply(client).await?,
172147
Self::ListPeers => {
173148
let r = client.get_peers().await?;
174149
log::info!("Peers information: {:?}", r);
175150
for i in r {
176151
pretty_print(&i);
177152
}
178153
}
179-
Self::PeerInfo(args) => args.apply(client, &mut ()).await?,
154+
Self::PeerInfo(args) => args.apply(client).await?,
180155

181156
Self::ControllerStatus => {
182157
let r = client.get_controller_status().await?;
@@ -190,7 +165,7 @@ impl Apply for Command {
190165
pretty_print(&i);
191166
}
192167
}
193-
Self::Controller(args) => args.apply(client, ps).await?,
168+
Self::Controller(args) => args.apply(client).await?,
194169
}
195170
Ok(())
196171
}
@@ -205,9 +180,8 @@ struct CompletionsArgs {
205180
#[cfg(feature = "clap_complete")]
206181
impl Apply for CompletionsArgs {
207182
type Context = ();
208-
type PersistentState = ();
209183

210-
async fn apply(self, _: Self::Context, _: &mut Self::PersistentState) -> anyhow::Result<()> {
184+
async fn apply(self, _: Self::Context) -> anyhow::Result<()> {
211185
use clap::CommandFactory as _;
212186

213187
clap_complete::generate(
@@ -241,11 +215,8 @@ struct CtrlNetStates {
241215

242216
impl Apply for CtrlNetArgs {
243217
type Context = Client;
244-
type PersistentState = CtrlNetStates;
245218

246-
async fn apply(self, client: Self::Context, ps: &mut Self::PersistentState) -> anyhow::Result<()> {
247-
self.cmd.apply((client, self.network_id), ps).await
248-
}
219+
async fn apply(self, client: Self::Context) -> anyhow::Result<()> { self.cmd.apply((client, self.network_id)).await }
249220
}
250221

251222
#[derive(Debug, Parser)]
@@ -269,17 +240,21 @@ enum CtrlNetCmds {
269240
Members,
270241
}
271242

243+
type CtrlNetMemTagStatesRef = Rc<RefCell<Option<CtrlNetMemTagStates>>>;
244+
272245
impl Apply for CtrlNetCmds {
273246
type Context = (Client, String);
274-
type PersistentState = CtrlNetStates;
275247

276-
async fn apply(self, (client, network_id): Self::Context, ps: &mut Self::PersistentState) -> anyhow::Result<()> {
277-
let next_ps = if let Some(inner) = ps.networks.get_mut(&network_id) {
278-
inner
279-
} else {
280-
ps.networks.insert(network_id.clone(), Default::default());
281-
ps.networks.get_mut(&network_id).unwrap() // We can assume this unwrap is always safe
248+
async fn apply(self, (client, network_id): Self::Context) -> anyhow::Result<()> {
249+
const STATE_PATH: &str = "ztcli.json";
250+
251+
let mut controller_state: CtrlNetStates = {
252+
let content = fs::read_to_string(STATE_PATH).await?;
253+
serde_json::from_str(&content)?
282254
};
255+
256+
let network_state = controller_state.networks.remove(&network_id).unwrap_or_default();
257+
let network_state = Rc::new(RefCell::new(Some(network_state)));
283258
match self {
284259
Self::Create(args) => {
285260
let r = client.generate_controller_network(&network_id, &(*args).into()).await?;
@@ -291,7 +266,7 @@ impl Apply for CtrlNetCmds {
291266
log::info!("Network updated successfully: {:?}", r);
292267
pretty_print(&r);
293268
}
294-
Self::Member(args) => args.apply((client, network_id), next_ps).await?,
269+
Self::Member(args) => args.apply((client, network_id.clone(), network_state.clone())).await?,
295270
Self::Info => {
296271
let r = client.get_controller_network(&network_id).await?;
297272
log::info!("Network information: {:?}", r);
@@ -301,16 +276,29 @@ impl Apply for CtrlNetCmds {
301276
let r = client.get_controller_network_members(&network_id).await?;
302277

303278
log::info!("Members information: {:?}", r);
304-
for (k, v) in r {
305-
let msg = if let Some(tag) = next_ps.tags.get_mut(&k) {
306-
format!("{k}: {v} ({tag})")
307-
} else {
308-
format!("{k}: {v}")
309-
};
310-
pretty_print(&msg);
279+
let network_state = controller_state.networks.get(network_id.as_str());
280+
if let Some(network_state) = network_state {
281+
for (k, v) in r {
282+
let msg = if let Some(tag) = network_state.tags.get(&k) {
283+
format!("{k}: {v} ({tag})")
284+
} else {
285+
format!("{k}: {v}")
286+
};
287+
pretty_print(&msg);
288+
}
289+
} else {
290+
for (k, v) in r {
291+
pretty_print(&format!("{k}: {v}"));
292+
}
311293
}
312294
}
313295
}
296+
let content = {
297+
let mut state = network_state.borrow_mut();
298+
controller_state.networks.insert(network_id, state.take().unwrap());
299+
serde_json::to_string(&controller_state)?
300+
};
301+
fs::write(STATE_PATH, content).await?;
314302
Ok(())
315303
}
316304
}
@@ -473,11 +461,10 @@ struct CtrlNetMemArgs {
473461
}
474462

475463
impl Apply for CtrlNetMemArgs {
476-
type Context = (Client, String);
477-
type PersistentState = CtrlNetMemTagStates;
464+
type Context = (Client, String, CtrlNetMemTagStatesRef);
478465

479-
async fn apply(self, (client, network_id): Self::Context, ps: &mut Self::PersistentState) -> anyhow::Result<()> {
480-
self.cmd.apply((client, network_id, self.member_id), ps).await
466+
async fn apply(self, (client, network_id, ps): Self::Context) -> anyhow::Result<()> {
467+
self.cmd.apply((client, network_id, self.member_id, ps)).await
481468
}
482469
}
483470

@@ -494,12 +481,9 @@ enum CtrlNetMemCmds {
494481
}
495482

496483
impl Apply for CtrlNetMemCmds {
497-
type Context = (Client, String, String);
498-
type PersistentState = CtrlNetMemTagStates;
484+
type Context = (Client, String, String, CtrlNetMemTagStatesRef);
499485

500-
async fn apply(
501-
self, (client, network_id, member_id): Self::Context, ps: &mut Self::PersistentState,
502-
) -> anyhow::Result<()> {
486+
async fn apply(self, (client, network_id, member_id, ps): Self::Context) -> anyhow::Result<()> {
503487
match self {
504488
Self::Info => {
505489
let r = client.get_controller_network_member(&network_id, &member_id).await?;
@@ -513,7 +497,7 @@ impl Apply for CtrlNetMemCmds {
513497
pretty_print(&r);
514498
}
515499
Self::Tag(args) => {
516-
args.apply(member_id, ps).await?;
500+
args.apply((member_id, ps)).await?;
517501
}
518502
}
519503
Ok(())
@@ -601,19 +585,20 @@ struct CtrlNetMemTagStates {
601585
}
602586

603587
impl Apply for CtrlNetMemTagArgs {
604-
type Context = String;
605-
type PersistentState = CtrlNetMemTagStates;
588+
type Context = (String, CtrlNetMemTagStatesRef);
606589

607-
async fn apply(self, id: Self::Context, ps: &mut Self::PersistentState) -> anyhow::Result<()> {
590+
async fn apply(self, (id, ref_state): Self::Context) -> anyhow::Result<()> {
591+
let mut borrowed_state = ref_state.borrow_mut();
592+
let network_state = borrowed_state.as_mut().unwrap();
608593
if let Some(tag) = self.tag {
609594
if tag.is_empty() {
610-
ps.tags.remove(&id);
595+
network_state.tags.remove(&id);
611596
log::info!("Tag removed for {}", id);
612597
} else {
613-
ps.tags.insert(id.clone(), tag.clone());
598+
network_state.tags.insert(id.clone(), tag.clone());
614599
log::info!("Tag set for {}: {}", id, tag);
615600
}
616-
} else if let Some(tag) = ps.tags.get(&id) {
601+
} else if let Some(tag) = network_state.tags.get(&id) {
617602
pretty_print(tag);
618603
} else {
619604
log::info!("No tag found for {}", id);
@@ -631,9 +616,8 @@ struct PeerInfoArgs {
631616

632617
impl Apply for PeerInfoArgs {
633618
type Context = Client;
634-
type PersistentState = ();
635619

636-
async fn apply(self, client: Self::Context, _: &mut Self::PersistentState) -> anyhow::Result<()> {
620+
async fn apply(self, client: Self::Context) -> anyhow::Result<()> {
637621
let r = client.get_peer(&self.peer_id).await?;
638622
log::info!("Peer information: {:?}", r);
639623
pretty_print(&r);
@@ -653,11 +637,8 @@ struct NetEditArgs {
653637

654638
impl Apply for NetEditArgs {
655639
type Context = Client;
656-
type PersistentState = ();
657640

658-
async fn apply(self, client: Self::Context, ps: &mut Self::PersistentState) -> anyhow::Result<()> {
659-
self.cmd.apply((client, self.network_id), ps).await
660-
}
641+
async fn apply(self, client: Self::Context) -> anyhow::Result<()> { self.cmd.apply((client, self.network_id)).await }
661642
}
662643

663644
#[derive(Debug, Parser)]
@@ -674,9 +655,8 @@ enum NetEditCmds {
674655

675656
impl Apply for NetEditCmds {
676657
type Context = (Client, String);
677-
type PersistentState = ();
678658

679-
async fn apply(self, (client, network_id): Self::Context, _: &mut Self::PersistentState) -> anyhow::Result<()> {
659+
async fn apply(self, (client, network_id): Self::Context) -> anyhow::Result<()> {
680660
match self {
681661
Self::Info => {
682662
let r = client.get_network(&network_id).await?;

0 commit comments

Comments
 (0)