-
Notifications
You must be signed in to change notification settings - Fork 6
/
init_devices.py
91 lines (73 loc) · 4.76 KB
/
init_devices.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
'''
Descripttion:
Version: 1.0
Author: ZhangHongYu
Date: 2022-03-13 21:16:56
LastEditors: ZhangHongYu
LastEditTime: 2022-04-06 20:56:29
'''
from models import ConvNet, MobileNet, NextCharacterLSTM
from utils.data_utils import rotate_data
from custom_ds.subsets import CustomSubset
from utils.plots import display_sample
from fl_devices import Client, Server
from method.ditto.client import DittoClient
from method.clustered.server import ClusteredServer
from method.my.server import CommunityServer
import torch
def init_clients_and_server(args, dataset, client_train_idcs, client_test_idcs, client_val_idcs, data_info):
# 获取训练集相关属性
if args.dataset != "Shakespeare":
n_channels, classes, input_sz, num_cls =\
data_info['n_channels'], data_info['classes'], data_info['input_sz'], data_info['num_cls']
client_train_data = [CustomSubset(dataset, idcs) for idcs in client_train_idcs]
client_test_data = [CustomSubset(dataset, idcs) for idcs in client_test_idcs]
client_val_data = [CustomSubset(dataset, idcs) for idcs in client_val_idcs]
# 对所有数据旋转以模拟簇状结构。即先iid再划分簇的方式,同一个client保持旋转模式一致
# 虽然对簇划分算法能抵御这种扰动,但会降低传统联邦学习算法精度,有点可以为之的感觉。
if args.dataset != 'Shakespeare':
rotate_data(client_train_data, client_test_data, client_val_data, args.n_clients, args.n_clusters)
if args.dataset != 'Shakespeare':
if args.method == "Ditto":
clients = [DittoClient(lambda: ConvNet(input_size=input_sz, channels=n_channels, num_classes=num_cls),\
lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9), train_data = train_dat, test_data = test_dat, val_data = val_dat, idnum=i) \
for i, (train_dat, test_dat, val_dat) in enumerate(zip(client_train_data, client_test_data, client_val_data))]
else:
clients = [Client(lambda: ConvNet(input_size=input_sz, channels=n_channels, num_classes=num_cls),\
lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9), train_data = train_dat, test_data = test_dat, val_data = val_dat, idnum=i) \
for i, (train_dat, test_dat, val_dat) in enumerate(zip(client_train_data, client_test_data, client_val_data))]
# clients = [Client(lambda: MobileNet(num_classes=num_cls), lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9), \
# train_data = train_dat, test_data = test_dat, val_dat = val_dat, idnum=i) \
# for i, (train_dat, test_dat, val_dat) in enumerate(zip(client_train_data, client_test_data, client_val_data))]
if args.method == "Clustered":
server = ClusteredServer(lambda : ConvNet(input_size=input_sz, channels=n_channels, num_classes=num_cls))
elif args.method == "My" or args.method == "Overlap":
server = CommunityServer(lambda : ConvNet(input_size=input_sz, channels=n_channels, num_classes=num_cls))
else:
server = Server(lambda : ConvNet(input_size=input_sz, channels=n_channels, num_classes=num_cls))
#将optim优化函数封装为一个匿名函数,然后其它参数就可以提供默认值
# server = Server(lambda : MobileNet(num_classes=num_cls))
else:
if args.method == "Ditto":
clients = [DittoClient(lambda: NextCharacterLSTM(), lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9),\
train_data = train_dat, test_data = test_dat, val_data = val_dat, idnum=i) for i, (train_dat, test_dat, val_dat) \
in enumerate(zip(client_train_data, client_test_data, client_val_data)) if len(train_dat)> 0 ]
else:
clients = [Client(lambda: NextCharacterLSTM(), lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9),\
train_data = train_dat, test_data = test_dat, val_data = val_dat, idnum=i) for i, (train_dat, test_dat, val_dat) \
in enumerate(zip(client_train_data, client_test_data, client_val_data)) if len(train_dat)> 0 ]
if args.method == "Clustered":
server = ClusteredServer(lambda : NextCharacterLSTM())
elif args.method == "My" or args.method == "Overlap":
server = CommunityServer(lambda : NextCharacterLSTM())
else:
server = Server(lambda : NextCharacterLSTM())
# 初始化各client的权重
client_n_samples = torch.tensor([client.n_train_samples for client in clients])
samples_sum = client_n_samples.sum()
for client in clients:
client.weight = client.n_train_samples/samples_sum
# 从现有client中抽取部分样本进行可视化
# if args.dataset != "Shakespeare":
# display_sample(clients, classes)
return clients, server