-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfederated_learning_demo.rs
More file actions
154 lines (136 loc) Β· 5.69 KB
/
federated_learning_demo.rs
File metadata and controls
154 lines (136 loc) Β· 5.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
//! Federated learning with distributed coordination demo.
//!
//! This example shows how to use the enhanced federated learning crate
//! with distributed coordination across multiple agents.
use std::sync::Arc;
use tokio::sync::mpsc;
use federated_learning::prelude::*;
use federated_learning::model::{Model, Layer, LayerType};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== Federated Learning with Distributed Coordination Demo ===");
// Create a simple neural network model
let model = Model {
name: "demo-model".to_string(),
layers: vec![
Layer {
layer_type: LayerType::Dense(128),
activation: Some("relu".to_string()),
trainable: true,
},
Layer {
layer_type: LayerType::Dense(64),
activation: Some("relu".to_string()),
trainable: true,
},
Layer {
layer_type: LayerType::Dense(10),
activation: Some("softmax".to_string()),
trainable: true,
},
],
parameter_count: 128 * 64 + 64 * 10, // simplified
version: 1,
metadata: std::collections::HashMap::new(),
};
// Create event channel for monitoring
let (event_tx, mut event_rx) = mpsc::unbounded_channel();
// Create distributed training configuration
let config = DistributedTrainingConfig {
min_agents: 2,
max_agents: Some(5),
max_rounds: Some(3),
round_timeout_secs: 60,
aggregation: AggregationConfig::default(),
enable_privacy: true,
checkpoint_interval: Some(1),
};
// Create coordinator
let coordinator = Arc::new(DistributedTrainingCoordinator::new(
config,
model,
event_tx,
)?);
println!("β Created distributed training coordinator");
// Start event listener task
let coordinator_clone = coordinator.clone();
let event_task = tokio::spawn(async move {
while let Some(event) = event_rx.recv().await {
match event {
DistributedTrainingEvent::RoundStarted { round_id, participants, model_version } => {
println!("[EVENT] Round {} started with {} participants (model v{})",
round_id, participants, model_version);
}
DistributedTrainingEvent::UpdateReceived { client_id, round_id, samples } => {
println!("[EVENT] Update from {} for round {} ({} samples)",
client_id, round_id, samples);
}
DistributedTrainingEvent::AggregationCompleted { round_id, aggregated_model_version, participant_count } => {
println!("[EVENT] Round {} aggregated (model v{}, {} participants)",
round_id, aggregated_model_version, participant_count);
}
DistributedTrainingEvent::RoundFailed { round_id, reason } => {
println!("[EVENT] Round {} failed: {}", round_id, reason);
}
}
}
});
// Simulate registering participants
println!("\n=== Registering Participants ===");
for i in 1..=3 {
let agent_id = format!("agent-{}", i);
// In a real scenario, you'd create actual FederatedClient instances
// For demo, we'll use placeholder
// let client = FederatedClient::new(...);
// coordinator.register_participant(agent_id, client).await?;
println!(" Registered {}", agent_id);
}
// Get initial stats
let stats = coordinator.get_stats().await;
println!("\n=== Initial Stats ===");
println!("Round: {}", stats.round);
println!("Active participants: {}", stats.active_participants);
println!("Model version: {}", stats.model_version);
// Start a training round
println!("\n=== Starting Training Round ===");
let round_started = coordinator.start_round().await?;
if round_started {
println!("Round started successfully");
} else {
println!("Not enough participants to start round");
}
// Simulate receiving updates (in a real scenario, these would come from agents)
println!("\n=== Simulating Updates ===");
for i in 1..=2 {
let agent_id = format!("agent-{}", i);
// Create a mock update
let update = ClientUpdate {
client_id: agent_id.clone(),
round_id: 1,
parameters: vec![0.1, 0.2, 0.3], // dummy parameters
sample_count: 100 * i,
metadata: std::collections::HashMap::new(),
};
// In a real implementation, you'd call:
// coordinator.receive_update(&agent_id, update).await?;
println!(" Simulated update from {} ({} samples)", agent_id, update.sample_count);
}
// Get updated stats
let stats = coordinator.get_stats().await;
println!("\n=== Updated Stats ===");
println!("Round: {}", stats.round);
println!("Total updates: {}", stats.total_updates);
// Demonstrate mesh integration
println!("\n=== Mesh Integration ===");
let mesh_integration = MeshFederatedIntegration::new(coordinator.clone());
println!("Created mesh integration");
// Simulate handling a mesh message
let mock_message = b"mock federated learning message";
mesh_integration.handle_message("agent-1", mock_message).await?;
println!("Handled mock mesh message");
// Stop event listener
drop(coordinator); // This will cause the event channel to close
let _ = event_task.await;
println!("\n=== Demo Complete ===");
Ok(())
}