Google JAX,是Google開發的用於變換數值函數的Python機器學習框架[3][4][5]。它結合了修改版本的Autograd(自動通過函數的微分獲得其梯度函數)[6],和TensorFlow的XLA(加速線性代數)[7]。它被設計為儘可能的遵從NumPy的結構和工作流程,並協同工作於各種現存的框架如TensorFlow和PyTorch[8][9]。
主要功能
JAX的主要功能是[3]:
- grad:自動微分,
- jit:即時編譯,
- vmap:自動向量化,
- pmap:SPMD編程。
grad
下面的代碼演示grad
函數的自動微分。
# 导入库
from jax import grad
import jax.numpy as jnp
# 定义logistic函数
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)
# 获得logistic函数的梯度函数
grad_logistic = grad(logistic)
# 求值logistic函数在x = 1处的梯度
grad_log_out = grad_logistic(1.0)
print(grad_log_out)
最終的輸出為:
0.19661194
jit
下面的代碼演示jit
函數的優化。
# 导入库
from jax import jit
import jax.numpy as jnp
# 定义cube函数
def cube(x):
return x * x * x
# 生成数据
x = jnp.ones((10000, 10000))
# 创建cube函数的jit版本
jit_cube = jit(cube)
# 应用cube函数和jit_cube函数于相同数据来比较其速度
cube(x)
jit_cube(x)
可見jit_cube
的運行時間顯著的短於cube
。
vmap
下面的代碼展示vmap
函數的通過SIMD的向量化。
# 导入库
from functools import partial
from jax import vmap
import jax.numpy as jnp
# 定义函数
def grads(self, inputs):
in_grad_partial = partial(self._net_grads, self._net_params)
grad_vmap = vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
return flat_grads
pmap
下面的代碼展示pmap
函數的對矩陣乘法的並行化。
# 从JAX导入pmap和random;导入JAX NumPy
from jax import pmap, random
import jax.numpy as jnp
# 生成2个维度为5000 x 6000的随机数矩阵,每设备一个
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
# 没有数据传输,并行的在每个CPU/GPU上进行局部矩阵乘法
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
# 没有数据传输,并行的在每个CPU/GPU上分别求取这两个矩阵的均值
means = pmap(jnp.mean)(outputs)
print(means)
最終的輸出為:
[1.1566595 1.1805978]
Remove ads
使用JAX的庫
一些Python庫使用JAX作為後端,這包括:
- Flax,最初由Google Brain開發的高層人工神經網絡庫[10]。
- Equinox,將參數化函數(包括人工神經網絡)表示為PyTree的庫。它由Patrick Kidger創建[11]。
- Diffrax,用於求微分方程的數值解的庫,比如解常微分方程和隨機微分方程[12]。
- Optax,DeepMind開發的用於梯度處理和最優化的庫[13]。
- Lineax,用於解線性方程組和線性最小二乘法[14]。
- RLax,DeepMind開發的用於強化學習的庫[15]
- jraph,DeepMind開發的圖神經網絡庫[16]。
- jaxtyping,用於為陣列或張量的形狀和數據類型增加類型標註的庫[17]。
- NumPyro,概率編程庫[18]。
- Brax,物理引擎[19]。
參見
引用
外部連結
Wikiwand in your browser!
Seamless Wikipedia browsing. On steroids.
Every time you click a link to Wikipedia, Wiktionary or Wikiquote in your browser's search results, it will show the modern Wikiwand interface.
Wikiwand extension is a five stars, simple, with minimum permission required to keep your browsing private, safe and transparent.
Remove ads