ML | JAX

JAX

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

Installing JAX

https://jax.readthedocs.io/en/latest/installation.html

TL;DR For most users, a typical JAX installation may look something like this:

  • CPU-only (Linux/macOS/Windows)
  • GPU (NVIDIA, CUDA 12)
  • TPU (Google Cloud TPU VM)

Install (NVIDIA)

1
2
3
4
$ micromamba activate wpsze_venv
$ micromamba install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia

$ nvidia-smi

Version Issue

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated CUDA installation found.
Version JAX was built against: 11080
Minimum supported: 12010
Installed version: 11080
The local installation version must be no lower than 12010.
--------------------------------------------------
Outdated cuBLAS installation found.
Version JAX was built against: 111103
Minimum supported: 120100
Installed version: 111103
The local installation version must be no lower than 120100.
--------------------------------------------------
Outdated cuSPARSE installation found.
Version JAX was built against: 11705
Minimum supported: 12100
Installed version: 11705
The local installation version must be no lower than 12100..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[0. 1.05 2.1 3.1499999 4.2 ]

Driver/CUDA version,

CUDA Toolkit Minimum Required Driver Version for CUDA Minor Version Compatibility*
Linux x86_64 Driver Version Windows x86_64 Driver Version
CUDA 12.1.x 525.60.13 527.41
CUDA 12.0.x 525.60.13 527.41
CUDA 11.8.x 450.80.02 452.39
CUDA 11.7.x 450.80.02 452.39

Install (CPU only)

1
conda install jax -c conda-forge

Test 1:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#!/bin/python

# 导入库
from jax import grad
import jax.numpy as jnp

# 定义logistic函数
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)

# 获得logistic函数的梯度函数
grad_logistic = grad(logistic)

# 求值logistic函数在x = 1处的梯度
grad_log_out = grad_logistic(1.0)
print(grad_log_out)

print("Final result is 0.19661194 ")

ML | JAX
https://waipangsze.github.io/2024/06/25/PINN-JAX/
Author
wpsze
Posted on
June 25, 2024
Updated on
October 28, 2024
Licensed under