High performance machine learning with JAX

Fri September 10, 10:15 AM–10:45 AM • Back to program
Session Type Live
Start time 10:15
End time 10:45
Countdown link Open timer

JAX provides an elegant interface to XLA with automatic differentiation allowing extremely high performance machine learning on modern accelerators; all from within Python. In this talk we'll give an overview to fundamentals of JAX and an intro to some of the libraries being developed on top.

JAX is the next generation machine learning library developed by Google Research. It provides a pure Python interface to the domain-specific compiler XLA (Accelerated Linear Algebra) that targets a range of accelerator hardware including GPUs and TPUs (Google's machine learning developed ASIC). Additionally JAX is the next generation of the autograd library which provides a rich set of optimisation tooling for numerous modern machine learning approaches, with an emphasis on training neural networks with gradient descent. This talk will outline the fundamentals of JAX programming, demonstrate some of the TPU specific capability for large scale distributed training as well as do a short review of the higher level libraries built on top of JAX.

Mat Kelcey he/him

Mat is an ML Research Engineer at Edge Impulse. He has worked across a range of machine learning domains over the last 20 years including work at ThoughtWorks, Google Brain, Wavii & AWS. Mat blogs at http://matpalm.com/