TensorFlow 2.0 Beta : Beginner Tutorials : テキストとシークエンス :- RNN でテキスト分類 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 07/01/2019
* 本ページは、TensorFlow の本家サイトの TF 2.0 Beta – Beginner Tutorials – Text and sequences の以下のページを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
テキストとシークエンス :- RNN でテキスト分類
このテキスト分類チュートリアルはセンチメント分析のために IMDB 巨大映画レビューデータセット 上で リカレント・ニューラルネットワーク を訓練します。
from __future__ import absolute_import, division, print_function, unicode_literals !pip install -q tensorflow-gpu==2.0.0-beta1 import tensorflow_datasets as tfds import tensorflow as tf
matplotlib をインポートしてグラフをプロットするためのヘルパー関数を作成します :
import matplotlib.pyplot as plt def plot_graphs(history, string): plt.plot(history.history[string]) plt.plot(history.history['val_'+string]) plt.xlabel("Epochs") plt.ylabel(string) plt.legend([string, 'val_'+string]) plt.show()
入力パイプラインをセットアップする
IMDB 巨大映画レビューデータセットは二値分類データセットです — 総てのレビューはポジティブかネガティブなセンチメントを持ちます。
TFDS を使用してデータセットをダウンロードします。データセットは作り付けの部分語字句解析器 (= subword tokenizer) を装備しています。
dataset, info = tfds.load('imdb_reviews/subwords8k', with_info=True, as_supervised=True) train_dataset, test_dataset = dataset['train'], dataset['test']
Downloading and preparing dataset imdb_reviews (80.23 MiB) to /home/kbuilder/tensorflow_datasets/imdb_reviews/subwords8k/0.1.0... HBox(children=(IntProgress(value=1, bar_style='info', description='Dl Completed...', max=1, style=ProgressStyl… HBox(children=(IntProgress(value=1, bar_style='info', description='Dl Size...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value=''))) HBox(children=(IntProgress(value=0, description='Shuffling...', max=10, style=ProgressStyle(description_width=… WARNING: Logging before flag parsing goes to stderr. W0628 05:40:53.455971 139635681199872 deprecation.py:323] From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version. Instructions for updating: Use eager execution and: <a href="../../../versions/r2.0/api_docs/python/tf/data/TFRecordDataset"><code>tf.data.TFRecordDataset(path)</code></a> HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value=''))) HBox(children=(IntProgress(value=0, description='Shuffling...', max=10, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value=''))) HBox(children=(IntProgress(value=0, description='Shuffling...', max=20, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des… HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=… Dataset imdb_reviews downloaded and prepared to /home/kbuilder/tensorflow_datasets/imdb_reviews/subwords8k/0.1.0. Subsequent calls will reuse this data.
これは部分語字句解析器ですから、それは任意の文字列を渡すことができて字句解析器はそれをトークン化します。
tokenizer = info.features['text'].encoder
print ('Vocabulary size: {}'.format(tokenizer.vocab_size))
Vocabulary size: 8185
sample_string = 'TensorFlow is cool.' tokenized_string = tokenizer.encode(sample_string) print ('Tokenized string is {}'.format(tokenized_string)) original_string = tokenizer.decode(tokenized_string) print ('The original string: {}'.format(original_string)) assert original_string == sample_string
Tokenized string is [6307, 2327, 4043, 4265, 9, 2724, 7975] The original string: TensorFlow is cool.
字句解析器は単語がその辞書にない場合には文字列を部分語に分解してエンコードします。
for ts in tokenized_string: print ('{} ----> {}'.format(ts, tokenizer.decode([ts])))
6307 ----> Ten 2327 ----> sor 4043 ----> Fl 4265 ----> ow 9 ----> is 2724 ----> cool 7975 ----> .
BUFFER_SIZE = 10000 BATCH_SIZE = 64
train_dataset = train_dataset.shuffle(BUFFER_SIZE) train_dataset = train_dataset.padded_batch(BATCH_SIZE, train_dataset.output_shapes) test_dataset = test_dataset.padded_batch(BATCH_SIZE, test_dataset.output_shapes)
モデルを作成する
tf.keras.Sequential モデルを構築して埋め込み層から始めます。埋め込み層は単語毎に一つのベクトルをストアします。呼び出されたとき、それは単語インデックスのシークエンスをベクトルのシークエンスに変換します。これらのベクトルは訓練可能です。(十分なデータ上で) 訓練後、類似の意味を持つ単語はしばしば同様のベクトルを持ちます。
このインデックス検索は tf.keras.layers.Dense 層を通した one-hot エンコード・ベクトルを渡す等値の演算よりも遥かにより効率的です。
リカレント・ニューラルネットワーク (RNN) は要素を通した iterate によるシークエンス入力を処理します。RNN は一つの時間ステップからの出力をそれらの入力 — そして次へと渡します。
tf.keras.layers.Bidirectional ラッパーはまた RNN 層とともに使用できます。これは RNN 層を通して入力を foward そして backward に伝播してそれから出力を連結します。これは RNN が長期の (= long range) 依存性を学習する手助けをします。
model = tf.keras.Sequential([ tf.keras.layers.Embedding(tokenizer.vocab_size, 64), tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(1, activation='sigmoid') ])
訓練プロセスを構成するために Keras モデルをコンパイルします :
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
モデルを訓練する
history = model.fit(train_dataset, epochs=10, validation_data=test_dataset)
Epoch 1/10 W0628 05:42:47.879820 139635681199872 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where 391/391 [==============================] - 367s 939ms/step - loss: 0.5840 - accuracy: 0.6862 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00 Epoch 2/10 391/391 [==============================] - 138s 352ms/step - loss: 0.4049 - accuracy: 0.8229 - val_loss: 0.4430 - val_accuracy: 0.7974 Epoch 3/10 391/391 [==============================] - 102s 262ms/step - loss: 0.3484 - accuracy: 0.8620 - val_loss: 0.4643 - val_accuracy: 0.7994 Epoch 4/10 391/391 [==============================] - 98s 250ms/step - loss: 0.3774 - accuracy: 0.8357 - val_loss: 0.4392 - val_accuracy: 0.8188 Epoch 5/10 391/391 [==============================] - 86s 219ms/step - loss: 0.2587 - accuracy: 0.9017 - val_loss: 0.4777 - val_accuracy: 0.7760 Epoch 6/10 391/391 [==============================] - 81s 208ms/step - loss: 0.2229 - accuracy: 0.9178 - val_loss: 0.4470 - val_accuracy: 0.8356 Epoch 7/10 391/391 [==============================] - 81s 206ms/step - loss: 0.1814 - accuracy: 0.9357 - val_loss: 0.5000 - val_accuracy: 0.8382 Epoch 8/10 391/391 [==============================] - 76s 193ms/step - loss: 0.2660 - accuracy: 0.8936 - val_loss: 0.5101 - val_accuracy: 0.8174 Epoch 9/10 391/391 [==============================] - 80s 204ms/step - loss: 0.2428 - accuracy: 0.9040 - val_loss: 0.5396 - val_accuracy: 0.8051 Epoch 10/10 391/391 [==============================] - 75s 191ms/step - loss: 0.2763 - accuracy: 0.8853 - val_loss: 0.6413 - val_accuracy: 0.6517
test_loss, test_acc = model.evaluate(test_dataset) print('Test Loss: {}'.format(test_loss)) print('Test Accuracy: {}'.format(test_acc))
391/Unknown - 20s 51ms/step - loss: 0.6413 - accuracy: 0.6517Test Loss: 0.6412561183695293 Test Accuracy: 0.6516799926757812
上のモデルはシークエンスに適用されるパディングをマスクしていません。これはパッドされたシークエンス上で訓練してパッドされていないシークエンス上でテストする場合に歪みに繋がる可能性があります。理想的にはモデルはパディングを無視することを学習するでしょうが、下で見れるようにそれは出力上で小さい効果を持つだけです。
prediction が >= 0.5 であれば、それはポジティブでそうでなければネガティブです。
def pad_to_size(vec, size): zeros = [0] * (size - len(vec)) vec.extend(zeros) return vec
def sample_predict(sentence, pad): tokenized_sample_pred_text = tokenizer.encode(sample_pred_text) if pad: tokenized_sample_pred_text = pad_to_size(tokenized_sample_pred_text, 64) predictions = model.predict(tf.expand_dims(tokenized_sample_pred_text, 0)) return (predictions)
# predict on a sample text without padding. sample_pred_text = ('The movie was cool. The animation and the graphics ' 'were out of this world. I would recommend this movie.') predictions = sample_predict(sample_pred_text, pad=False) print (predictions)
[[0.63414526]]
# predict on a sample text with padding sample_pred_text = ('The movie was cool. The animation and the graphics ' 'were out of this world. I would recommend this movie.') predictions = sample_predict(sample_pred_text, pad=True) print (predictions)
[[0.66387755]]
plot_graphs(history, 'accuracy')
plot_graphs(history, 'loss')
2 つあるいはそれ以上の LSTM 層をスタックする
Keras リカレント層は return_sequences コンストラクタ引数で制御される 2 つの利用可能なモードを持ちます :
- 各タイムスタンプのための連続する出力の完全なシークエンスを返すか (shape (batch_size, timesteps, output_features) の 3D tensor)、
- 各入力シークエンスのための最後の出力だけを返します (shape (batch_size, output_features) の 2D tensor)。
model = tf.keras.Sequential([ tf.keras.layers.Embedding(tokenizer.vocab_size, 64), tf.keras.layers.Bidirectional(tf.keras.layers.LSTM( 64, return_sequences=True)), tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(1, activation='sigmoid') ])
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
history = model.fit(train_dataset, epochs=10, validation_data=test_dataset)
Epoch 1/10 391/391 [==============================] - 710s 2s/step - loss: 0.6758 - accuracy: 0.5653 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00 Epoch 2/10 391/391 [==============================] - 240s 615ms/step - loss: 0.6268 - accuracy: 0.6446 - val_loss: 0.5156 - val_accuracy: 0.7750 Epoch 3/10 391/391 [==============================] - 185s 473ms/step - loss: 0.3965 - accuracy: 0.8324 - val_loss: 0.4141 - val_accuracy: 0.8158 Epoch 4/10 391/391 [==============================] - 166s 424ms/step - loss: 0.3237 - accuracy: 0.8726 - val_loss: 0.4115 - val_accuracy: 0.8274 Epoch 5/10 391/391 [==============================] - 156s 398ms/step - loss: 0.2794 - accuracy: 0.8934 - val_loss: 0.3814 - val_accuracy: 0.8395 Epoch 6/10 391/391 [==============================] - 138s 353ms/step - loss: 0.2052 - accuracy: 0.9269 - val_loss: 0.3805 - val_accuracy: 0.8464 Epoch 7/10 391/391 [==============================] - 140s 359ms/step - loss: 0.1538 - accuracy: 0.9493 - val_loss: 0.4158 - val_accuracy: 0.8466 Epoch 8/10 391/391 [==============================] - 130s 334ms/step - loss: 0.1210 - accuracy: 0.9622 - val_loss: 0.4478 - val_accuracy: 0.8453 Epoch 9/10 391/391 [==============================] - 134s 343ms/step - loss: 0.0934 - accuracy: 0.9723 - val_loss: 0.4828 - val_accuracy: 0.8485 Epoch 10/10 391/391 [==============================] - 137s 351ms/step - loss: 0.0735 - accuracy: 0.9798 - val_loss: 0.5393 - val_accuracy: 0.8396
test_loss, test_acc = model.evaluate(test_dataset) print('Test Loss: {}'.format(test_loss)) print('Test Accuracy: {}'.format(test_acc))
391/Unknown - 35s 89ms/step - loss: 0.5393 - accuracy: 0.8396Test Loss: 0.5392520240962962 Test Accuracy: 0.8395599722862244
# predict on a sample text without padding. sample_pred_text = ('The movie was not good. The animation and the graphics ' 'were terrible. I would not recommend this movie.') predictions = sample_predict(sample_pred_text, pad=False) print (predictions)
[[0.00492772]]
# predict on a sample text with padding sample_pred_text = ('The movie was not good. The animation and the graphics ' 'were terrible. I would not recommend this movie.') predictions = sample_predict(sample_pred_text, pad=True) print (predictions)
[[0.00426524]]
plot_graphs(history, 'accuracy')
plot_graphs(history, 'loss')
GRU 層 のような他の存在するリカレント層を調べてください。
以上