Keras 2 : examples : NLP – デシジョンツリーと事前訓練済み埋め込みを使用したテキスト分類 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 05/27/2022 (keras 2.9.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Natural Language Processing : Text classification using Decision Forests and pretrained embeddings (Author: Gitesh Chawda)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Keras 2 : examples : 自然言語処理 – デシジョンツリーと事前訓練済み埋め込みを使用したテキスト分類
Description : テキスト分類のために Tensorflow Decision Forests を使用する。
イントロダクション
TensorFlow Decision Forests (TF-DF) は、Keras API と互換な Decision Forest モデルのための最先端のアルゴリズムのコレクションです。このモジュールはランダムフォレスト, 勾配ブースティング木と CART を含み、そして回帰、分類とランキングタスクのために使用できます。
この例では、災害関連のツイートを分類するために事前訓練済みの埋め込みとともに勾配ブースティング木を使用します。
See also:
次のコマンドを使用して Tensorflow Decision Forest をインストールします : pip install tensorflow_decision_forests
インポート
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub
from tensorflow.keras import layers
import tensorflow_decision_forests as tfdf
import matplotlib.pyplot as plt
データの取得
データセットは Kaggle で利用可能です。
データセットの説明 :
ファイル :
- train.csv: 訓練セット
カラム :
- id : 各ツイートに対する一意な識別子
- text : ツイートのテキスト
- location : ツイートが送られた場所 (空白かもしれません)
- keyword : ツイートの特定のキーワード (空白かもしれません)
- target: train.csv のみで、これはツイートが実際の災害についてか (1) そうでないか (0) を表します。
# Turn .csv files into pandas DataFrame's
df = pd.read_csv(
"https://raw.githubusercontent.com/IMvision12/Tweets-Classification-NLP/main/train.csv"
)
print(df.head())
id keyword location text \ 0 1 NaN NaN Our Deeds are the Reason of this #earthquake M... 1 4 NaN NaN Forest fire near La Ronge Sask. Canada 2 5 NaN NaN All residents asked to 'shelter in place' are ... 3 6 NaN NaN 13,000 people receive #wildfires evacuation or... 4 7 NaN NaN Just got sent this photo from Ruby #Alaska as ...
target 0 1 1 1 2 1 3 1 4 1
データセットは 5 カラムを持つ 7613 サンプルを含みます :
print(f"Training dataset shape: {df.shape}")
Training dataset shape: (7613, 5)
シャッフルして不要なカラムを破棄します :
df_shuffled = df.sample(frac=1, random_state=42)
# Dropping id, keyword and location columns as these columns consists of mostly nan values
# we will be using only text and target columns
df_shuffled.drop(["id", "keyword", "location"], axis=1, inplace=True)
df_shuffled.reset_index(inplace=True, drop=True)
print(df_shuffled.head())
text target 0 So you have a new weapon that can cause un-ima... 1 1 The f$&@ing things I do for #GISHWHES Just... 0 2 DT @georgegalloway: RT @Galloway4Mayor: ÛÏThe... 1 3 Aftershock back to school kick off was great. ... 0 4 in response to trauma Children of Addicts deve... 0
シャッフルされたデータフレームの情報をプリントします :
print(df_shuffled.info())
<class 'pandas.core.frame.DataFrame'> RangeIndex: 7613 entries, 0 to 7612 Data columns (total 2 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 text 7613 non-null object 1 target 7613 non-null int64 dtypes: int64(1), object(1) memory usage: 119.1+ KB None
“disaster” と “non-disaster” ツイートの総数は :
print(
"Total Number of disaster and non-disaster tweets: "
f"{df_shuffled.target.value_counts()}"
)
Total Number of disaster and non-disaster tweets: 0 4342 1 3271 Name: target, dtype: int64
幾つかのサンプルをプレビューしてみましょう :
for index, example in df_shuffled[:5].iterrows():
print(f"Example #{index}")
print(f"\tTarget : {example['target']}")
print(f"\tText : {example['text']}")
Example #0 Target : 1 Text : So you have a new weapon that can cause un-imaginable destruction. Example #1 Target : 0 Text : The f$&@ing things I do for #GISHWHES Just got soaked in a deluge going for pads and tampons. Thx @mishacollins @/@ Example #2 Target : 1 Text : DT @georgegalloway: RT @Galloway4Mayor: ÛÏThe CoL police can catch a pickpocket in Liverpool Stree... http://t.co/vXIn1gOq4Q Example #3 Target : 0 Text : Aftershock back to school kick off was great. I want to thank everyone for making it possible. What a great night. Example #4 Target : 0 Text : in response to trauma Children of Addicts develop a defensive self - one that decreases vulnerability. (3
データセットを訓練とテストセットに分割します :
test_df = df_shuffled.sample(frac=0.1, random_state=42)
train_df = df_shuffled.drop(test_df.index)
print(f"Using {len(train_df)} samples for training and {len(test_df)} for validation")
Using 6852 samples for training and 761 for validation
訓練データの “disaster” と “non-disaster” ツイートの総数は :
print(train_df["target"].value_counts())
0 3929 1 2923 Name: target, dtype: int64
テストデータの “disaster” と “non-disaster” ツイートの総数は :
print(test_df["target"].value_counts())
0 413 1 348 Name: target, dtype: int64
データを tf.data.Dataset に変換する
def create_dataset(dataframe):
dataset = tf.data.Dataset.from_tensor_slices(
(df["text"].to_numpy(), df["target"].to_numpy())
)
dataset = dataset.batch(100)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
train_ds = create_dataset(train_df)
test_ds = create_dataset(test_df)
事前訓練済み埋め込みのダウンロード
Universal Sentence Encoder 埋め込みはテキストを高次元ベクトルにエンコードします、これはテキスト分類、セマンティック類似度、クラスタリングと他の自然言語タスクのために使用できます。それらは様々なデータソースと様々なタスクで訓練されます。入力は可変長の英語テキストで出力は 512 次元ベクトルです。
これらの事前訓練済み埋め込みについて学習するには、Universal Sentence Encoder を見てください。
sentence_encoder_layer = hub.KerasLayer(
"https://tfhub.dev/google/universal-sentence-encoder/4"
)
モデルの作成
2 つのモデルを作成します。最初のモデル (model_1) では、raw テキストは最初に事前訓練済み埋め込みでエンコードされてから分類のために勾配ブースティング木モデルに渡されます。2 番目のモデル (model_2) では、raw テキストは勾配ブースティング木モデルに直接渡されます。
Building model_1
inputs = layers.Input(shape=(), dtype=tf.string)
outputs = sentence_encoder_layer(inputs)
preprocessor = keras.Model(inputs=inputs, outputs=outputs)
model_1 = tfdf.keras.GradientBoostedTreesModel(preprocessing=preprocessor)
Use /tmp/tmpkpl10aj9 as temporary training directory
Building model_2
model_2 = tfdf.keras.GradientBoostedTreesModel()
Use /tmp/tmpysfsq6o0 as temporary training directory
モデルの訓練
メトリクス metrics Accuracy, Recall, Precision と AUC を渡すことによりモデルをコンパイルします。損失については、TF-DF はタスク (分類または回帰) に対する最善な損失を自動的に検出します。それはモデル summary でプリントされます。
また、TF-DF モデルはミニバッチ勾配降下モデルではなくバッチ訓練モデルですので、過剰適合を監視したり、訓練を早期に停止するための検証データセットは必要としません。幾つかのアルゴリズムは検証データセットを使用しません (e.g. Random Forest)、一方で幾つかの他のものは使用します (e.g. 勾配ブースティング木)。検証データセットが必要な場合には、訓練データセットから自動的に抽出されます。
# Compiling model_1
model_1.compile(metrics=["Accuracy", "Recall", "Precision", "AUC"])
# Here we do not specify epochs as, TF-DF trains exactly one epoch of the dataset
model_1.fit(train_ds)
# Compiling model_2
model_2.compile(metrics=["Accuracy", "Recall", "Precision", "AUC"])
# Here we do not specify epochs as, TF-DF trains exactly one epoch of the dataset
model_2.fit(train_ds)
Starting reading the dataset 77/77 [==============================] - ETA: 0s Dataset read in 0:00:15.844516 Training model Model trained in 0:02:30.922245 Compiling model 77/77 [==============================] - 167s 2s/step Starting reading the dataset 55/77 [====================>.........] - ETA: 0s Dataset read in 0:00:00.219258 Training model Model trained in 0:00:00.289591 Compiling model 77/77 [==============================] - 1s 6ms/step <keras.callbacks.History at 0x7f453f9349d0>
訓練メトリクスのプロット
def plot_curve(logs):
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot([log.num_trees for log in logs], [log.evaluation.accuracy for log in logs])
plt.xlabel("Number of trees")
plt.ylabel("Accuracy")
plt.subplot(1, 2, 2)
plt.plot([log.num_trees for log in logs], [log.evaluation.loss for log in logs])
plt.xlabel("Number of trees")
plt.ylabel("Loss")
plt.show()
plot_curve(logs_1)
plot_curve(logs_2)
テストデータ上の評価
results = model_1.evaluate(test_ds, return_dict=True, verbose=0)
print("model_1 Evaluation: \n")
for name, value in results.items():
print(f"{name}: {value:.4f}")
results = model_2.evaluate(test_ds, return_dict=True, verbose=0)
print("model_2 Evaluation: \n")
for name, value in results.items():
print(f"{name}: {value:.4f}")
model_1 Evaluation:
loss: 0.0000 Accuracy: 0.9631 recall: 0.9425 precision: 0.9707 auc: 0.9890 model_2 Evaluation:
loss: 0.0000 Accuracy: 0.5731 recall: 0.0064 precision: 1.0000 auc: 0.5035
検証データ上の予測
test_df.reset_index(inplace=True, drop=True)
for index, row in test_df.iterrows():
text = tf.expand_dims(row["text"], axis=0)
preds = model_1.predict_step(text)
preds = tf.squeeze(tf.round(preds))
print(f"Text: {row['text']}")
print(f"Prediction: {int(preds)}")
print(f"Ground Truth : {row['target']}")
if index == 10:
break
Text: DFR EP016 Monthly Meltdown - On Dnbheaven 2015.08.06 http://t.co/EjKRf8N8A8 #Drum and Bass #heavy #nasty http://t.co/SPHWE6wFI5 Prediction: 0 Ground Truth : 0 Text: FedEx no longer to transport bioterror germs in wake of anthrax lab mishaps http://t.co/qZQc8WWwcN via @usatoday Prediction: 1 Ground Truth : 0 Text: Gunmen kill four in El Salvador bus attack: Suspected Salvadoran gang members killed four people and wounded s... http://t.co/CNtwB6ScZj Prediction: 1 Ground Truth : 1 Text: @camilacabello97 Internally and externally screaming Prediction: 1 Ground Truth : 1 Text: Radiation emergency #preparedness starts with knowing to: get inside stay inside and stay tuned http://t.co/RFFPqBAz2F via @CDCgov Prediction: 1 Ground Truth : 1 Text: Investigators rule catastrophic structural failure resulted in 2014 Virg.. Related Articles: http://t.co/Cy1LFeNyV8 Prediction: 1 Ground Truth : 1 Text: How the West was burned: Thousands of wildfires ablaze in #California alone http://t.co/iCSjGZ9tE1 #climate #energy http://t.co/9FxmN0l0Bd Prediction: 1 Ground Truth : 1 Text: Map: Typhoon Soudelor's predicted path as it approaches Taiwan; expected to make landfall over southern China by SÛ_ http://t.co/JDVSGVhlIs Prediction: 1 Ground Truth : 1 Text: Ûª93 blasts accused Yeda Yakub dies in Karachi of heart attack http://t.co/mfKqyxd8XG #Mumbai Prediction: 1 Ground Truth : 1 Text: My ears are bleeding https://t.co/k5KnNwugwT Prediction: 0 Ground Truth : 0 Text: @RedCoatJackpot *As it was typical for them their bullets collided and none managed to reach their targets; such was the ''curse'' of a -- Prediction: 0 Ground Truth : 0
結びの言葉 (= Concluding remarks)
TensorFlow Decision Forests パッケージは構造化データで特に上手く動作するパワフルなモデルを提供します。この実験では、事前訓練済み埋め込みを持つ勾配ブースティング木モデルは 96.31% テスト精度を獲得する一方で、普通の勾配ブースティング木モデルは 57.31% 精度でした。
以上