-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathserver.py
More file actions
89 lines (71 loc) · 3.15 KB
/
server.py
File metadata and controls
89 lines (71 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
from pathlib import Path
from typing import Any
import flwr as fl
import pandas as pd
from flwr.common.parameter import ndarrays_to_parameters
from flwr.common.typing import Parameters
from examples.models.mlp_classifier import MLP
from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager
from fl4health.feature_alignment.tab_features_info_encoder import TabularFeaturesInfoEncoder
from fl4health.metrics.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.servers.tabular_feature_alignment_server import TabularFeatureAlignmentServer
from fl4health.strategies.basic_fedavg import BasicFedAvg
from fl4health.utils.config import load_config
# This data path is used to create a "source of truth" on the server-side as an example.
# This is used if the config specifies source_specified as true
DATA_PATH = "examples/feature_alignment_example/mimic3d_hospital1.csv"
CONFIG_PATH = "examples/feature_alignment_example/config.yaml"
def get_initial_model_parameters(input_dim: int, output_dim: int) -> Parameters:
initial_model = MLP(input_dim, output_dim)
return ndarrays_to_parameters([val.cpu().numpy() for _, val in initial_model.state_dict().items()])
def construct_tab_feature_info_encoder(
data_path: Path, id_column: str, target_column: str
) -> TabularFeaturesInfoEncoder:
df = pd.read_csv(data_path)
return TabularFeaturesInfoEncoder.encoder_from_dataframe(df, id_column, target_column)
def main(config: dict[str, Any]) -> None:
client_manager = PoissonSamplingClientManager()
strategy = BasicFedAvg(
min_fit_clients=config["n_clients"],
min_evaluate_clients=config["n_clients"],
# Server waits for min_available_clients before starting FL rounds
min_available_clients=config["n_clients"],
on_fit_config_fn=None,
on_evaluate_config_fn=None,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=None,
)
source_specified = config["source_specified"]
if source_specified:
tab_feature_info_encoder_hospital1 = construct_tab_feature_info_encoder(
Path(DATA_PATH), "hadm_id", "LOSgroupNum"
)
else:
tab_feature_info_encoder_hospital1 = None
server = TabularFeatureAlignmentServer(
client_manager=client_manager,
config=config,
initialize_parameters=get_initial_model_parameters,
strategy=strategy,
tabular_features_source_of_truth=tab_feature_info_encoder_hospital1,
accept_failures=False,
)
fl.server.start_server(
server=server,
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FL Server Main")
parser.add_argument(
"--config_path",
action="store",
type=str,
help="Path to configuration file.",
default=CONFIG_PATH,
)
args = parser.parse_args()
config = load_config(args.config_path)
main(config)