Skip to content

Commit a276014

Browse files
committed
participant n evaluation tested
1 parent c63f0e3 commit a276014

4 files changed

Lines changed: 378 additions & 62 deletions

File tree

metadata.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"num_classes": 10,
88

99
"federated_learning": {
10-
"num_clients": 100,
10+
"num_clients": 260,
1111
"malicious_fraction": 0.2,
1212
"num_rounds": 20,
1313
"local_epochs": 1,

private/evaluate.py

Lines changed: 116 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,117 @@
1+
# ==========================
2+
# Robust Evaluation Script
3+
# ==========================
4+
5+
import argparse
6+
import pandas as pd
7+
import torch
8+
import numpy as np
9+
10+
from sklearn.metrics import (
11+
accuracy_score,
12+
f1_score,
13+
classification_report,
14+
confusion_matrix
15+
)
16+
17+
# --------------------------
18+
# Helpers
19+
# --------------------------
20+
21+
import pandas as pd
122
import torch
2-
import csv
3-
import sys
4-
from sklearn.metrics import accuracy_score, f1_score
5-
6-
# -------------------------
7-
# Paths (adjust if needed)
8-
# -------------------------
9-
GROUND_TRUTH_PATH = "ground_truth_client_labels.pt"
10-
SUBMISSION_PATH = sys.argv[1] # path to submission.csv
11-
12-
# -------------------------
13-
# Load ground truth
14-
# -------------------------
15-
ground_truth = torch.load(GROUND_TRUTH_PATH)
16-
17-
# Convert to sorted lists
18-
gt_labels = []
19-
pred_labels = []
20-
21-
# -------------------------
22-
# Load submission
23-
# -------------------------
24-
submission = {}
25-
26-
with open(SUBMISSION_PATH, "r") as f:
27-
reader = csv.DictReader(f)
28-
if "client_id" not in reader.fieldnames or "predicted_label" not in reader.fieldnames:
29-
raise ValueError("Submission must contain 'client_id' and 'predicted_label' columns.")
30-
31-
for row in reader:
32-
client_id = int(row["client_id"])
33-
label = row["predicted_label"].strip().lower()
34-
35-
if label not in {"honest", "malicious"}:
36-
raise ValueError(f"Invalid label '{label}' for client {client_id}")
37-
38-
submission[client_id] = label
39-
40-
# -------------------------
41-
# Match predictions to ground truth
42-
# -------------------------
43-
for client_id in sorted(ground_truth.keys()):
44-
if client_id not in submission:
45-
raise ValueError(f"Missing prediction for client_id {client_id}")
46-
47-
gt_labels.append(ground_truth[client_id])
48-
pred_labels.append(submission[client_id])
49-
50-
# -------------------------
51-
# Compute metrics
52-
# -------------------------
53-
accuracy = accuracy_score(gt_labels, pred_labels)
54-
macro_f1 = f1_score(gt_labels, pred_labels, average="macro")
55-
56-
# -------------------------
57-
# Output results
58-
# -------------------------
59-
print("Evaluation Results")
60-
print("------------------")
61-
print(f"Accuracy : {accuracy:.4f}")
62-
print(f"Macro F1 : {macro_f1:.4f}")
23+
24+
sub = pd.read_csv("submission.csv")
25+
gt = torch.load("ground_truth_client_labels.pt")
26+
27+
print("Submission rows:", len(sub))
28+
print("Ground truth len:", len(gt))
29+
30+
31+
def load_ground_truth(path):
32+
gt = torch.load(path, map_location="cpu")
33+
34+
# Accept tensor/list/array/bool/int formats
35+
if isinstance(gt, torch.Tensor):
36+
gt = gt.cpu().numpy()
37+
38+
gt = np.array(gt)
39+
40+
# Convert to label strings
41+
labels = ["malicious" if x else "honest" for x in gt]
42+
return labels
43+
44+
45+
def load_submission(path):
46+
df = pd.read_csv(path)
47+
48+
# Accept both column styles
49+
if "predicted_label" in df.columns:
50+
label_col = "predicted_label"
51+
elif "label" in df.columns:
52+
label_col = "label"
53+
else:
54+
raise ValueError("Submission must contain 'label' or 'predicted_label'")
55+
56+
# Accept both ID styles
57+
ids = df["client_id"]
58+
59+
if ids.dtype == object:
60+
ids = ids.str.replace("client_", "", regex=False).astype(int)
61+
62+
labels = df[label_col].tolist()
63+
64+
return ids.tolist(), labels
65+
66+
67+
# --------------------------
68+
# Main Evaluation
69+
# --------------------------
70+
71+
def evaluate(submission_path, gt_path):
72+
73+
gt_labels = load_ground_truth(gt_path)
74+
ids, pred_labels = load_submission(submission_path)
75+
76+
# Sort predictions by client_id
77+
pred_sorted = [x for _, x in sorted(zip(ids, pred_labels))]
78+
79+
if len(pred_sorted) != len(gt_labels):
80+
raise ValueError("Prediction length mismatch with ground truth")
81+
82+
# Metrics
83+
acc = accuracy_score(gt_labels, pred_sorted)
84+
macro_f1 = f1_score(gt_labels, pred_sorted, average="macro")
85+
86+
print("\n==============================")
87+
print(" Evaluation Results")
88+
print("==============================")
89+
print(f"Accuracy : {acc:.4f}")
90+
print(f"Macro F1 : {macro_f1:.4f}")
91+
92+
print("\nConfusion Matrix")
93+
print(confusion_matrix(gt_labels, pred_sorted))
94+
95+
print("\nClassification Report")
96+
print(classification_report(gt_labels, pred_sorted))
97+
98+
return acc, macro_f1
99+
100+
101+
# --------------------------
102+
# CLI
103+
# --------------------------
104+
105+
if __name__ == "__main__":
106+
import sys
107+
108+
if len(sys.argv) == 1:
109+
print("No submission provided — using submission.csv")
110+
evaluate("submission.csv", "ground_truth_client_labels.pt")
111+
else:
112+
parser = argparse.ArgumentParser()
113+
parser.add_argument("submission")
114+
parser.add_argument("--gt", default="ground_truth_client_labels.pt")
115+
args = parser.parse_args()
116+
evaluate(args.submission, args.gt)
117+

private/result-img.png

268 KB
Loading

0 commit comments

Comments
 (0)