deeplearn.js : チュートリアル : TensorFlow モデルの移植
作成 : (株)クラスキャット セールスインフォメーション
日時 : 08/23/2017
* 本ページは、github.io の deeplearn.js サイトの Tutorials – Port TensorFlow models を翻訳した上で
適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、適宜、追加改変している場合もあります。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
このチュートリアルでは訓練と TensorFlow モデルの deeplearn.js への移植をデモします。このチュートリアルで使用されるコードとすべての必要なリソースは demos/mnist にストアされています。
MNIST データセットから手書き数字を予想する完全結合ニューラルネットワークを使用します。コードは公式 TensorFlow MNIST チュートリアル から fork されています。
NOTE : deeplearn.js repo のベース・ディレクトリを $BASE として参照します。
最初に、deeplearn.js レポジトリを clone して TensorFlow がインストールされていることを確認します。$BASE 内に cd して以下を実行することでモデルを訓練します :
python demos/mnist/fully_connected_feed.py
訓練は ~1 分かかりそしてモデル・チェックポイントを /tmp/tensorflow/mnist/tensorflow/mnist/logs/fully_connected_feed/ にストアするでしょう。
次に、TensorFlow チェックポイントから deeplearn.js へ重みを移植する必要があります。これを行なうスクリプトを提供しています。それを $BASE ディレクトリから実行します :
python scripts/dump_checkpoint_vars.py --output_dir=demos/mnist/ \ --checkpoint_file=/tmp/tensorflow/mnist/logs/fully_connected_feed/model.ckpt-1999
(訳注: 以下は dump_checkpoint_vars.py 実行時のコンソールログです : )
Writing variable softmax_linear/biases... Writing variable softmax_linear/weights... Writing variable hidden2/weights... Writing variable hidden1/weights... Writing variable hidden2/biases... Writing variable hidden1/biases... Ignoring global_step Writing manifest to demos/mnist/manifest.json Done!
スクリプトはファイルのセット (variable 毎に一つのファイル、そして manifest.json) を demos/mnist ディレクトリに保存します。manifest.json は単なる辞書で variable 名をファイルにそして shape を map します :
{ ..., "hidden1/weights": { "filename": "hidden1_weights", "shape": [784, 128] }, ... }
コーディングを開始する前の最後の一つのこと – $BASE ディレクトリから静的 HTTP サーバを実行する必要があります :
npm run prep ./node_modules/.bin/http-server >> Starting up http-server, serving ./ >> Available on: >> http://127.0.0.1:8080 >> Hit CTRL-C to stop the server
ブラウザで http://localhost:8080/demos/mnist/manifest.json を見ることにより HTTP 経由で manifest.json にアクセスできることを確認してください。
幾つかの deeplearn.js コードを書くための準備ができました!
NOTE : TypeScript で書くことを選択した場合、コードを JavaScript にコンパイルして静的 HTTP サーバ経由でそれをサーブすることを確認しましょう。
重みを読みためには、CheckpointLoader を作成してそれが manifest ファイルを指すようにする必要があります。そして loader.getAllVariables() を呼び出します、これは variable 名を NDArray にマップする辞書を返します。その時点で、私たちのモデルを書く準備ができます。CheckpointLoader の使用方法を示すスニペットがここにあります :
import {CheckpointLoader, Graph} from 'deeplearnjs'; // manifest.json is in the same dir as index.html. const reader = new CheckpointReader('.'); // 訳注: 原文のコード行は誤り、正しくは new CheckpointLoader('.'); reader.getAllVariables().then(vars => { // Write your model here. const g = new Graph(); const input = g.placeholder('input', [784]); const hidden1W = g.constant(vars['hidden1/weights']); const hidden1B = g.constant(vars['hidden1/biases']); const hidden1 = g.relu(g.add(g.matmul(input, hidden1W), hidden1B)); ... ... const math = new NDArrayMathGPU(); const sess = new Session(g, math); math.scope(() => { const result = sess.eval(...); console.log(result.getValues()); }); });
完全なモデル・コードについての詳細は demos/mnist/mnist.ts を見てください。デモは、3つの異なる API を使用する MNIST モデルの正確な実装を提供します :
- buildModelGraphAPI() は Graph API を使用します、これは TensorFlow API を模倣していて、feed と fetch による遅延実行 (lazy execution) を提供します。ユーザは入力データ以外の GPU 関連のメモリリークを心配する必要はありません。
- buildModelLayerAPI() は Graph.layers と連動する Graph API を使用します、これは Keras layers API を模倣しています。
- buildModelMathAPI() は Math API を使用します。これは deeplearn.js の低位 API でユーザに最大限の制御を与えます。Math コマンドは直ちに実行されます、numpy のように。Math コマンドは math.scope() でラップされその結果、中間的な math コマンドで作成された NDArray は自動的にクリーンアップされます。
mnist デモを実行するために、watch-demo スクリプトを提供しています、これは typescript コードを監視してそれが変更された時にリコンパイルします。更に、スクリプトは単純な HTTP サーバを 8080 上で実行します、これは静的な html/js ファイルをサーブします。watch-demo を実行する前に、チュートリアルの前の方で起動した HTTP サーバを 8080 ポートを開放するために必ず kill してください。それから $BASE から web app デモのエントリ・ポイント、demos/mnist/mnist.ts を指すように watch-demo を実行します :
./scripts/watch-demo demos/mnist/mnist.ts >> Starting up http-server, serving ./ >> Available on: >> http://127.0.0.1:8080 >> http://192.168.1.5:8080 >> Hit CTRL-C to stop the server >> 1410084 bytes written to demos/mnist/bundle.js (0.91 seconds) at 5:17:45 PM
http://localhost:8080/demos/mnist/ を見てください、demos/mnist/sample_data.json にストアされた 50 mnist 画像のテストセットを使用して計測されたテスト精度 ~90% を示す単純なページを見るはずです。デモで自由に遊んでください (e.g. それを対話的にするとか)、そして pull リクエストを私たちに送ってください!
(訳注 : demos/mnist の実行結果のスナップショットです : )
以上