-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbaseline.py
More file actions
42 lines (33 loc) · 1.19 KB
/
baseline.py
File metadata and controls
42 lines (33 loc) · 1.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#FYI: This baseline ignores all the model information and predicts all the clients as honest.
import json
import csv
import torch
import os
# Paths
METADATA_PATH = "metadata.json"
MODEL_PATH = "model/global_model.pt"
OUTPUT_PATH = "submission.csv"
def main():
# Load metadata
with open(METADATA_PATH, "r") as f:
metadata = json.load(f)
num_clients = metadata["federated_learning"]["num_clients"]
# (Optional) Load model to show how it would be done
# This baseline does NOT use the model for predictions
try:
state_dict = torch.load(MODEL_PATH, map_location="cpu")
print("Loaded global_model.pt successfully.")
except Exception as e:
print("Warning: could not load model:", e)
# Generate trivial predictions: all clients are honest
predictions = []
for client_id in range(num_clients):
predictions.append((client_id, "honest"))
# Write submission file
with open(OUTPUT_PATH, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["client_id", "predicted_label"])
writer.writerows(predictions)
print(f"Baseline submission written to {OUTPUT_PATH}")
if __name__ == "__main__":
main()