-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient_1.py
More file actions
41 lines (34 loc) · 1.53 KB
/
client_1.py
File metadata and controls
41 lines (34 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from flwr.client.numpy_client import NumPyClient
from task import load_model, load_data
import flwr as fl
class FlowerClient(NumPyClient):
def __init__(self, model, data, epochs, batch_size, verbose):
self.model = model
self.x_train, self.y_train, self.x_test, self.y_test = data
self.epochs = epochs
self.batch_size = batch_size
self.verbose = verbose
def get_parameters(self, config):
"""Return the current model parameters."""
return self.model.get_weights()
def set_parameters(self, parameters):
"""Set the model parameters."""
self.model.set_weights(parameters)
def fit(self, parameters, config):
"""Fit the model on local data."""
self.set_parameters(parameters)
self.model.fit(self.x_train, self.y_train, epochs=self.epochs, batch_size=self.batch_size, verbose=self.verbose)
return self.model.get_weights(), len(self.x_train), {}
def evaluate(self, parameters, config):
"""Evaluate the model on the local test data."""
self.set_parameters(parameters)
loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
return loss, len(self.x_test), {"accuracy": accuracy}
# Load data for client 1 (partition 0)
data = load_data(partition_id=0, num_partitions=2)
model = load_model()
# Run the client using the new `start_client()` method
fl.client.start_client(
server_address="localhost:8080",
client=FlowerClient(model, data, epochs=1, batch_size=32, verbose=1).to_client()
)