JAX : Tutorials : JAX クイックスタート (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/17/2020 (0.1.63)
* 本ページは、JAX の以下のページを翻訳した上で適宜、補足説明したものです:
- Tutorials : JAX Quickstart
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
Tutorials : JAX クイックスタート
JAX は高パフォーマンス機械学習研究のための素晴らしい自動微分を持つ、CPU, GPU と TPU 上の NumPy です。
Autograd の更新バージョンで、JAX は native Python と NumPy コードを自動的に微分できます。それはループ、if、再帰そしてクロージャを含む、Python の特徴の巨大なサブセットを通して微分できて、そしてそれは導関数の導関数の導関数さえ取ることができます (訳注: 原文ママ)。それはフォワードモード微分に加えてリバースモードをサポートし、そして 2 つは任意の階数に構成することができます。
新しいものとして JAX は GPU と TPU のような、アクセラレータ上で貴方の NumPy コードをコンパイルして実行するために XLA を使用します。ライブリ呼び出しが just-in-time コンパイルと実行されて、デフォルトで内部ではコンパイルが発生します。しかし JAX は one-function API を使用して貴方自身の Python 関数を XLA-最適化カーネルに貴方に just-in-time コンパイルさせることさえします。コンパイルと自動微分は恣意的に構成できますので、洗練されたアルゴリズムを表現して Python を離れることなく最大限のパフォーマンスを得ることができます。
import jax.numpy as np from jax import grad, jit, vmap from jax import random
行列を乗算する
以下のサンプルでランダムデータを生成していきます。NumPy と JAX の間の一つの大きな違いは乱数をどのように生成するかです。より多くの詳細については、readme を見てください。
key = random.PRNGKey(0) x = random.normal(key, (10,)) print(x)
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/site-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.')
[-0.372111 0.2642311 -0.18252774 -0.7368198 -0.44030386 -0.15214427 -0.6713536 -0.59086424 0.73168874 0.56730247]
飛び込んで 2 つの大きな行列を乗算します。
size = 3000 x = random.normal(key, (size, size), dtype=np.float32) %timeit np.dot(x, x.T).block_until_ready() # runs on the GPU
557 ms ± 4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
block_until_ready を追加しました、JAX はデフォルトで非同期実行を使用する ためです。
JAX NumPy 関数は通常の NumPy 配列上でも動作します。
import numpy as onp # original CPU-backed NumPy x = onp.random.normal(size=(size, size)).astype(onp.float32) %timeit np.dot(x, x.T).block_until_ready()
534 ms ± 11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
それは遅いです、何故ならばそれはデータを GPU に毎回転送しなければならないためです。device_put を使用して NDArray がデバイスメモリにより支援されることを確実にできます。
from jax import device_put x = onp.random.normal(size=(size, size)).astype(onp.float32) x = device_put(x) %timeit np.dot(x, x.T).block_until_ready()
447 ms ± 8.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
device_put の出力は依然として NDArray のように動作しますが、それらがプリント、プロット、ディスクへのセーブ、分岐等のために必要とされるときそれは値を CPU にコピーし戻すだけです。device_put の動作は関数 jit(lambda x: x) に等値ですが、それはより高速です。
GPU (or TPU!) を持つ場合、これらの呼び出しはアクセラレータ上で実行されて CPU 上よりも遥かに高速である可能性を持ちます。
x = onp.random.normal(size=(size, size)).astype(onp.float32) %timeit onp.dot(x, x.T)
445 ms ± 11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX は単なる GPU-支援 NumPy を遥かに越えています。それはまた数値コードを書くときに有用な 2, 3 のプログラム変換も装備しています。今のところ、3 つの主要なものがあります :
- jit: 貴方のコードを高速化するため。
- grad: 導関数を取るため。
- vmap: 自動ベクトル化 or バッチ処理。
これらを一つずつ、調べましょう。私達はまたこれらを興味深い方法で編成することで終えます。
関数を高速化するために jit を使用する
JAX は GPU (or CPU, 一つ持たないのであれば、そして TPU coming soon!) 上で透過的に実行します。けれども、上のサンプルでは、JAX はカーネルを GPU に一度に 1 演算ディスパッチします。演算のシークエンスを持つ場合、XLA を使用して複数の演算をまとめてコンパイルするために @jit デコレータを使用できます。それを試しましょう。
def selu(x, alpha=1.67, lmbda=1.05): return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha) x = random.normal(key, (1000000,)) %timeit selu(x).block_until_ready()
4.58 ms ± 167 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
それを @jit で高速化できます、これは最初に selu が呼び出されたときに jit-compile してその後でキャッシュされます。
selu_jit = jit(selu) %timeit selu_jit(x).block_until_ready()
1.04 ms ± 17.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
grad で導関数を取る
数値関数を評価することに加えて、それらを変換することも望みます。一つの変換は自動微分です。JAX では、丁度 Autograd のように、grad 関数で勾配を計算できます。
def sum_logistic(x): return np.sum(1.0 / (1.0 + np.exp(-x))) x_small = np.arange(3.) derivative_fn = grad(sum_logistic) print(derivative_fn(x_small))
[0.25 0.19661197 0.10499357]
結果が正しいことを有限差分で検証しましょう。
def first_finite_differences(f, x): eps = 1e-3 return np.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps) for v in np.eye(len(x))]) print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1964569 0.10502338]
導関数を取ることは grad を呼び出すように容易です。grad と jit は構成して任意に混在できます。上のサンプルでは sum_logistic を jit してからその導関数を取りました。更に進むことができます :
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.03532558
より進んだ autodiff については、リバースモード・ベクトル-Jacobian 積のために jax.vjp をそしてフォワードモード Jacobian-ベクトル積のために jax.jvp を使用できます。2 つは他の一つと、そして他の JAX 変換と任意に構成できます。ここに full ヘッセ行列を効率的に計算する関数を作成するためにこれらを構成する一つの方法があります :
from jax import jacfwd, jacrev def hessian(fun): return jit(jacfwd(jacrev(fun)))
vmap による自動ベクトル化
JAX はその API でもう一つの変換を持ちます、貴方はこれが有用であることを見出すかもしれません : vmap, ベクトル化マップです。それは配列軸に沿って関数をマッピングするお馴染みのセマンティクスを持ちますが、外側でループを保持する代わりに、より良いパフォーマンスのためにループを関数のプリミティブ演算に押し込めます。jit とともに構成されるとき、それは丁度手動でバッチ次元を追加するように高速であり得ます。
単純なサンプルで作業していきます、そして vmap を使用して行列-ベクトル積を行列-行列積に promote します。この特定のケースでは手動で行なうことは容易ですが、同じテクニックがより複雑な関数に適用できます。
mat = random.normal(key, (150, 100)) batched_x = random.normal(key, (10, 100)) def apply_matrix(v): return np.dot(mat, v)
apply_matrix のような関数が与えられたとき、Python でバッチ次元に渡りループできますが、通常はそれを行なうパフォーマンスは貧弱です。
def naively_batched_apply_matrix(v_batched): return np.stack([apply_matrix(v) for v in v_batched]) print('Naively batched') %timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched 3.18 ms ± 67.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
この演算をどのように手動でバッチ処理するかを知っています。この場合、np.dot は特別なバッチ次元を透過的に処理します。
@jit def batched_apply_matrix(v_batched): return np.dot(v_batched, mat.T) print('Manually batched') %timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched 136 µs ± 1.93 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
けれども、バッチ処理サポートがないより複雑な関数を持ったと仮定します。バッチ処理サポートを自動的に追加する vmap が利用できます。
@jit def vmap_batched_apply_matrix(v_batched): return vmap(apply_matrix)(v_batched) print('Auto-vectorized with vmap') %timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap 170 µs ± 2.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
もちろん、vmap は jit, grad と任意の他の JAX 変換とともに任意に構成できます。
This is just a taste of what JAX can do. We’re really excited to see what you do with it!
以上