Installing Jax with GPU Support
I’ve been struggling with installing Jax with GPU support. Something seems to go wrong each time I try. Sometimes it can’t find the GPU. Other times some dynamic libraries will be missing. More frustratingly, sometimes things get messed up when I install new libraries such as Optax. Finally, I think I’ve got it all working. So here’s how I did it.
Basic environment
I use conda
virtual environments for my Python as it keeps my workstation configurations clean, especially when it comes to work with GPUs. conda
helps me manage the required cuda toolkits and libraries without me having to make permanent installations on my workstation. In this tutorial, I will be installing Jax in a conda
virtual environment.
conda create -n jax python=3.9 pip
Activate the virtual environment using conda activate jax
and proceed with the following steps.
1. Installing nvcc
According to the Jax installation guide, Jax requires ptxas
which is part of the cuda-nvcc
package on conda
. On my workstation, my GPU driver version is 510.108.03
and the corresponding Cuda version is 11.6. To keep everything consistent, I will install cuda-nvcc
built with cuda-11.6
using conda
as follows:
conda install cuda-nvcc -c "nvidia/label/cuda-11.6.2"
For different versions, check out the conda page.
Note that both Jax and cuda-nvcc are installed together using conda
in the installation guide. I have found that it doesn’t work for me. I think it has something to do with how conda
resolves package dependencies.
I will be installing Jax using pip at a later step.
2. Installing cudnn
After all, why not?
conda install cudnn=8.2 -c conda-forge
This is also required for Jax later.
3. Install Jax using pip
I have found that using pip works better for me.
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Note that I have matched the cuda
and cudnn
to what we have installed in the virtual environment.
4. Install Optax using pip
Let’s face it, if you are going to do machine learning work, you’re most likely going to use optimizers. Might as well install optax
the default library of optimizers for Jax.
pip install optax
That’s it! It should work and detect the GPU properly.