-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathaggregation_algorithm.py
More file actions
69 lines (53 loc) · 2.23 KB
/
aggregation_algorithm.py
File metadata and controls
69 lines (53 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Script for aggregating results from local trainings."""
import json
from pathlib import Path
from typing import Any, List
import numpy as np
import requests
from feltlabs.config import AggregationConfig, parse_aggregation_args
from feltlabs.core.cryptography import decrypt_nacl
def load_local_models(config: AggregationConfig) -> List[bytes]:
"""Load results from local algorithm (models) for aggregation.
The model URLs are provided in custom algorithm data file. This will usually be
URLs to models from local trainings. For testing purposes this can be local paths.
Args:
config: config object containing path to custom data and private key
Returns:
list of bytes - local results loaded as bytes
"""
with config.custom_data_path.open("r") as f:
conf = json.load(f)
data_array = []
for url in conf["model_urls"]:
if config.download_models:
if isinstance(url, dict):
res = requests.get(**url)
elif isinstance(url, str):
res = requests.get(url)
else:
raise Exception(f"Invalid model URL (type {type(url)}): {url}")
data_array.append(res.content)
else:
data_array.append(Path(url).read_bytes())
return [decrypt_nacl(config.private_key, val) for val in data_array]
def main(config: AggregationConfig):
"""Main function executing the local result loading, aggregation and saving outputs.
Args:
config: training config object provided by FELT containing all paths
"""
# Load data as numpy array
local_models = load_local_models(config)
# Run the aggregation algorithm
models = [np.frombuffer(m) for m in local_models]
final_value = np.mean(models)
# Get final output values
model_bytes = bytes(str(final_value), "utf-8")
# Save models into output folder. You have to name output file as "model"
with open(config.output_folder / "model", "wb+") as f:
f.write(model_bytes)
print("Training finished.")
if __name__ == "__main__":
# Get config - we recommend using config parser provided by FELT Labs
# It automatically provides all input and output paths
config = parse_aggregation_args()
main(config)