-
Notifications
You must be signed in to change notification settings - Fork 1
/
federated.py
80 lines (57 loc) · 2.11 KB
/
federated.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from copy import deepcopy
from threading import Thread, Lock
import torch
from dataset import mnist_dataset
from model.layers import CNN
from model.train import test, train
clientModels = []
clientModelsLock = Lock()
def clientTraining(serverModel, clientDatasets, client, round):
global clientModels
clientTrainingSet = clientDatasets[client][round]
trainLoader = mnist_dataset.get_dataloader(clientTrainingSet)
clientModel = deepcopy(serverModel)
trainedClientModel = train(clientModel, trainLoader)
clientModelsLock.acquire()
clientModels.append(trainedClientModel)
clientModelsLock.release()
print(f"Client {client+1} done")
def fedAvg(clientModels):
averagedModel = deepcopy(clientModels[0])
with torch.no_grad():
for model in clientModels[1:]:
for param1, param2 in zip(averagedModel.parameters(), model.parameters()):
param1.data += param2.data
for param in averagedModel.parameters():
param.data /= len(clientModels)
return averagedModel
class federatedConfig:
clientNum = 3
trainingRounds = 2
def federated():
global clientModels
config = federatedConfig()
trainSet = mnist_dataset.load_dataset(isTrainDataset=True)
clientDatasets = mnist_dataset.split_client_datasets(
trainSet, config.clientNum, config.trainingRounds
)
testSet = mnist_dataset.load_dataset(isTrainDataset=False)
testLoader = mnist_dataset.get_dataloader(testSet)
serverModel = CNN()
for round in range(config.trainingRounds):
print(f"Round {round+1} started")
clientModels.clear()
clientThreads = []
for client in range(config.clientNum):
t = Thread(
target=clientTraining, args=(serverModel, clientDatasets, client, round)
)
t.start()
clientThreads.append(t)
for t in clientThreads:
t.join()
serverModel = fedAvg(clientModels)
testAcc = test(serverModel, testLoader)
print(f"Round {round+1} done\tAccuracy = {testAcc}")
if __name__ == "__main__":
federated()