Skip to content

mstrand1/Jax-logistic-regression

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Jax-logistic-regression

Logistic regression classifier using Google's JAX to support GPU acceleration.

This class is an update of a logistic regression class used in my intro to machine learning course. The major difference is the handling of the gradient descent operations, which were rewritten using jax's grad, jit, and vmap functions. The goal with this project is speed - I've found that using JaxReg with GPU acceleration gives a ~29x speed increase over the original class. I used Google colab's free GPU when measuring speed increase (see 'Time Comparison').