diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2483976 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.idea/ +__pycache__/ diff --git a/client.py b/client.py index 9d93e80..a7dab23 100644 --- a/client.py +++ b/client.py @@ -8,7 +8,7 @@ import copy from optimization import Optimization class Client(): - def __init__(self, cid, data, device, project_dir, model_name, local_epoch, lr, batch_size, drop_rate, stride): + def __init__(self, cid, data, device, project_dir, model_name, local_epoch, lr, batch_size, drop_rate, stride, clustering=False): self.cid = cid self.project_dir = project_dir self.model_name = model_name @@ -21,25 +21,30 @@ def __init__(self, cid, data, device, project_dir, model_name, local_epoch, lr, self.dataset_sizes = self.data.train_dataset_sizes[cid] self.train_loader = self.data.train_loaders[cid] - self.full_model = get_model(self.data.train_class_sizes[cid], drop_rate, stride) - self.classifier = self.full_model.classifier.classifier - self.full_model.classifier.classifier = nn.Sequential() - self.model = self.full_model - self.distance=0 + self.model = get_model(self.data.train_class_sizes[cid], drop_rate, stride) + self.classifier = copy.deepcopy(self.model.classifier.classifier) + self.model.classifier.classifier = nn.Sequential() + self.distance = 0 self.optimization = Optimization(self.train_loader, self.device) + self.use_clustering = clustering # print("class name size",class_names_size[cid]) - def train(self, federated_model, use_cuda): + def train(self, federated_model=None, use_cuda=False): self.y_err = [] self.y_loss = [] - - self.model.load_state_dict(federated_model.state_dict()) + if self.use_clustering: + print("using clustering, model is set before") + assert federated_model is None + # self.model.classifier.classifier = nn.Sequential() + federated_model = copy.deepcopy(self.model) + else: + self.model.load_state_dict(federated_model.state_dict()) self.model.classifier.classifier = self.classifier self.old_classifier = copy.deepcopy(self.classifier) self.model = self.model.to(self.device) - + self.model.train(True) optimizer = get_optimizer(self.model, self.lr) - scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1) + # scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1) criterion = nn.CrossEntropyLoss() @@ -50,8 +55,7 @@ def train(self, federated_model, use_cuda): print('Epoch {}/{}'.format(epoch, self.local_epoch - 1)) print('-' * 10) - scheduler.step() - self.model.train(True) + # scheduler.step() running_loss = 0.0 running_corrects = 0.0 @@ -60,12 +64,12 @@ def train(self, federated_model, use_cuda): b, c, h, w = inputs.shape if b < self.batch_size: continue - if use_cuda: - inputs = Variable(inputs.cuda().detach()) - labels = Variable(labels.cuda().detach()) - else: - inputs, labels = Variable(inputs), Variable(labels) - + # if use_cuda: + # inputs = Variable(inputs.cuda().detach()) + # labels = Variable(labels.cuda().detach()) + # else: + # inputs, labels = Variable(inputs), Variable(labels) + inputs, labels = inputs.to(self.device), labels.to(self.device) optimizer.zero_grad() outputs = self.model(inputs) @@ -106,6 +110,9 @@ def train(self, federated_model, use_cuda): def generate_soft_label(self, x, regularization): return self.optimization.kd_generate_soft_label(self.model, x, regularization) + def generate_custom_data_feature(self, inputs): + return self.optimization.generate_custom_data_feature(self.model, inputs) + def get_model(self): return self.model @@ -116,4 +123,8 @@ def get_train_loss(self): return self.y_loss[-1] def get_cos_distance_weight(self): - return self.distance \ No newline at end of file + return self.distance + + def set_model(self, model): + self.model = copy.deepcopy(model) + diff --git a/data_utils.py b/data_utils.py index 0aa3d65..760b61f 100644 --- a/data_utils.py +++ b/data_utils.py @@ -15,7 +15,7 @@ def __len__(self): return len(self.imgs) def __getitem__(self, index): - data,label = self.imgs[index] + data, label = self.imgs[index] return self.transform(Image.open(data)), label @@ -95,6 +95,10 @@ def preprocess_train(self): print('Train dataset sizes:', self.train_dataset_sizes) print('Train class sizes:', self.train_class_sizes) + if "cuhk02" in self.datasets: + #cuhk02 is not labeled, we only use it for feature extraction in clustering + self.datasets.remove("cuhk02") + self.client_list.remove("cuhk02") def preprocess_test(self): """preprocess testing data, constructing test loaders @@ -103,10 +107,8 @@ def preprocess_test(self): self.gallery_meta = {} self.query_meta = {} - for test_dir in self.datasets: - test_dir = 'data/'+test_dir+'/pytorch' - - dataset = test_dir.split('/')[1] + for dataset in self.datasets: + test_dir = os.path.join(self.data_dir, dataset, "pytorch") gallery_dataset = datasets.ImageFolder(os.path.join(test_dir, 'gallery')) query_dataset = datasets.ImageFolder(os.path.join(test_dir, 'query')) diff --git a/evaluate.py b/evaluate.py index 8d2519a..eea2a2a 100644 --- a/evaluate.py +++ b/evaluate.py @@ -2,11 +2,11 @@ import torch import numpy as np import os -import argparse -parser = argparse.ArgumentParser(description='Training') -parser.add_argument('--result_dir', default='.', type=str) -parser.add_argument('--dataset', default='no_dataset', type=str) -args = parser.parse_args() +# import argparse +# parser = argparse.ArgumentParser(description='Training') +# parser.add_argument('--result_dir', default='.', type=str) +# parser.add_argument('--dataset', default='no_dataset', type=str) +# args = parser.parse_args() ####################################################################### # Evaluate @@ -60,32 +60,39 @@ def compute_mAP(index, good_index, junk_index): return ap, cmc -###################################################################### -result = scipy.io.loadmat(args.result_dir + '/pytorch_result.mat') -query_feature = torch.FloatTensor(result['query_f']) -query_cam = result['query_cam'][0] -query_label = result['query_label'][0] -gallery_feature = torch.FloatTensor(result['gallery_f']) -gallery_cam = result['gallery_cam'][0] -gallery_label = result['gallery_label'][0] +def testing_model(result, dataset): + # result = scipy.io.loadmat(file_path) + # print("========= after loading ==========") + # for i in result: + # print(i, np.array(result[i]).shape) -query_feature = query_feature.cuda() -gallery_feature = gallery_feature.cuda() + query_feature = torch.FloatTensor(result['query_f']) + query_cam = np.array(result['query_cam']) + query_label = np.array(result['query_label']) + gallery_feature = torch.FloatTensor(result['gallery_f']) + gallery_cam = np.array(result['gallery_cam']) + gallery_label = np.array(result['gallery_label']) + # print(type(query_feature),query_feature[:3]) + # print(type(query_cam),query_cam[:3]) + # print(type(query_label),query_label[:3]) -print(query_feature.shape) -CMC = torch.IntTensor(len(gallery_label)).zero_() -ap = 0.0 + query_feature = query_feature.cuda() + gallery_feature = gallery_feature.cuda() -for i in range(len(query_label)): - ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], query_cam[i], gallery_feature, gallery_label, gallery_cam) - if CMC_tmp[0]==-1: - continue - CMC = CMC + CMC_tmp - ap += ap_tmp + print(query_feature.shape) + CMC = torch.IntTensor(len(gallery_label)).zero_() + ap = 0.0 -CMC = CMC.float() -CMC = CMC/len(query_label) #average CMC -print(args.dataset+' Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0], CMC[4], CMC[9], ap/len(query_label))) -print('-'*15) -print() + for i in range(len(query_label)): + ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], query_cam[i], gallery_feature, gallery_label, gallery_cam) + if CMC_tmp[0]==-1: + continue + CMC = CMC + CMC_tmp + ap += ap_tmp + + CMC = CMC.float() + CMC = CMC/len(query_label) #average CMC + print(dataset+' Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0], CMC[4], CMC[9], ap/len(query_label))) + print('-'*15) + print() diff --git a/finch.py b/finch.py new file mode 100644 index 0000000..d5dcd1e --- /dev/null +++ b/finch.py @@ -0,0 +1,173 @@ +import time +import argparse +import numpy as np +from sklearn import metrics +import scipy.sparse as sp +import warnings + +try: + from pyflann import * + + pyflann_available = True +except Exception as e: + warnings.warn('pyflann not installed: {}'.format(e)) + pyflann_available = False + pass + +RUN_FLANN = 70000 + + +def clust_rank(mat, initial_rank=None, distance='cosine'): + s = mat.shape[0] + if initial_rank is not None: + orig_dist = [] + elif s <= RUN_FLANN: + orig_dist = metrics.pairwise.pairwise_distances(mat, mat, metric=distance) + np.fill_diagonal(orig_dist, 1000.0) + initial_rank = np.argmin(orig_dist, axis=1) + else: + if not pyflann_available: + raise MemoryError("You should use pyflann for inputs larger than {} samples.".format(RUN_FLANN)) + print('Using flann to compute 1st-neighbours at this step ...') + flann = FLANN() + result, dists = flann.nn(mat, mat, num_neighbors=2, algorithm="kdtree", trees=8, checks=128) + initial_rank = result[:, 1] + orig_dist = [] + print('Step flann done ...') + + # The Clustering Equation + A = sp.csr_matrix((np.ones_like(initial_rank, dtype=np.float32), (np.arange(0, s), initial_rank)), shape=(s, s)) + A = A + sp.eye(s, dtype=np.float32, format='csr') + A = A @ A.T + + A = A.tolil() + A.setdiag(0) + return A, orig_dist + + +def get_clust(a, orig_dist, min_sim=None): + if min_sim is not None: + a[np.where((orig_dist * a.toarray()) > min_sim)] = 0 + + num_clust, u = sp.csgraph.connected_components(csgraph=a, directed=True, connection='weak', return_labels=True) + return u, num_clust + + +def cool_mean(M, u): + _, nf = np.unique(u, return_counts=True) + idx = np.argsort(u) + M = M[idx, :] + M = np.vstack((np.zeros((1, M.shape[1])), M)) + + np.cumsum(M, axis=0, out=M) + cnf = np.cumsum(nf) + nf1 = np.insert(cnf, 0, 0) + nf1 = nf1[:-1] + + M = M[cnf, :] - M[nf1, :] + M = M / nf[:, None] + return M + + +def get_merge(c, u, data): + if len(c) != 0: + _, ig = np.unique(c, return_inverse=True) + c = u[ig] + else: + c = u + + mat = cool_mean(data, c) + return c, mat + + +def update_adj(adj, d): + # Update adj, keep one merge at a time + idx = adj.nonzero() + v = np.argsort(d[idx]) + v = v[:2] + x = [idx[0][v[0]], idx[0][v[1]]] + y = [idx[1][v[0]], idx[1][v[1]]] + a = sp.lil_matrix(adj.get_shape()) + a[x, y] = 1 + return a + + +def req_numclust(c, data, req_clust, distance): + iter_ = len(np.unique(c)) - req_clust + c_, mat = get_merge([], c, data) + for i in range(iter_): + adj, orig_dist = clust_rank(mat, initial_rank=None, distance=distance) + adj = update_adj(adj, orig_dist) + u, _ = get_clust(adj, [], min_sim=None) + c_, mat = get_merge(c_, u, data) + return c_ + + +def FINCH(data, min_sim=None, initial_rank=None, req_clust=None, distance='cosine', verbose=True): + """ FINCH clustering algorithm. + :param data: Input matrix with features in rows. + :param initial_rank: Nx1 first integer neighbor indices (optional). + :param req_clust: Set output number of clusters (optional). Not recommended. + :param distance: One of ['cityblock', 'cosine', 'euclidean', 'l1', 'l2', 'manhattan'] Recommended 'cosine'. + :param verbose: Print verbose output. + :return: + c: NxP matrix where P is the partition. Cluster label for every partition. + num_clust: Number of clusters. + req_c: Labels of required clusters (Nx1). Only set if `req_clust` is not None. + The code implements the FINCH algorithm described in our CVPR 2019 paper + Sarfraz et al. "Efficient Parameter-free Clustering Using First Neighbor Relations", CVPR2019 + https://arxiv.org/abs/1902.11266 + For academic purpose only. The code or its re-implementation should not be used for commercial use. + Please contact the author below for licensing information. + Copyright + M. Saquib Sarfraz (saquib.sarfraz@kit.edu) + Karlsruhe Institute of Technology (KIT) + """ + # Cast input data to float32 + data = data.astype(np.float32) + + # min_sim = None + adj, orig_dist = clust_rank(data, initial_rank, distance) + initial_rank = None + group, num_clust = get_clust(adj, [], min_sim) + c, mat = get_merge([], group, data) + + if verbose: + print('Partition 0: {} clusters'.format(num_clust)) + if len(orig_dist) != 0: + min_sim = np.max(orig_dist * adj.toarray()) + + exit_clust = 2 + c_ = c + k = 1 + num_clust = [num_clust] + + while exit_clust > 1: + adj, orig_dist = clust_rank(mat, initial_rank, distance) + u, num_clust_curr = get_clust(adj, orig_dist, min_sim) + c_, mat = get_merge(c_, u, data) + + num_clust.append(num_clust_curr) + c = np.column_stack((c, c_)) + exit_clust = num_clust[-2] - num_clust_curr + + if num_clust_curr == 1 or exit_clust < 1: + num_clust = num_clust[:-1] + c = c[:, :-1] + break + + if verbose: + print('Partition {}: {} clusters'.format(k, num_clust[k])) + k += 1 + + if req_clust is not None: + if req_clust not in num_clust: + ind = [i for i, v in enumerate(num_clust) if v >= req_clust] + req_c = req_numclust(c[:, ind[-1]], data, req_clust, distance) + else: + req_c = c[:, num_clust.index(req_clust)] + else: + req_c = None + + return c, num_clust, req_c + diff --git a/finch_dis.py b/finch_dis.py new file mode 100644 index 0000000..cd664de --- /dev/null +++ b/finch_dis.py @@ -0,0 +1,89 @@ +import numpy as np +from tqdm import tqdm +from collections import defaultdict +from sklearn import metrics +from sklearn.preprocessing import normalize + + +def finch(feats, finch_step, finch_dis, metric="cosine", do_normalize=True): + if do_normalize: + feats = normalize(feats, norm='l2').astype('float32') + num_track = feats.shape[0] + clusters = np.arange(num_track) + + for step in range(finch_step): + print('Step {}'.format(step)) + pre_ids = list(set(clusters)) + pre_ids.sort() + if step >= 3: + print(pre_ids[-10:]) + if len(pre_ids) <= 3: + break + pre_map = defaultdict(list) + for i, x in tqdm(enumerate(clusters)): + pre_map[x].append(i) + # if step>=3: + # print("pre_map before convert: ",pre_map[-10:]) + pre_map = {k: np.array(v) for k, v in pre_map.items()} + + print('Calculate center features') + if step == 0: + feats_now = feats.copy() + else: + feats_now = np.array([np.sum(feats[pre_map[i]], axis=0) / pre_map[i].size + for i in tqdm(pre_ids)]) + print('Search top1') + print("feature_shape_now: ", feats_now.shape) + num_track_now = feats_now.shape[0] + + feats_now = normalize(feats_now, norm='l2').astype('float32') + + orig_dist = metrics.pairwise.pairwise_distances(feats_now, feats_now, metric=metric) + np.fill_diagonal(orig_dist, float('inf')) + topk_idx = np.argmin(orig_dist, axis=1) + topk_scores = [orig_dist[i][topk_idx[i]] for i in range(len(orig_dist))] + print("orig_dist: {}, topk_scores:{}".format(orig_dist, topk_scores)) + clusters_now = [] + used = [False for _ in range(num_track_now)] + + def dfs(root): + used[root] = True + res = set([root]) + for idx in graph[root]: + if not used[idx]: + res |= dfs(idx) + return res + + graph = [[] for i in range(num_track_now)] + print('==Building graph==') + for i in tqdm(range(num_track_now)): + if finch_dis < 0 or topk_scores[i] < finch_dis: + graph[i].append(topk_idx[i]) + graph[topk_idx[i]].append(i) + + print('DFS: ') + for i in tqdm(range(num_track_now)): + if used[i]: + continue + clusters_now.append(list(dfs(i))) + + print('Merge cluster') + if step >= 3: + print("what is pre_map before merge?", pre_map[0]) + new_id_cnt = len(pre_map) + for cluster in tqdm(clusters_now): + tmp_ids = np.array([]) + for i in cluster: + tmp_ids = np.concatenate((tmp_ids, pre_map[i])) + pre_map.pop(i) + pre_map[new_id_cnt] = tmp_ids.astype('int32') + new_id_cnt += 1 + + for i, key in enumerate(pre_map.keys()): + clusters[pre_map[key]] = i + + print('Done') + print("final_cluster_num: {}, clusters: {}".format(len(set(clusters)), clusters)) + + return clusters + diff --git a/kmeans.py b/kmeans.py new file mode 100644 index 0000000..56870a5 --- /dev/null +++ b/kmeans.py @@ -0,0 +1,12 @@ +import numpy as np +from sklearn.cluster import KMeans +from sklearn.preprocessing import normalize + +def kmeans(feats, n_clusters=2, do_normalize=True): + if do_normalize: + feats = normalize(feats, norm='l2').astype('float32') + print(feats.shape) + print(feats) + kmeans = KMeans(n_clusters=n_clusters, random_state=0) + kmeans.fit(feats) + return kmeans.labels_ \ No newline at end of file diff --git a/main.py b/main.py index f68f142..cc8e358 100644 --- a/main.py +++ b/main.py @@ -23,6 +23,7 @@ from utils import set_random_seed from data_utils import Data + mp.set_start_method('spawn', force=True) sys.setrecursionlimit(10000) version = torch.__version__ @@ -32,7 +33,7 @@ parser.add_argument('--model_name',default='ft_ResNet50', type=str, help='output model name') parser.add_argument('--project_dir',default='.', type=str, help='project path') parser.add_argument('--data_dir',default='data',type=str, help='training dir path') -parser.add_argument('--datasets',default='Market,DukeMTMC-reID,cuhk03-np-detected,cuhk01,MSMT17,viper,prid,3dpes,ilids',type=str, help='datasets used') +parser.add_argument('--datasets',default='Market,DukeMTMC-reID,cuhk03-np-detected,cuhk01,MSMT17,viper,prid,3dpes,ilids,cuhk02',type=str, help='datasets used') parser.add_argument('--train_all', action='store_true', help='use all training data' ) parser.add_argument('--stride', default=2, type=int, help='stride') parser.add_argument('--lr', default=0.05, type=float, help='learning rate') @@ -53,16 +54,79 @@ parser.add_argument('--multiple_scale',default='1', type=str,help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2') parser.add_argument('--test_dir',default='all',type=str, help='./test_data') +parser.add_argument('--resume_epoch', default=0, type=int, help='resume from which epoch, if 0, no resume') +parser.add_argument('--experiment_index', default=0, type=int, help='index of training time') # arguments for optimization parser.add_argument('--cdw', action='store_true', help='use cosine distance weight for model aggregation, default false' ) -parser.add_argument('--kd', action='store_true', help='apply knowledge distillation, default false' ) +parser.add_argument('--kd', action='store_true', help='apply knowledge distillation, default false') +parser.add_argument('--kd_method', default='cluster', type=str, help='whole or cluster') parser.add_argument('--regularization', action='store_true', help='use regularization during distillation, default false' ) +parser.add_argument('--clustering', action='store_true', help='use clustering to aggregate models, fault false') +parser.add_argument('--clustering_method', default='finch', type=str, help='method used for clustering, finch or kmeans') +parser.add_argument('--max_distance', default=0.9, type=float, help='maximum distance in finch algorithm') +parser.add_argument('--n_cluster', default=2, type=int, help='number of cluster in Kmeans') + + +def save_checkpoint(server, clients, client_list, cpk_dir, epoch): + torch.save({ + 'epoch': epoch, + 'server_state_dict': server.federated_model.state_dict(), + 'client_list': [clients[c].cid for c in client_list], + 'client_classifier': [clients[c].classifier.state_dict() for c in client_list], + 'client_model': [clients[c].model.state_dict() for c in client_list] + }, os.path.join(cpk_dir, "{}.pth".format(epoch))) + + +def load_checkpoint(path): + cpk = torch.load(path) + epoch = cpk['epoch'] + server_state_dict = cpk['server_state_dict'] + client_list = cpk['client_list'] + client_classifier = cpk['client_classifier'] + client_model = cpk['client_model'] + return epoch, server_state_dict, client_list, client_classifier, client_model + def train(): args = parser.parse_args() print(args) - + if args.clustering: + clu = "clu" + else: + clu = "Nclu" + + if args.cdw: + cdw = "cdw" + else: + cdw = "Ncdw" + if args.kd: + kd = "kd" + else: + kd = "Nkd" + if args.regularization: + reg = "reg" + else: + reg = "Nreg" + + kd_method = args.kd_method + assert (kd_method == 'whole' or kd_method == 'cluster') + if args.clustering: + if args.clustering_method == "kmeans": + cluster_description = "kmeans_{}".format(args.n_cluster) + else: + cluster_description = "finch_{}".format(args.max_distance) + else: + cluster_description = "No_cluster" + + cpk_dir = "checkpoints/{}_{}_{}_{}_{}_{}_{}".format(clu, cdw, kd, kd_method, reg, + cluster_description, args.experiment_index) + cpk_dir = os.path.join(args.project_dir, cpk_dir) + if not os.path.isdir(cpk_dir): + os.makedirs(cpk_dir) + + epoch = args.resume_epoch + use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") @@ -70,7 +134,7 @@ def train(): data = Data(args.datasets, args.data_dir, args.batch_size, args.erasing_p, args.color_jitter, args.train_all) data.preprocess() - + clients = {} for cid in data.client_list: clients[cid] = Client( @@ -83,7 +147,8 @@ def train(): args.lr, args.batch_size, args.drop_rate, - args.stride) + args.stride, + args.clustering) server = Server( clients, @@ -95,26 +160,41 @@ def train(): args.lr, args.drop_rate, args.stride, - args.multiple_scale) - - dir_name = os.path.join(args.project_dir, 'model', args.model_name) - if not os.path.isdir(dir_name): - os.mkdir(dir_name) + args.multiple_scale, + args.clustering, + args.clustering_method, + args.max_distance, + args.n_cluster) + + if epoch != 0: + print("======= loading checkpoint, epoch: {}".format(epoch)) + path = os.path.join(cpk_dir, "{}.pth".format(epoch)) + cpk_epoch, server_state_dict, client_list, client_classifier, client_model = load_checkpoint(path) + assert (epoch == cpk_epoch) + server.federated_model.load_state_dict(server_state_dict) + for i in range(len(client_list)): + cid = client_list[i] + clients[cid].classifier.load_state_dict(client_classifier[i]) + clients[cid].model.load_state_dict(client_model[i]) + print("all models loaded, training from {}".format(epoch)) print("=====training start!========") - rounds = 800 - for i in range(rounds): + rounds = 500 + rounds = rounds // args.local_epoch + for i in range(epoch, rounds): + save_checkpoint(server, clients, data.client_list, cpk_dir, i) print('='*10) print("Round Number {}".format(i)) print('='*10) server.train(i, args.cdw, use_cuda) - save_path = os.path.join(dir_name, 'federated_model.pth') - torch.save(server.federated_model.cpu().state_dict(), save_path) - if (i+1)%10 == 0: - server.test(use_cuda) + # if not args.clustering: + # save_path = os.path.join(dir_name, 'federated_model.pth') + # torch.save(server.federated_model.cpu().state_dict(), save_path) + if (i+1) % 10 == 0: + server.test(use_cuda, use_fed=True) if args.kd: - server.knowledge_distillation(args.regularization) - server.test(use_cuda) + server.knowledge_distillation(args.regularization, kd_method) + server.test(use_cuda, use_fed=True) server.draw_curve() if __name__ == '__main__': diff --git a/optimization.py b/optimization.py index 4d6c155..c278f8b 100644 --- a/optimization.py +++ b/optimization.py @@ -31,3 +31,13 @@ def kd_generate_soft_label(self, model, data, regularization): if regularization: result = F.normalize(result, dim=1, p=2) return result + + def generate_custom_data_feature(self, model, inputs): + with torch.no_grad(): + out = model(inputs) + features = [] + for i in out: + features.append(i) + features = torch.cat(features, 0) + return features + diff --git a/running.sh b/running.sh new file mode 100644 index 0000000..7546c9b --- /dev/null +++ b/running.sh @@ -0,0 +1,4 @@ +export PYTHONPATH=$PYTHONPATH:$pwd +srun -u --partition=Sensetime --job-name=base \ + -n1 --gres=gpu:1 --ntasks-per-node=1 \ + python main.py --data_dir /mnt/lustre/ganxin/fedreid_data/data --train_all | tee fed_reid_base_cpk.log & \ No newline at end of file diff --git a/running_cdw.sh b/running_cdw.sh new file mode 100644 index 0000000..24ca311 --- /dev/null +++ b/running_cdw.sh @@ -0,0 +1,4 @@ +export PYTHONPATH=$PYTHONPATH:$pwd +srun -u --partition=Sensetime --job-name=cdw \ + -n1 --gres=gpu:1 --ntasks-per-node=1 \ + python main.py --data_dir /mnt/lustre/ganxin/fedreid_data/data --train_all --cdw | tee fed_reid_base_cdw_cpk.log & \ No newline at end of file diff --git a/running_cdw_kd.sh b/running_cdw_kd.sh new file mode 100644 index 0000000..260e172 --- /dev/null +++ b/running_cdw_kd.sh @@ -0,0 +1,4 @@ +export PYTHONPATH=$PYTHONPATH:$pwd +srun -u --partition=Sensetime --job-name=cdw_kd \ + -n1 --gres=gpu:1 --ntasks-per-node=1 \ + python main.py --data_dir /mnt/lustre/ganxin/fedreid_data/data --train_all --cdw --kd --regularization| tee fed_reid_base_cdw_kd_cpk.log & \ No newline at end of file diff --git a/running_clustering.sh b/running_clustering.sh new file mode 100644 index 0000000..88ac2de --- /dev/null +++ b/running_clustering.sh @@ -0,0 +1,4 @@ +export PYTHONPATH=$PYTHONPATH:$pwd +srun -u --partition=Sensetime --job-name=clu \ + -n1 --gres=gpu:1 --ntasks-per-node=1 \ + python main.py --data_dir /mnt/lustre/ganxin/fedreid_data/data --train_all --clustering | tee fed_reid_clustering_cpk.log & \ No newline at end of file diff --git a/running_clustering_cdw.sh b/running_clustering_cdw.sh new file mode 100644 index 0000000..976442d --- /dev/null +++ b/running_clustering_cdw.sh @@ -0,0 +1,4 @@ +export PYTHONPATH=$PYTHONPATH:$pwd +srun -u --partition=Sensetime --job-name=clu_cdw \ + -n1 --gres=gpu:1 --ntasks-per-node=1 \ + python main.py --data_dir /mnt/lustre/ganxin/fedreid_data/data --train_all --clustering --cdw --resume_epoch 41 | tee fed_reid_clustering_cdw_cpk_from_41.log & \ No newline at end of file diff --git a/running_clustering_cdw_kd.sh b/running_clustering_cdw_kd.sh new file mode 100644 index 0000000..c1826e8 --- /dev/null +++ b/running_clustering_cdw_kd.sh @@ -0,0 +1,4 @@ +export PYTHONPATH=$PYTHONPATH:$pwd +srun -u --partition=Sensetime --job-name=clu_cdw_kd \ + -n1 --gres=gpu:1 --ntasks-per-node=1 \ + python main.py --data_dir /mnt/lustre/ganxin/fedreid_data/data --train_all --clustering --cdw --kd --regularization --resume_epoch 39 | tee fed_reid_clustering_cdw_kd_cpk_from_39.log & \ No newline at end of file diff --git a/running_clustering_cdw_kd_whole.sh b/running_clustering_cdw_kd_whole.sh new file mode 100644 index 0000000..89c2b8d --- /dev/null +++ b/running_clustering_cdw_kd_whole.sh @@ -0,0 +1,4 @@ +export PYTHONPATH=$PYTHONPATH:$pwd +srun -u --partition=Sensetime --job-name=clu_cdw_kd \ + -n1 --gres=gpu:1 --ntasks-per-node=1 \ + python main.py --data_dir /mnt/lustre/ganxin/fedreid_data/data --train_all --clustering --cdw --kd --regularization --kd_method whole | tee fed_reid_clustering_cdw_kd_whole.log & \ No newline at end of file diff --git a/running_clustering_finch.sh b/running_clustering_finch.sh new file mode 100644 index 0000000..fdcb62b --- /dev/null +++ b/running_clustering_finch.sh @@ -0,0 +1,5 @@ +▽ +export PYTHONPATH=$PYTHONPATH:$pwd +srun -u --partition=innova --job-name=f0.9 \ + -n1 --gres=gpu:1 --ntasks-per-node=1 \ + python main.py --data_dir /mnt/lustre/ganxin/fedreid_data/data --train_all --clustering --clustering_method finch --max_distance 0.9 --resume_epoch 119 | tee fed_reid_clustering_cpk_finch_0.9_from_119.log & \ No newline at end of file diff --git a/running_clustering_kd.sh b/running_clustering_kd.sh new file mode 100644 index 0000000..02d3886 --- /dev/null +++ b/running_clustering_kd.sh @@ -0,0 +1,4 @@ +export PYTHONPATH=$PYTHONPATH:$pwd +srun -u --partition=Sensetime --job-name=clu_kd \ + -n1 --gres=gpu:1 --ntasks-per-node=1 \ + python main.py --data_dir /mnt/lustre/ganxin/fedreid_data/data --train_all --clustering --kd --regularization | tee fed_reid_clustering_kd.log & \ No newline at end of file diff --git a/running_clustering_kd_whole.sh b/running_clustering_kd_whole.sh new file mode 100644 index 0000000..85b0e40 --- /dev/null +++ b/running_clustering_kd_whole.sh @@ -0,0 +1,4 @@ +export PYTHONPATH=$PYTHONPATH:$pwd +srun -u --partition=Sensetime --job-name=clu_kd_whole \ + -n1 --gres=gpu:1 --ntasks-per-node=1 \ + python main.py --data_dir /mnt/lustre/ganxin/fedreid_data/data --train_all --clustering --kd --regularization --kd_method whole | tee fed_reid_clustering_kd_whole.log & \ No newline at end of file diff --git a/running_clustering_kmeans.sh b/running_clustering_kmeans.sh new file mode 100644 index 0000000..79b4302 --- /dev/null +++ b/running_clustering_kmeans.sh @@ -0,0 +1,4 @@ +export PYTHONPATH=$PYTHONPATH:$pwd +srun -u --partition=innova --job-name=kmeans_2 \ + -n1 --gres=gpu:1 --ntasks-per-node=1 \ + python main.py --data_dir /mnt/lustre/ganxin/fedreid_data/data --train_all --clustering --clustering_method kmeans --n_cluster 2 --resume_epoch 149 | tee fed_reid_clustering_cpk_kmeans_2_from_149.log & \ No newline at end of file diff --git a/running_customised.sh b/running_customised.sh new file mode 100644 index 0000000..1884eac --- /dev/null +++ b/running_customised.sh @@ -0,0 +1,4 @@ +export PYTHONPATH=$PYTHONPATH:$pwd +srun -u --partition=irdcRD --job-name=le${1}_bs${2}_time${3} \ + -n1 --gres=gpu:1 --ntasks-per-node=1 \ + python main.py --data_dir /mnt/lustre/ganxin/fedreid_data/data --train_all --local_epoch ${1} --batch_size ${2} --experiment_index ${3} --resume_epoch ${4} 2>&1 | tee ablation/fed_reid_base_bs${2}_le${1}_time${3}_${4}.log & \ No newline at end of file diff --git a/server.py b/server.py index 2cd46e9..6a1ff43 100644 --- a/server.py +++ b/server.py @@ -1,16 +1,20 @@ import os import math -import json +import random import matplotlib.pyplot as plt from utils import get_model, extract_feature import torch.nn as nn import torch import scipy.io import copy +import numpy as np from data_utils import ImageDataset -import random import torch.optim as optim from torchvision import datasets +# from finch import FINCH +from finch_dis import finch +from kmeans import kmeans +from evaluate import testing_model def add_model(dst_model, src_model, dst_no_data, src_no_data): if dst_model is None: @@ -51,7 +55,9 @@ def aggregate_models(models, weights): class Server(): - def __init__(self, clients, data, device, project_dir, model_name, num_of_clients, lr, drop_rate, stride, multiple_scale): + def __init__(self, clients, data, device, project_dir, model_name, num_of_clients, lr, + drop_rate, stride, multiple_scale, clustering=False, clustering_method="finch", + max_distance=2, n_cluster=2): self.project_dir = project_dir self.data = data self.device = device @@ -68,45 +74,97 @@ def __init__(self, clients, data, device, project_dir, model_name, num_of_client for s in multiple_scale.split(','): self.multiple_scale.append(math.sqrt(float(s))) - self.full_model = get_model(750, drop_rate, stride).to(device) - self.full_model.classifier.classifier = nn.Sequential() - self.federated_model=self.full_model + self.federated_model = get_model(750, drop_rate, stride).to(device) + self.federated_model.classifier.classifier = nn.Sequential() self.federated_model.eval() self.train_loss = [] + self.use_clustering = clustering + self.clustering_group_for_kd = None + self.cdw = None + self.clients_using = None + self.clients_weights = None + self.clustering_method = clustering_method + self.max_dis = max_distance + self.n_cluster = n_cluster + def train(self, epoch, cdw, use_cuda): - models = [] + self.cdw = cdw loss = [] - cos_distance_weights = [] + if not self.use_clustering: + models = [] + cos_distance_weights = [] data_sizes = [] current_client_list = random.sample(self.client_list, self.num_of_clients) + self.clients_using = current_client_list + feature_lists = [] for i in current_client_list: - self.clients[i].train(self.federated_model, use_cuda) - cos_distance_weights.append(self.clients[i].get_cos_distance_weight()) + if not self.use_clustering: + self.clients[i].train(self.federated_model, use_cuda) + cos_distance_weights.append(self.clients[i].get_cos_distance_weight()) + models.append(self.clients[i].get_model()) + else: + self.clients[i].train(None, use_cuda=use_cuda) loss.append(self.clients[i].get_train_loss()) - models.append(self.clients[i].get_model()) data_sizes.append(self.clients[i].get_data_sizes()) + if (epoch + 1) % 10 == 0: + print("before aggregation, local testing") + self.test(use_cuda) + if self.use_clustering: + for _, (inputs, targets) in enumerate(self.data.train_loaders['cuhk02']): + inputs, target = inputs.to(self.device), targets.to(self.device) + break + for i in current_client_list: + feature_lists.append(self.clients[i].generate_custom_data_feature(inputs).cpu().detach().numpy()) + feature_lists = np.array(feature_lists) + # c, num_clust, _ = FINCH(feature_lists, min_sim=self.max_dis) + # self.clustering_method, self.max_dis, self.n_cluster = n_cluster + if self.clustering_method == "kmeans": + clusters = kmeans(feature_lists, n_clusters=self.n_cluster, do_normalize=True) + else: + clusters = finch(feature_lists, finch_step=1, finch_dis=self.max_dis, metric="cosine", do_normalize=True) + id_groups = self.clustering(clusters, current_client_list) + print("id_groups", id_groups) + + if epoch==0: self.L0 = torch.Tensor(loss) avg_loss = sum(loss) / self.num_of_clients print("==============================") - print("number of clients used:", len(models)) + # print("number of clients used:", len(models)) print('Train Epoch: {}, AVG Train Loss among clients of lost epoch: {:.6f}'.format(epoch, avg_loss)) print() - + self.train_loss.append(avg_loss) - - weights = data_sizes - - if cdw: - print("cos distance weights:", cos_distance_weights) - weights = cos_distance_weights + if self.use_clustering: + for i in id_groups.keys(): + models = [] + data_num = [] + cos_distance = [] + for j in id_groups[i]: + models.append(self.clients[j].get_model()) + data_num.append(self.clients[j].get_data_sizes()) + cos_distance.append(self.clients[j].get_cos_distance_weight()) + if cdw: + federated_model = aggregate_models(models, cos_distance) + else: + federated_model = aggregate_models(models, data_num) + for j in id_groups[i]: + self.clients[j].set_model(federated_model) + print("using clustering, client models set") + self.clustering_group_for_kd = id_groups + + else: + weights = data_sizes + if cdw: + print("cos distance weights:", cos_distance_weights) + weights = cos_distance_weights + self.federated_model = aggregate_models(models, weights) - self.federated_model = aggregate_models(models, weights) def draw_curve(self): plt.figure() @@ -119,20 +177,32 @@ def draw_curve(self): plt.savefig(os.path.join(dir_name, 'train.png')) plt.close('all') - def test(self, use_cuda): + def test(self, use_cuda, use_fed=False): print("="*10) - print("Start Tesing!") + print("Start Testing!") print("="*10) - print('We use the scale: %s'%self.multiple_scale) + print('We use the scale: %s' % self.multiple_scale) for dataset in self.data.datasets: - self.federated_model = self.federated_model.eval() - if use_cuda: - self.federated_model = self.federated_model.cuda() - + # if self.use_clustering: + if use_fed and not self.use_clustering: + print("Using federated model") + client_model = self.federated_model.eval() + if use_cuda: + client_model = self.federated_model.cuda() + else: + print("Using local model") + client_model = self.clients[dataset].get_model().eval() # self.federated_model.eval() + if use_cuda: + client_model = client_model.cuda() # self.federated_model.cuda() + # else: + # self.federated_model = self.federated_model.eval() + # if use_cuda: + # self.federated_model = self.federated_model.cuda() + with torch.no_grad(): - gallery_feature = extract_feature(self.federated_model, self.data.test_loaders[dataset]['gallery'], self.multiple_scale) - query_feature = extract_feature(self.federated_model, self.data.test_loaders[dataset]['query'], self.multiple_scale) + gallery_feature = extract_feature(client_model, self.data.test_loaders[dataset]['gallery'], self.multiple_scale) + query_feature = extract_feature(client_model, self.data.test_loaders[dataset]['query'], self.multiple_scale) result = { 'gallery_f': gallery_feature.numpy(), @@ -140,38 +210,111 @@ def test(self, use_cuda): 'gallery_cam': self.data.gallery_meta[dataset]['cameras'], 'query_f': query_feature.numpy(), 'query_label': self.data.query_meta[dataset]['labels'], - 'query_cam': self.data.query_meta[dataset]['cameras']} - - scipy.io.savemat(os.path.join(self.project_dir, - 'model', - self.model_name, - 'pytorch_result.mat'), - result) - + 'query_cam': self.data.query_meta[dataset]['cameras'] + } + print("====== before loading =======") + # for i in result: + # print(i, np.array(result[i]).shape, result[i][:3]) + # file_path = os.path.join(self.project_dir, + # 'model', + # self.model_name, + # 'pytorch_result_{}_{}.mat'.format(dataset, random.randint(0, 100000000))) + # scipy.io.savemat(file_path, result) + print(self.model_name) print(dataset) + testing_model(result, dataset) + # os.system('python evaluate.py --result_dir {} --dataset {}'.format(os.path.join(self.project_dir, 'model', self.model_name), dataset)) + + def knowledge_distillation(self, regularization, kd_method): + if self.use_clustering and kd_method == 'cluster': + print("personlaization with kd_method cluster") + for i in self.clustering_group_for_kd: + print("grouping {} for kd".format(self.clustering_group_for_kd[i])) + first_client_id = self.clustering_group_for_kd[i][0] + model = self.clients[first_client_id].get_model() + federated_model = self.cluster_knowledge_distillation(model, + self.clustering_group_for_kd[i], + regularization) + for j in self.clustering_group_for_kd[i]: + self.clients[j].set_model(federated_model) + elif self.use_clustering and kd_method == 'whole': + print("personlaization with kd_method whole") + models = [] + cos_distance_weights = [] + data_sizes = [] + for i in self.clients_using: + cos_distance_weights.append(self.clients[i].get_cos_distance_weight()) + models.append(self.clients[i].get_model()) + data_sizes.append(self.clients[i].get_data_sizes()) + weights = data_sizes + if self.cdw: + print("cos distance weights:", cos_distance_weights) + weights = cos_distance_weights + self.federated_model = aggregate_models(models, weights) - os.system('python evaluate.py --result_dir {} --dataset {}'.format(os.path.join(self.project_dir, 'model', self.model_name), dataset)) + federated_model = self.cluster_knowledge_distillation(self.federated_model, + self.client_list, + regularization) + for j in self.client_list: + self.clients[j].set_model(federated_model) + else: + MSEloss = nn.MSELoss().to(self.device) + optimizer = optim.SGD(self.federated_model.parameters(), lr=self.lr*0.01, weight_decay=5e-4, momentum=0.9, nesterov=True) + self.federated_model.train() - def knowledge_distillation(self, regularization): + for _, (x, target) in enumerate(self.data.kd_loader): + x, target = x.to(self.device), target.to(self.device) + # target=target.long() + optimizer.zero_grad() + soft_target = torch.Tensor([[0]*512]*len(x)).to(self.device) + + for i in self.client_list: + i_label = (self.clients[i].generate_soft_label(x, regularization)) + soft_target += i_label + soft_target /= len(self.client_list) + + output = self.federated_model(x) + + loss = MSEloss(output, soft_target) + loss.backward() + optimizer.step() + print("train_loss_fine_tuning", loss.data) + + def cluster_knowledge_distillation(self, model, c_list, regularization): MSEloss = nn.MSELoss().to(self.device) - optimizer = optim.SGD(self.federated_model.parameters(), lr=self.lr*0.01, weight_decay=5e-4, momentum=0.9, nesterov=True) - self.federated_model.train() + optimizer = optim.SGD(model.parameters(), lr=self.lr * 0.01, weight_decay=5e-4, momentum=0.9, + nesterov=True) + model.train() - for _, (x, target) in enumerate(self.data.kd_loader): + for _, (x, target) in enumerate(self.data.kd_loader): x, target = x.to(self.device), target.to(self.device) # target=target.long() optimizer.zero_grad() - soft_target = torch.Tensor([[0]*512]*len(x)).to(self.device) - - for i in self.client_list: + soft_target = torch.Tensor([[0] * 512] * len(x)).to(self.device) + + for i in c_list: i_label = (self.clients[i].generate_soft_label(x, regularization)) soft_target += i_label soft_target /= len(self.client_list) - - output = self.federated_model(x) - + + output = model(x) + loss = MSEloss(output, soft_target) loss.backward() optimizer.step() - print("train_loss_fine_tuning", loss.data) \ No newline at end of file + print("train_loss_fine_tuning of {} is {}".format(c_list, loss.data)) + return model + + def clustering(self, indexs, client_list): + id_groups = {} # dict.fromkeys([i for i in range(num_of_cluster[0])],[]) + assert len(indexs) == len(client_list) + for i in range(len(client_list)): + if indexs[i] not in id_groups.keys(): + id_groups[indexs[i]] = [client_list[i]] + else: + id_groups[indexs[i]].append(client_list[i]) + return id_groups + + +