ML | JAX
JAX
JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.
With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.
Installing JAX
https://jax.readthedocs.io/en/latest/installation.html
TL;DR For most users, a typical JAX installation may look something like this:
- CPU-only (Linux/macOS/Windows)
- GPU (NVIDIA, CUDA 12)
- TPU (Google Cloud TPU VM)
Install (NVIDIA)
1 |
|
Version Issue
1 |
|
CUDA Toolkit | Minimum Required Driver Version for CUDA Minor Version Compatibility* | |
---|---|---|
Linux x86_64 Driver Version | Windows x86_64 Driver Version | |
CUDA 12.1.x | 525.60.13 | 527.41 |
CUDA 12.0.x | 525.60.13 | 527.41 |
CUDA 11.8.x | 450.80.02 | 452.39 |
CUDA 11.7.x | 450.80.02 | 452.39 |
Install (CPU only)
1 |
|
Test 1:
1 |
|
ML | JAX
https://waipangsze.github.io/2024/06/25/PINN-JAX/