Skip to content

Simulating FedAvg and the Federated Learning protocol to train a CNN against the MNIST dataset

Notifications You must be signed in to change notification settings

BaselOmari/Federated-Learning-Simulator

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

71 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Federated Learning Simulator

Overview

Introduced in "Communication-Efficient Learning of Deep Networks from Decentralized Data" by Brendan McMahan et al., Federated Learning flips the paradigm of machine learning.

When training a machine learning model based on client data, instead of having the data be sent from the clients to the server, the model is sent from the server to the clients.

This project provides a multi-threaded simulation of the Federated Averaging Algorithm presented in the paper, using PyTorch and Python.

The framework is tested against the MNIST dataset (in federated.py) and compared with a regular baseline (in base.py).

In-Depth

The following is a simplified overview of the Federated Averaging algorithm presenetd in McMahan et al.:

  1. The server initiates a plain machine learning model with given weights

Server with Plain Model, and Clients

  1. The server sends the model to each client

Server and Clients with Plain Model

  1. The clients update their copy of the model with their local dataset
  2. The clients send back the model to the server

Clients send back updated model

  1. The server averages all the weights of all the models that it has received back, weighted by the size of local dataset of each client. This average represents the weights of the trained model

Server aggregates received models

  1. Steps 2-5 are then continually repeated when the clients obtain new data

Results

The framework is tested against the MNIST dataset and compared with a regular baseline.

The Federated Learning simulator achieved 92% accuracy training on the MNIST dataset, which I compared to a standard baseline that achieved 95% accuracy. While performance was compromised, it is very minimal compared to the privacy-preserving advantages that this framework offers.

About

Simulating FedAvg and the Federated Learning protocol to train a CNN against the MNIST dataset

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages