TensorFlow 2.0 : 上級 Tutorials : 構造化データ :- 不均衡なデータ上の分類 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/18/2019
* 本ページは、TensorFlow org サイトの TF 2.0 – Advanced Tutorials – Structured data の以下のページを翻訳した上で
適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
構造化データ :- 不均衡なデータ上の分類
このチュートリアルでは非常に不均衡なデータセットをどのように分類するかを実演します、そこでは一つのクラスのサンプル数がもう一つの (クラスの) サンプルに大いに数で上回ります。Kaggle でホストされている Credit Card Fraud Detection データセットで作業します。目的は全部で 284,807 トランザクションから僅か 492 の詐欺トランザクションを検出することです。不均衡なデータからモデルが学習することを助けるためにモデルと クラス重み を定義するために Keras を使用します。
このチュートリアルは以下を行なうための完全なコードを含みます :
- Pandas を使用して CSV ファイルをロードする。
- 訓練、検証とテストセットを作成する。
- (クラス重みを設定することを含む) Keras を使用してモデルを定義して訓練する。
- (精度 (= precision) と recall を含む) 様々なメトリクスを使用してモデルを評価する。
- 次のような不均衡なデータを扱うための一般的なテクニックを試す :
- クラス重み付け (= Class weighting)
- Oversampling
セットアップ
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf from tensorflow import keras import os import tempfile import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import sklearn from sklearn.metrics import confusion_matrix from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler
mpl.rcParams['figure.figsize'] = (12, 10) colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
データ前処理と調査
Kaggle Credit Card Fraud データセットをダウンロードする
Pandas は構造化データをロードしてそれで作業するための多くの役立つユティリティを持つ Python ライブラリで CSV を dataframe にダウンロードするために使用できます。
★ Note: このデータセットはビッグ・データマイニングと詐欺検出上の Worldline と ULB (Université Libre de Bruxelles) の 機械学習グループ の研究コラボレーションの間に収集されて解析されました。関連トピックについて現在と過去のプロジェクトのより詳細は ここ と DefeatFraud プロジェクトのページで利用可能です。
file = tf.keras.utils raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv') raw_df.head()
Time | V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | … | V21 | V22 | V23 | V24 | V25 | V26 | V27 | V28 | Amount | Class | |
0 | 0.0 | -1.359807 | -0.072781 | 2.536347 | 1.378155 | -0.338321 | 0.462388 | 0.239599 | 0.098698 | 0.363787 | … | -0.018307 | 0.277838 | -0.110474 | 0.066928 | 0.128539 | -0.189115 | 0.133558 | -0.021053 | 149.62 | 0 |
1 | 0.0 | 1.191857 | 0.266151 | 0.166480 | 0.448154 | 0.060018 | -0.082361 | -0.078803 | 0.085102 | -0.255425 | … | -0.225775 | -0.638672 | 0.101288 | -0.339846 | 0.167170 | 0.125895 | -0.008983 | 0.014724 | 2.69 | 0 |
2 | 1.0 | -1.358354 | -1.340163 | 1.773209 | 0.379780 | -0.503198 | 1.800499 | 0.791461 | 0.247676 | -1.514654 | … | 0.247998 | 0.771679 | 0.909412 | -0.689281 | -0.327642 | -0.139097 | -0.055353 | -0.059752 | 378.66 | 0 |
3 | 1.0 | -0.966272 | -0.185226 | 1.792993 | -0.863291 | -0.010309 | 1.247203 | 0.237609 | 0.377436 | -1.387024 | … | -0.108300 | 0.005274 | -0.190321 | -1.175575 | 0.647376 | -0.221929 | 0.062723 | 0.061458 | 123.50 | 0 |
4 | 2.0 | -1.158233 | 0.877737 | 1.548718 | 0.403034 | -0.407193 | 0.095921 | 0.592941 | -0.270533 | 0.817739 | … | -0.009431 | 0.798278 | -0.137458 | 0.141267 | -0.206010 | 0.502292 | 0.219422 | 0.215153 | 69.99 | 0 |
5 rows × 31 columns
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()
Time | V1 | V2 | V3 | V4 | V5 | V26 | V27 | V28 | Amount | クラス | |
カウント | 284807.000000 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 284807.000000 | 284807.000000 |
平均 | 94813.859575 | 1.165980e-15 | 3.416908e-16 | -1.373150e-15 | 2.086869e-15 | 9.604066e-16 | 1.687098e-15 | -3.666453e-16 | -1.220404e-16 | 88.349619 | 0.001727 |
標準偏差 | 47488.145955 | 1.958696e+00 | 1.651309e+00 | 1.516255e+00 | 1.415869e+00 | 1.380247e+00 | 4.822270e-01 | 4.036325e-01 | 3.300833e-01 | 250.120109 | 0.041527 |
最小値 | 0.000000 | -5.640751e+01 | -7.271573e+01 | -4.832559e+01 | -5.683171e+00 | -1.137433e+02 | -2.604551e+00 | -2.256568e+01 | -1.543008e+01 | 0.000000 | 0.000000 |
25% | 54201.500000 | -9.203734e-01 | -5.985499e-01 | -8.903648e-01 | -8.486401e-01 | -6.915971e-01 | -3.269839e-01 | -7.083953e-02 | -5.295979e-02 | 5.600000 | 0.000000 |
50% | 84692.000000 | 1.810880e-02 | 6.548556e-02 | 1.798463e-01 | -1.984653e-02 | -5.433583e-02 | -5.213911e-02 | 1.342146e-03 | 1.124383e-02 | 22.000000 | 0.000000 |
75% | 139320.500000 | 1.315642e+00 | 8.037239e-01 | 1.027196e+00 | 7.433413e-01 | 6.119264e-01 | 2.409522e-01 | 9.104512e-02 | 7.827995e-02 | 77.165000 | 0.000000 |
最大値 | 172792.000000 | 2.454930e+00 | 2.205773e+01 | 9.382558e+00 | 1.687534e+01 | 3.480167e+01 | 3.517346e+00 | 3.161220e+01 | 3.384781e+01 | 25691.160000 | 1.000000 |
クラスラベルが不均衡であることを調べる
不均衡なデータセットを見てみましょう :
neg, pos = np.bincount(raw_df['Class']) total = neg + pos print('Examples:\n Total: {}\n Positive: {} ({:.2f}% of total)\n'.format( total, pos, 100 * pos / total))
Examples: Total: 284807 Positive: 492 (0.17% of total)
これはポジティブなサンプルの小さい割合を示します。
データをクリーンアップ、分割して正規化する
生データは 2, 3 の問題を持ちます。最初に Time と Amount カラムは直接使用するには変化し過ぎです。Time カラムは破棄して (何故ならばそれが何を意味するか明確でないからです) そして範囲を減じるために Amount カラムの対数を取ります。
cleaned_df = raw_df.copy() # You don't want the `Time` column. cleaned_df.pop('Time') # The `Amount` column covers a huge range. Convert to log-space. eps=0.001 # 0 => 0.1¢ cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)
データセットを訓練、検証とテストセットに分割します。検証セットは損失と任意のメトリクスを評価するためにモデル fitting の間に使用されます、けれどもモデルはこのデータでは fit されません。テストセットは訓練段階の間には全く使用されません、そしてモデルが新しいデータにどのくらい上手く一般化されたかを評価するために最後に使用されるだけです。これは不均衡なデータセットでは特に重要です、そこでは overfitting は訓練データの欠落から本質的な関心事です。
# Use a utility from sklearn to split and shuffle our dataset. train_df, test_df = train_test_split(cleaned_df, test_size=0.2) train_df, val_df = train_test_split(train_df, test_size=0.2) # Form np arrays of labels and features. train_labels = np.array(train_df.pop('Class')) bool_train_labels = train_labels != 0 val_labels = np.array(val_df.pop('Class')) test_labels = np.array(test_df.pop('Class')) train_features = np.array(train_df) val_features = np.array(val_df) test_features = np.array(test_df)
sklearn StandardScaler を使用して入力特徴を正規化します。これは平均を 0 そして標準偏差を 1 に設定します。
★ Note: StandardScaler はモデルが検証やテストセットを覗き見ないことを確実にするために train_features を使用して fit するだけです。
scaler = StandardScaler() train_features = scaler.fit_transform(train_features) val_features = scaler.transform(val_features) test_features = scaler.transform(test_features) train_features = np.clip(train_features, -5, 5) val_features = np.clip(val_features, -5, 5) test_features = np.clip(test_features, -5, 5) print('Training labels shape:', train_labels.shape) print('Validation labels shape:', val_labels.shape) print('Test labels shape:', test_labels.shape) print('Training features shape:', train_features.shape) print('Validation features shape:', val_features.shape) print('Test features shape:', test_features.shape)
Training labels shape: (182276,) Validation labels shape: (45569,) Test labels shape: (56962,) Training features shape: (182276, 29) Validation features shape: (45569, 29) Test features shape: (56962, 29)
データ分布を見る
次に 2,3 の特徴に渡りポジティブとネガティブ・サンプルの分布を比較します。この時点で貴方自身に尋ねる良い質問は :
- これらの分布は意味があるでしょうか?
- はい。入力を正規化してそしてこれらは +/- 2 範囲に殆ど集中しています。
- 分布の間に違いを見ることができるでしょうか?
- はい、ポジティブ・サンプルは遥かに高いレートの極値 (= extreme values) を含みます。
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns = train_df.columns) neg_df = pd.DataFrame(train_features[~bool_train_labels], columns = train_df.columns) sns.jointplot(pos_df['V5'], pos_df['V6'], kind='hex', xlim = (-5,5), ylim = (-5,5)) plt.suptitle("Positive distribution") sns.jointplot(neg_df['V5'], neg_df['V6'], kind='hex', xlim = (-5,5), ylim = (-5,5)) _ = plt.suptitle("Negative distribution")
モデルとメトリクスを定義する
密に接続された隠れ層、overfitting を減じる dropout 層、そしてトランザクションが不正である確率を返す出力 sigmoid 層を持つ単純なネットワークを作成する関数を定義します :
METRICS = [ keras.metrics.TruePositives(name='tp'), keras.metrics.FalsePositives(name='fp'), keras.metrics.TrueNegatives(name='tn'), keras.metrics.FalseNegatives(name='fn'), keras.metrics.BinaryAccuracy(name='accuracy'), keras.metrics.Precision(name='precision'), keras.metrics.Recall(name='recall'), keras.metrics.AUC(name='auc'), ] def make_model(metrics = METRICS, output_bias=None): if output_bias is not None: output_bias = tf.keras.initializers.Constant(output_bias) model = keras.Sequential([ keras.layers.Dense( 16, activation='relu', input_shape=(train_features.shape[-1],)), keras.layers.Dropout(0.5), keras.layers.Dense(1, activation='sigmoid', bias_initializer=output_bias), ]) model.compile( optimizer=keras.optimizers.Adam(lr=1e-3), loss=keras.losses.BinaryCrossentropy(), metrics=metrics) return model
有用なメトリクスを理解する
パフォーマンスを評価するときに有用である、モデルにより計算可能な上で定義された 2, 3 のメトリクスがあることに注意してください。
- false negative と false positive は 間違って 分類されたサンプルです
- true negative と true positive は 正しく 分類されたサンプルです
- accuracy は正しく分類されたサンプルのパーセンテージです > $\frac{\text{true samples}}{\text{total samples}}$
- precision は正しく分類された 予測された ポジティブのパーセンテージです > $\frac{\text{true positives}}{\text{true positives + false positives}}$
- recall は正しく分類された 実際の ポジティブのパーセンテージです > $\frac{\text{true positives}}{\text{true positives + false negatives}}$
- AUC は ROC曲線下の面積 (ROC-AUC, Area Under the Curve of a Receiver Operating Characteristic curve) を参照します。このメトリックは、分類器がランダムなポジティブ・サンプルをランダムなネガティブ・サンプルよりも高く位置付ける確率に等しいです。
★ Note: accuracy (精度) はこのタスクのために有用なメトリクスではありません。常に False を予測することによりこのタスク上 99.8%+ 精度を得られるでしょう。
Read more: * True vs. False と Positive vs. Negative * Accuracy * Precision と Recall * ROC-AUC
ベースライン・モデル
モデルを構築する
今は先に定義された関数を使用してモデルを作成して訓練します。モデルは 2048 のデフォルトよりも大きいバッチサイズを使用して fit することに注意してください、これは各バッチが 2, 3 のポジティブ・サンプルを含む妥当な機会を持つことを確実にするために重要です。バッチサイズが小さ過ぎれば、そこから学習するための詐欺的なトランザクションを持たない傾向になるでしょう。
Note: このモデルは不均衡なクラスを上手く扱いません。このチュートリアルの後でそれを改良します。
EPOCHS = 100 BATCH_SIZE = 2048 early_stopping = tf.keras.callbacks.EarlyStopping( monitor='val_auc', verbose=1, patience=10, mode='max', restore_best_weights=True)
model = make_model() model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 16) 480 _________________________________________________________________ dropout (Dropout) (None, 16) 0 _________________________________________________________________ dense_1 (Dense) (None, 1) 17 ================================================================= Total params: 497 Trainable params: 497 Non-trainable params: 0
Test run the model:
model.predict(train_features[:10])
array([[0.35366294], [0.17252561], [0.08079773], [0.13621533], [0.18569201], [0.22863552], [0.3243227 ], [0.25937212], [0.27047133], [0.1558134 ]], dtype=float32)
オプション: 正しい初期バイアスを設定する
これらは初期推論は良くはありません。貴方はデータセットは不均衡であることを知っています。それを反映するように出力層のバイアスを設定します (参照: A Recipe for Training Neural Networks: “init well”)。これは初期収束に役立つことができます。
デフォルトのバイアス初期化では損失はおよそ math.log(2) = 0.69314 になるはずです
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0) print("Loss: {:0.4f}".format(results[0]))
Loss: 0.2624
設定する正しいバイアスは以下から導出できます :
initial_bias = np.log([pos/neg]) initial_bias
array([-6.35935934])
それを初期バイアスとして設定すると、モデルは遥かに合理的な初期推論を与えます。
それは pos/total = 0.0018 近くになるはずです
model = make_model(output_bias = initial_bias) model.predict(train_features[:10])
array([[0.00416172], [0.0006738 ], [0.00036496], [0.00125635], [0.00200993], [0.0012227 ], [0.01357868], [0.00356525], [0.00580611], [0.02072738]], dtype=float32)
この初期化で初期損失はおよそ次になるはずです :
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0) print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0164
この初期損失は素朴な初期化であるよりもおよそ 50 倍少ないです。
このようにしてモデルはポジティブサンプルが可能性が低いことを単に学習する、最初の数エポックを消費する必要がありません。これはまた訓練の間に損失のプロットを読むことを容易にします。
初期重みをチェックポイントする
様々な訓練をより比較可能に実行するために、この初期モデルの重みをチェックポイント・ファイルに保持します、そしてそれらを訓練前に各モデルにロードします。
initial_weights = os.path.join(tempfile.mkdtemp(),'initial_weights') model.save_weights(initial_weights)
バイアス修正 (= fix) が役立つことをを確認する
進む前に、注意深いバイアス初期化が実際に役立つことを素早く確認します。
この注意深い初期化ありとなしで、20 エポックの間モデルを訓練します、そして損失を比較します :
model = make_model() model.load_weights(initial_weights) model.layers[-1].bias.assign([0.0]) zero_bias_history = model.fit( train_features, train_labels, batch_size=BATCH_SIZE, epochs=20, validation_data=(val_features, val_labels), verbose=0)
model = make_model() model.load_weights(initial_weights) careful_bias_history = model.fit( train_features, train_labels, batch_size=BATCH_SIZE, epochs=20, validation_data=(val_features, val_labels), verbose=0)
def plot_loss(history, label, n): # Use a log scale to show the wide range of values. plt.semilogy(history.epoch, history.history['loss'], color=colors[n], label='Train '+label) plt.semilogy(history.epoch, history.history['val_loss'], color=colors[n], label='Val '+label, linestyle="--") plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend()
plot_loss(zero_bias_history, "Zero Bias", 0) plot_loss(careful_bias_history, "Careful Bias", 1)
上の図は次を明白にします: この問題上、検証損失の観点から、この注意深い初期化は明白な優位を与えます。
モデルを訓練する
model = make_model() model.load_weights(initial_weights) baseline_history = model.fit( train_features, train_labels, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks = [early_stopping], validation_data=(val_features, val_labels))
Train on 182276 samples, validate on 45569 samples Epoch 1/100 182276/182276 [==============================] - 3s 15us/sample - loss: 0.0148 - tp: 33.0000 - fp: 107.0000 - tn: 181864.0000 - fn: 272.0000 - accuracy: 0.9979 - precision: 0.2357 - recall: 0.1082 - auc: 0.6736 - val_loss: 0.0072 - val_tp: 3.0000 - val_fp: 2.0000 - val_tn: 45485.0000 - val_fn: 79.0000 - val_accuracy: 0.9982 - val_precision: 0.6000 - val_recall: 0.0366 - val_auc: 0.9058 Epoch 2/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0087 - tp: 94.0000 - fp: 34.0000 - tn: 181937.0000 - fn: 211.0000 - accuracy: 0.9987 - precision: 0.7344 - recall: 0.3082 - auc: 0.8194 - val_loss: 0.0049 - val_tp: 40.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 42.0000 - val_accuracy: 0.9989 - val_precision: 0.8511 - val_recall: 0.4878 - val_auc: 0.9264 Epoch 3/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0075 - tp: 122.0000 - fp: 33.0000 - tn: 181938.0000 - fn: 183.0000 - accuracy: 0.9988 - precision: 0.7871 - recall: 0.4000 - auc: 0.8638 - val_loss: 0.0043 - val_tp: 47.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 35.0000 - val_accuracy: 0.9991 - val_precision: 0.8704 - val_recall: 0.5732 - val_auc: 0.9266 Epoch 4/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0065 - tp: 131.0000 - fp: 32.0000 - tn: 181939.0000 - fn: 174.0000 - accuracy: 0.9989 - precision: 0.8037 - recall: 0.4295 - auc: 0.8855 - val_loss: 0.0040 - val_tp: 57.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 25.0000 - val_accuracy: 0.9993 - val_precision: 0.8906 - val_recall: 0.6951 - val_auc: 0.9327 Epoch 5/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0055 - tp: 172.0000 - fp: 28.0000 - tn: 181943.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8600 - recall: 0.5639 - auc: 0.9170 - val_loss: 0.0038 - val_tp: 59.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 23.0000 - val_accuracy: 0.9993 - val_precision: 0.8939 - val_recall: 0.7195 - val_auc: 0.9326 Epoch 6/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0056 - tp: 158.0000 - fp: 32.0000 - tn: 181939.0000 - fn: 147.0000 - accuracy: 0.9990 - precision: 0.8316 - recall: 0.5180 - auc: 0.9008 - val_loss: 0.0036 - val_tp: 59.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 23.0000 - val_accuracy: 0.9993 - val_precision: 0.8939 - val_recall: 0.7195 - val_auc: 0.9326 Epoch 7/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0055 - tp: 157.0000 - fp: 31.0000 - tn: 181940.0000 - fn: 148.0000 - accuracy: 0.9990 - precision: 0.8351 - recall: 0.5148 - auc: 0.9112 - val_loss: 0.0035 - val_tp: 62.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8986 - val_recall: 0.7561 - val_auc: 0.9326 Epoch 8/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0048 - tp: 172.0000 - fp: 33.0000 - tn: 181938.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8390 - recall: 0.5639 - auc: 0.9130 - val_loss: 0.0034 - val_tp: 62.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8986 - val_recall: 0.7561 - val_auc: 0.9326 Epoch 9/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0051 - tp: 159.0000 - fp: 28.0000 - tn: 181943.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8503 - recall: 0.5213 - auc: 0.9081 - val_loss: 0.0033 - val_tp: 62.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8986 - val_recall: 0.7561 - val_auc: 0.9326 Epoch 10/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0050 - tp: 169.0000 - fp: 33.0000 - tn: 181938.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8366 - recall: 0.5541 - auc: 0.9215 - val_loss: 0.0033 - val_tp: 62.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8986 - val_recall: 0.7561 - val_auc: 0.9326 Epoch 11/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0051 - tp: 157.0000 - fp: 29.0000 - tn: 181942.0000 - fn: 148.0000 - accuracy: 0.9990 - precision: 0.8441 - recall: 0.5148 - auc: 0.9233 - val_loss: 0.0032 - val_tp: 62.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8986 - val_recall: 0.7561 - val_auc: 0.9326 Epoch 12/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0045 - tp: 178.0000 - fp: 37.0000 - tn: 181934.0000 - fn: 127.0000 - accuracy: 0.9991 - precision: 0.8279 - recall: 0.5836 - auc: 0.9317 - val_loss: 0.0031 - val_tp: 62.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8986 - val_recall: 0.7561 - val_auc: 0.9326 Epoch 13/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0047 - tp: 160.0000 - fp: 32.0000 - tn: 181939.0000 - fn: 145.0000 - accuracy: 0.9990 - precision: 0.8333 - recall: 0.5246 - auc: 0.9185 - val_loss: 0.0031 - val_tp: 63.0000 - val_fp: 8.0000 - val_tn: 45479.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8873 - val_recall: 0.7683 - val_auc: 0.9326 Epoch 14/100 161792/182276 [=========================>....] - ETA: 0s - loss: 0.0045 - tp: 155.0000 - fp: 31.0000 - tn: 161488.0000 - fn: 118.0000 - accuracy: 0.9991 - precision: 0.8333 - recall: 0.5678 - auc: 0.9256Restoring model weights from the end of the best epoch. 182276/182276 [==============================] - 1s 3us/sample - loss: 0.0044 - tp: 175.0000 - fp: 33.0000 - tn: 181938.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8413 - recall: 0.5738 - auc: 0.9251 - val_loss: 0.0031 - val_tp: 64.0000 - val_fp: 8.0000 - val_tn: 45479.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8889 - val_recall: 0.7805 - val_auc: 0.9326 Epoch 00014: early stopping
訓練履歴を確認する
このセクションでは、モデルの精度と訓練と検証セット上の損失のプロットを生成します。これは overfitting を確認するために有用です、それについてはこの チュートリアル で更に学習できます。
更に、上で作成した任意のメトリクスのためにこれらのプロットを生成できます。false negative がサンプルとして含まれます。
def plot_metrics(history): metrics = ['loss', 'auc', 'precision', 'recall'] for n, metric in enumerate(metrics): name = metric.replace("_"," ").capitalize() plt.subplot(2,2,n+1) plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train') plt.plot(history.epoch, history.history['val_'+metric], color=colors[0], linestyle="--", label='Val') plt.xlabel('Epoch') plt.ylabel(name) if metric == 'loss': plt.ylim([0, plt.ylim()[1]]) elif metric == 'auc': plt.ylim([0.8,1]) else: plt.ylim([0,1]) plt.legend()
plot_metrics(baseline_history)
★ Note: 検証カーブは訓練カーブよりも一般により良く遂行します。これは主として、モデルを評価するとき dropout 層が有効ではないという事実によります。
メトリクスを評価する
実際 (= actual) vs 予測されたラベルを要約するために 混同行列 を使用できます、そこでは X 軸は予測されたラベルで Y 軸は実際のラベルです。
train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE) test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
def plot_cm(labels, predictions, p=0.5): cm = confusion_matrix(labels, predictions > p) plt.figure(figsize=(5,5)) sns.heatmap(cm, annot=True, fmt="d") plt.title('Confusion matrix @{:.2f}'.format(p)) plt.ylabel('Actual label') plt.xlabel('Predicted label') print('Legitimate Transactions Detected (True Negatives): ', cm[0][0]) print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1]) print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0]) print('Fraudulent Transactions Detected (True Positives): ', cm[1][1]) print('Total Fraudulent Transactions: ', np.sum(cm[1]))
テストデータセット上のモデルを評価して上で作成したメトリクスのための結果を表示します。
baseline_results = model.evaluate(test_features, test_labels, batch_size=BATCH_SIZE, verbose=0) for name, value in zip(model.metrics_names, baseline_results): print(name, ': ', value) print() plot_cm(test_labels, test_predictions_baseline)
loss : 0.004311129764585171 tp : 68.0 fp : 11.0 tn : 56846.0 fn : 37.0 accuracy : 0.9991573 precision : 0.8607595 recall : 0.64761907 auc : 0.92348534 Legitimate Transactions Detected (True Negatives): 56846 Legitimate Transactions Incorrectly Detected (False Positives): 11 Fraudulent Transactions Missed (False Negatives): 37 Fraudulent Transactions Detected (True Positives): 68 Total Fraudulent Transactions: 105
モデルが総てを完全に予測したのであれば、これは 対角行列 になるでしょう、そこでは (正しくない予測を示す) 主対角線から離れた値はゼロになるでしょう。この場合行列は比較的少ない false positive を持つことを示します、これは誤ってフラグを建てられた比較的少ない正当なトランザクションがあったことを意味します。けれども、false positive の数が増加しているコストにもかかわらず、貴方はより少ない false negative でさえも望みがちでしょう。このトレードオフは好ましいかもしれません、何故ならば false positive はカスタマーにカードの動きを検証することを要求するための電子メールが送られることを引き起こすかもしれない一方で、false negative は詐欺的なトランザクションが通り抜けることを可能にするからです。
ROC をプロットする
今は ROC をプロットします。このプロットは有用です、何故ならばそれはひと目で、単に出力閾値を調整することによりモデルが到達できるパフォーマンスの範囲を示すからです。
def plot_roc(name, labels, predictions, **kwargs): fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions) plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs) plt.xlabel('False positives [%]') plt.ylabel('True positives [%]') plt.xlim([-0.5,20]) plt.ylim([80,100.5]) plt.grid(True) ax = plt.gca() ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0]) plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--') plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fc6ba4dcbe0>
それは precision は比較的高いように見えますが、recall と ROC 曲線下の面積 (AUC) は貴方が好むほどには高くありません。分類器は precision と recall の両者を最大化しようとするときしばしば課題に直面します、これは不均衡なデータセットで作業するときに特に真です。関心がある問題のコンテキストで異なるタイプのエラーのコストを考えることは重要です。この例では、false negative (詐欺的なトランザクションが取り逃がされる) がファイナンシャルコストを持つかもしれない一方で、false positive (トランザクションは誤って詐欺的とフラグが立てられる) はユーザの幸せを減らすかもしれません。
クラス重み
クラス重みを計算する
ゴールは詐欺的トランザクションを識別することですが、作業するためのそれらの非常に多くの positive サンプルを持ちませんので、分類器に利用可能な少ないサンプルに大きく重み付けすることを望むでしょう。Keras にパラメータを通して各クラスのための重みを渡すことによりこれを行なうことができます。これらはモデルに under-represented クラスからのサンプルに「より大きな注意を払う」ことをさせます。
# Scaling by total/2 helps keep the loss to a similar magnitude. # The sum of the weights of all examples stays the same. weight_for_0 = (1 / neg)*(total)/2.0 weight_for_1 = (1 / pos)*(total)/2.0 class_weight = {0: weight_for_0, 1: weight_for_1} print('Weight for class 0: {:.2f}'.format(weight_for_0)) print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50 Weight for class 1: 289.44
クラス重みでモデルを訓練する
今は予測にどのように影響するかを見るためにクラス重みでモデルを再訓練して評価してみます。
★ Note: class_weights の使用は損失の範囲を変更します。これは optimizer に依拠して訓練の安定性に影響するかもしれません。optimizers.SGD のような、ステップサイズが勾配の大きさに依拠する optimizer は失敗するかもしれません。ここで使用される optimizer、optimizers.Adam はスケーリング変更の影響を受けません。また重み付けゆえに、トータル損失は 2 つのモデル間で比較可能ではないことに注意してください。
weighted_model = make_model() weighted_model.load_weights(initial_weights) weighted_history = weighted_model.fit( train_features, train_labels, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks = [early_stopping], validation_data=(val_features, val_labels), # The class weights go here class_weight=class_weight)
Train on 182276 samples, validate on 45569 samples Epoch 1/100 182276/182276 [==============================] - 3s 17us/sample - loss: 2.3320 - tp: 49.0000 - fp: 550.0000 - tn: 181421.0000 - fn: 256.0000 - accuracy: 0.9956 - precision: 0.0818 - recall: 0.1607 - auc: 0.6815 - val_loss: 0.7987 - val_tp: 40.0000 - val_fp: 14.0000 - val_tn: 45473.0000 - val_fn: 42.0000 - val_accuracy: 0.9988 - val_precision: 0.7407 - val_recall: 0.4878 - val_auc: 0.9440 Epoch 2/100 182276/182276 [==============================] - 1s 3us/sample - loss: 1.0274 - tp: 151.0000 - fp: 1122.0000 - tn: 180849.0000 - fn: 154.0000 - accuracy: 0.9930 - precision: 0.1186 - recall: 0.4951 - auc: 0.8471 - val_loss: 0.4155 - val_tp: 65.0000 - val_fp: 28.0000 - val_tn: 45459.0000 - val_fn: 17.0000 - val_accuracy: 0.9990 - val_precision: 0.6989 - val_recall: 0.7927 - val_auc: 0.9664 Epoch 3/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.6664 - tp: 196.0000 - fp: 1788.0000 - tn: 180183.0000 - fn: 109.0000 - accuracy: 0.9896 - precision: 0.0988 - recall: 0.6426 - auc: 0.8987 - val_loss: 0.3144 - val_tp: 69.0000 - val_fp: 58.0000 - val_tn: 45429.0000 - val_fn: 13.0000 - val_accuracy: 0.9984 - val_precision: 0.5433 - val_recall: 0.8415 - val_auc: 0.9748 Epoch 4/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.4983 - tp: 221.0000 - fp: 2856.0000 - tn: 179115.0000 - fn: 84.0000 - accuracy: 0.9839 - precision: 0.0718 - recall: 0.7246 - auc: 0.9235 - val_loss: 0.2678 - val_tp: 70.0000 - val_fp: 109.0000 - val_tn: 45378.0000 - val_fn: 12.0000 - val_accuracy: 0.9973 - val_precision: 0.3911 - val_recall: 0.8537 - val_auc: 0.9746 Epoch 5/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.4914 - tp: 228.0000 - fp: 3722.0000 - tn: 178249.0000 - fn: 77.0000 - accuracy: 0.9792 - precision: 0.0577 - recall: 0.7475 - auc: 0.9132 - val_loss: 0.2342 - val_tp: 71.0000 - val_fp: 215.0000 - val_tn: 45272.0000 - val_fn: 11.0000 - val_accuracy: 0.9950 - val_precision: 0.2483 - val_recall: 0.8659 - val_auc: 0.9767 Epoch 6/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.4245 - tp: 240.0000 - fp: 4810.0000 - tn: 177161.0000 - fn: 65.0000 - accuracy: 0.9733 - precision: 0.0475 - recall: 0.7869 - auc: 0.9247 - val_loss: 0.2168 - val_tp: 71.0000 - val_fp: 414.0000 - val_tn: 45073.0000 - val_fn: 11.0000 - val_accuracy: 0.9907 - val_precision: 0.1464 - val_recall: 0.8659 - val_auc: 0.9772 Epoch 7/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.3187 - tp: 256.0000 - fp: 5810.0000 - tn: 176161.0000 - fn: 49.0000 - accuracy: 0.9679 - precision: 0.0422 - recall: 0.8393 - auc: 0.9448 - val_loss: 0.2047 - val_tp: 71.0000 - val_fp: 535.0000 - val_tn: 44952.0000 - val_fn: 11.0000 - val_accuracy: 0.9880 - val_precision: 0.1172 - val_recall: 0.8659 - val_auc: 0.9772 Epoch 8/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.3666 - tp: 256.0000 - fp: 6358.0000 - tn: 175613.0000 - fn: 49.0000 - accuracy: 0.9649 - precision: 0.0387 - recall: 0.8393 - auc: 0.9320 - val_loss: 0.2008 - val_tp: 71.0000 - val_fp: 595.0000 - val_tn: 44892.0000 - val_fn: 11.0000 - val_accuracy: 0.9867 - val_precision: 0.1066 - val_recall: 0.8659 - val_auc: 0.9779 Epoch 9/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.3080 - tp: 254.0000 - fp: 6611.0000 - tn: 175360.0000 - fn: 51.0000 - accuracy: 0.9635 - precision: 0.0370 - recall: 0.8328 - auc: 0.9521 - val_loss: 0.1941 - val_tp: 71.0000 - val_fp: 685.0000 - val_tn: 44802.0000 - val_fn: 11.0000 - val_accuracy: 0.9847 - val_precision: 0.0939 - val_recall: 0.8659 - val_auc: 0.9781 Epoch 10/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2565 - tp: 263.0000 - fp: 7082.0000 - tn: 174889.0000 - fn: 42.0000 - accuracy: 0.9609 - precision: 0.0358 - recall: 0.8623 - auc: 0.9594 - val_loss: 0.1907 - val_tp: 73.0000 - val_fp: 746.0000 - val_tn: 44741.0000 - val_fn: 9.0000 - val_accuracy: 0.9834 - val_precision: 0.0891 - val_recall: 0.8902 - val_auc: 0.9790 Epoch 11/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2785 - tp: 263.0000 - fp: 7027.0000 - tn: 174944.0000 - fn: 42.0000 - accuracy: 0.9612 - precision: 0.0361 - recall: 0.8623 - auc: 0.9541 - val_loss: 0.1873 - val_tp: 73.0000 - val_fp: 790.0000 - val_tn: 44697.0000 - val_fn: 9.0000 - val_accuracy: 0.9825 - val_precision: 0.0846 - val_recall: 0.8902 - val_auc: 0.9790 Epoch 12/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.3415 - tp: 256.0000 - fp: 7570.0000 - tn: 174401.0000 - fn: 49.0000 - accuracy: 0.9582 - precision: 0.0327 - recall: 0.8393 - auc: 0.9374 - val_loss: 0.1861 - val_tp: 73.0000 - val_fp: 863.0000 - val_tn: 44624.0000 - val_fn: 9.0000 - val_accuracy: 0.9809 - val_precision: 0.0780 - val_recall: 0.8902 - val_auc: 0.9792 Epoch 13/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2324 - tp: 272.0000 - fp: 7138.0000 - tn: 174833.0000 - fn: 33.0000 - accuracy: 0.9607 - precision: 0.0367 - recall: 0.8918 - auc: 0.9638 - val_loss: 0.1861 - val_tp: 73.0000 - val_fp: 821.0000 - val_tn: 44666.0000 - val_fn: 9.0000 - val_accuracy: 0.9818 - val_precision: 0.0817 - val_recall: 0.8902 - val_auc: 0.9800 Epoch 14/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2341 - tp: 271.0000 - fp: 7125.0000 - tn: 174846.0000 - fn: 34.0000 - accuracy: 0.9607 - precision: 0.0366 - recall: 0.8885 - auc: 0.9662 - val_loss: 0.1851 - val_tp: 74.0000 - val_fp: 830.0000 - val_tn: 44657.0000 - val_fn: 8.0000 - val_accuracy: 0.9816 - val_precision: 0.0819 - val_recall: 0.9024 - val_auc: 0.9800 Epoch 15/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2603 - tp: 269.0000 - fp: 6786.0000 - tn: 175185.0000 - fn: 36.0000 - accuracy: 0.9626 - precision: 0.0381 - recall: 0.8820 - auc: 0.9581 - val_loss: 0.1843 - val_tp: 74.0000 - val_fp: 824.0000 - val_tn: 44663.0000 - val_fn: 8.0000 - val_accuracy: 0.9817 - val_precision: 0.0824 - val_recall: 0.9024 - val_auc: 0.9806 Epoch 16/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2448 - tp: 267.0000 - fp: 6749.0000 - tn: 175222.0000 - fn: 38.0000 - accuracy: 0.9628 - precision: 0.0381 - recall: 0.8754 - auc: 0.9605 - val_loss: 0.1836 - val_tp: 74.0000 - val_fp: 810.0000 - val_tn: 44677.0000 - val_fn: 8.0000 - val_accuracy: 0.9820 - val_precision: 0.0837 - val_recall: 0.9024 - val_auc: 0.9812 Epoch 17/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2779 - tp: 263.0000 - fp: 7040.0000 - tn: 174931.0000 - fn: 42.0000 - accuracy: 0.9611 - precision: 0.0360 - recall: 0.8623 - auc: 0.9531 - val_loss: 0.1799 - val_tp: 74.0000 - val_fp: 866.0000 - val_tn: 44621.0000 - val_fn: 8.0000 - val_accuracy: 0.9808 - val_precision: 0.0787 - val_recall: 0.9024 - val_auc: 0.9814 Epoch 18/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2188 - tp: 272.0000 - fp: 7279.0000 - tn: 174692.0000 - fn: 33.0000 - accuracy: 0.9599 - precision: 0.0360 - recall: 0.8918 - auc: 0.9712 - val_loss: 0.1791 - val_tp: 74.0000 - val_fp: 902.0000 - val_tn: 44585.0000 - val_fn: 8.0000 - val_accuracy: 0.9800 - val_precision: 0.0758 - val_recall: 0.9024 - val_auc: 0.9816 Epoch 19/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2562 - tp: 268.0000 - fp: 6972.0000 - tn: 174999.0000 - fn: 37.0000 - accuracy: 0.9615 - precision: 0.0370 - recall: 0.8787 - auc: 0.9599 - val_loss: 0.1789 - val_tp: 74.0000 - val_fp: 864.0000 - val_tn: 44623.0000 - val_fn: 8.0000 - val_accuracy: 0.9809 - val_precision: 0.0789 - val_recall: 0.9024 - val_auc: 0.9817 Epoch 20/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2241 - tp: 274.0000 - fp: 6996.0000 - tn: 174975.0000 - fn: 31.0000 - accuracy: 0.9614 - precision: 0.0377 - recall: 0.8984 - auc: 0.9671 - val_loss: 0.1800 - val_tp: 74.0000 - val_fp: 862.0000 - val_tn: 44625.0000 - val_fn: 8.0000 - val_accuracy: 0.9809 - val_precision: 0.0791 - val_recall: 0.9024 - val_auc: 0.9818 Epoch 21/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2015 - tp: 273.0000 - fp: 6637.0000 - tn: 175334.0000 - fn: 32.0000 - accuracy: 0.9634 - precision: 0.0395 - recall: 0.8951 - auc: 0.9752 - val_loss: 0.1765 - val_tp: 74.0000 - val_fp: 855.0000 - val_tn: 44632.0000 - val_fn: 8.0000 - val_accuracy: 0.9811 - val_precision: 0.0797 - val_recall: 0.9024 - val_auc: 0.9824 Epoch 22/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2209 - tp: 271.0000 - fp: 6428.0000 - tn: 175543.0000 - fn: 34.0000 - accuracy: 0.9645 - precision: 0.0405 - recall: 0.8885 - auc: 0.9674 - val_loss: 0.1760 - val_tp: 74.0000 - val_fp: 835.0000 - val_tn: 44652.0000 - val_fn: 8.0000 - val_accuracy: 0.9815 - val_precision: 0.0814 - val_recall: 0.9024 - val_auc: 0.9826 Epoch 23/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.1654 - tp: 282.0000 - fp: 6397.0000 - tn: 175574.0000 - fn: 23.0000 - accuracy: 0.9648 - precision: 0.0422 - recall: 0.9246 - auc: 0.9817 - val_loss: 0.1763 - val_tp: 74.0000 - val_fp: 773.0000 - val_tn: 44714.0000 - val_fn: 8.0000 - val_accuracy: 0.9829 - val_precision: 0.0874 - val_recall: 0.9024 - val_auc: 0.9827 Epoch 24/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2263 - tp: 271.0000 - fp: 6441.0000 - tn: 175530.0000 - fn: 34.0000 - accuracy: 0.9645 - precision: 0.0404 - recall: 0.8885 - auc: 0.9670 - val_loss: 0.1740 - val_tp: 74.0000 - val_fp: 802.0000 - val_tn: 44685.0000 - val_fn: 8.0000 - val_accuracy: 0.9822 - val_precision: 0.0845 - val_recall: 0.9024 - val_auc: 0.9829 Epoch 25/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2225 - tp: 271.0000 - fp: 6317.0000 - tn: 175654.0000 - fn: 34.0000 - accuracy: 0.9652 - precision: 0.0411 - recall: 0.8885 - auc: 0.9678 - val_loss: 0.1715 - val_tp: 74.0000 - val_fp: 831.0000 - val_tn: 44656.0000 - val_fn: 8.0000 - val_accuracy: 0.9816 - val_precision: 0.0818 - val_recall: 0.9024 - val_auc: 0.9808 Epoch 26/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2039 - tp: 279.0000 - fp: 6674.0000 - tn: 175297.0000 - fn: 26.0000 - accuracy: 0.9632 - precision: 0.0401 - recall: 0.9148 - auc: 0.9691 - val_loss: 0.1716 - val_tp: 74.0000 - val_fp: 798.0000 - val_tn: 44689.0000 - val_fn: 8.0000 - val_accuracy: 0.9823 - val_precision: 0.0849 - val_recall: 0.9024 - val_auc: 0.9833 Epoch 27/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2030 - tp: 276.0000 - fp: 6050.0000 - tn: 175921.0000 - fn: 29.0000 - accuracy: 0.9666 - precision: 0.0436 - recall: 0.9049 - auc: 0.9722 - val_loss: 0.1748 - val_tp: 74.0000 - val_fp: 722.0000 - val_tn: 44765.0000 - val_fn: 8.0000 - val_accuracy: 0.9840 - val_precision: 0.0930 - val_recall: 0.9024 - val_auc: 0.9813 Epoch 28/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.2273 - tp: 270.0000 - fp: 6034.0000 - tn: 175937.0000 - fn: 35.0000 - accuracy: 0.9667 - precision: 0.0428 - recall: 0.8852 - auc: 0.9656 - val_loss: 0.1730 - val_tp: 74.0000 - val_fp: 748.0000 - val_tn: 44739.0000 - val_fn: 8.0000 - val_accuracy: 0.9834 - val_precision: 0.0900 - val_recall: 0.9024 - val_auc: 0.9814 Epoch 29/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.1998 - tp: 274.0000 - fp: 5768.0000 - tn: 176203.0000 - fn: 31.0000 - accuracy: 0.9682 - precision: 0.0453 - recall: 0.8984 - auc: 0.9758 - val_loss: 0.1718 - val_tp: 74.0000 - val_fp: 783.0000 - val_tn: 44704.0000 - val_fn: 8.0000 - val_accuracy: 0.9826 - val_precision: 0.0863 - val_recall: 0.9024 - val_auc: 0.9813 Epoch 30/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.1844 - tp: 277.0000 - fp: 5926.0000 - tn: 176045.0000 - fn: 28.0000 - accuracy: 0.9673 - precision: 0.0447 - recall: 0.9082 - auc: 0.9767 - val_loss: 0.1700 - val_tp: 74.0000 - val_fp: 783.0000 - val_tn: 44704.0000 - val_fn: 8.0000 - val_accuracy: 0.9826 - val_precision: 0.0863 - val_recall: 0.9024 - val_auc: 0.9815 Epoch 31/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.1811 - tp: 276.0000 - fp: 6118.0000 - tn: 175853.0000 - fn: 29.0000 - accuracy: 0.9663 - precision: 0.0432 - recall: 0.9049 - auc: 0.9785 - val_loss: 0.1679 - val_tp: 74.0000 - val_fp: 791.0000 - val_tn: 44696.0000 - val_fn: 8.0000 - val_accuracy: 0.9825 - val_precision: 0.0855 - val_recall: 0.9024 - val_auc: 0.9816 Epoch 32/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.1889 - tp: 276.0000 - fp: 6230.0000 - tn: 175741.0000 - fn: 29.0000 - accuracy: 0.9657 - precision: 0.0424 - recall: 0.9049 - auc: 0.9774 - val_loss: 0.1677 - val_tp: 74.0000 - val_fp: 777.0000 - val_tn: 44710.0000 - val_fn: 8.0000 - val_accuracy: 0.9828 - val_precision: 0.0870 - val_recall: 0.9024 - val_auc: 0.9817 Epoch 33/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.1986 - tp: 280.0000 - fp: 5610.0000 - tn: 176361.0000 - fn: 25.0000 - accuracy: 0.9691 - precision: 0.0475 - recall: 0.9180 - auc: 0.9717 - val_loss: 0.1702 - val_tp: 74.0000 - val_fp: 679.0000 - val_tn: 44808.0000 - val_fn: 8.0000 - val_accuracy: 0.9849 - val_precision: 0.0983 - val_recall: 0.9024 - val_auc: 0.9817 Epoch 34/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.1848 - tp: 277.0000 - fp: 5707.0000 - tn: 176264.0000 - fn: 28.0000 - accuracy: 0.9685 - precision: 0.0463 - recall: 0.9082 - auc: 0.9729 - val_loss: 0.1689 - val_tp: 74.0000 - val_fp: 713.0000 - val_tn: 44774.0000 - val_fn: 8.0000 - val_accuracy: 0.9842 - val_precision: 0.0940 - val_recall: 0.9024 - val_auc: 0.9818 Epoch 35/100 182276/182276 [==============================] - 1s 3us/sample - loss: 0.1910 - tp: 273.0000 - fp: 5492.0000 - tn: 176479.0000 - fn: 32.0000 - accuracy: 0.9697 - precision: 0.0474 - recall: 0.8951 - auc: 0.9747 - val_loss: 0.1665 - val_tp: 74.0000 - val_fp: 748.0000 - val_tn: 44739.0000 - val_fn: 8.0000 - val_accuracy: 0.9834 - val_precision: 0.0900 - val_recall: 0.9024 - val_auc: 0.9820 Epoch 36/100 163840/182276 [=========================>....] - ETA: 0s - loss: 0.1844 - tp: 245.0000 - fp: 5415.0000 - tn: 158154.0000 - fn: 26.0000 - accuracy: 0.9668 - precision: 0.0433 - recall: 0.9041 - auc: 0.9783Restoring model weights from the end of the best epoch. 182276/182276 [==============================] - 1s 3us/sample - loss: 0.1787 - tp: 277.0000 - fp: 6024.0000 - tn: 175947.0000 - fn: 28.0000 - accuracy: 0.9668 - precision: 0.0440 - recall: 0.9082 - auc: 0.9795 - val_loss: 0.1624 - val_tp: 74.0000 - val_fp: 790.0000 - val_tn: 44697.0000 - val_fn: 8.0000 - val_accuracy: 0.9825 - val_precision: 0.0856 - val_recall: 0.9024 - val_auc: 0.9821 Epoch 00036: early stopping
訓練履歴を確認する
plot_metrics(weighted_history)
メトリクスを評価する
train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE) test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
weighted_results = weighted_model.evaluate(test_features, test_labels, batch_size=BATCH_SIZE, verbose=0) for name, value in zip(weighted_model.metrics_names, weighted_results): print(name, ': ', value) print() plot_cm(test_labels, test_predictions_weighted)
loss : 0.07607786336474939 tp : 92.0 fp : 937.0 tn : 55920.0 fn : 13.0 accuracy : 0.9833222 precision : 0.08940719 recall : 0.8761905 auc : 0.9757879 Legitimate Transactions Detected (True Negatives): 55920 Legitimate Transactions Incorrectly Detected (False Positives): 937 Fraudulent Transactions Missed (False Negatives): 13 Fraudulent Transactions Detected (True Positives): 92 Total Fraudulent Transactions: 105
ここでクラス重みで accuracy と precision はより低いことを見て取れます、何故ならばより多くの false positive があるからです、しかし反対に recall と AUC はより高いです、何故ならばモデルはまたより多くの true positive を見つけたからです。より低い accuracy を持つにもかかわらず、このモデルはより高い recall を持ちます (そしてより多くの詐欺的トランザクションを識別します)。もちろん、両者のタイプのエラーへのコストはあります (非常に多くの正当なトランザクションを詐欺的としてフラグを立ててユーザを困らせることもまた望まないでしょう)。貴方のアプリケーションのためにこれらの異なるタイプのエラーの間のトレードオフを注意深く考えてください。
ROC をプロットする
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0]) plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--') plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1]) plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--') plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fc6b41b67f0>
オーバーサンプリング
少数 (= minotiry) クラスをオーバーサンプリングする
関連するアプローチは少数クラスをオーバーサンプリングすることによりデータセットを再サンプリングすることです。
pos_features = train_features[bool_train_labels] neg_features = train_features[~bool_train_labels] pos_labels = train_labels[bool_train_labels] neg_labels = train_labels[~bool_train_labels]
NumPy を使用する
positive サンプルからランダムインデックスの適切な数を選択してデータセットを手動で平衡にすることができます :
ids = np.arange(len(pos_features)) choices = np.random.choice(ids, len(neg_features)) res_pos_features = pos_features[choices] res_pos_labels = pos_labels[choices] res_pos_features.shape
(181971, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0) resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0) order = np.arange(len(resampled_labels)) np.random.shuffle(order) resampled_features = resampled_features[order] resampled_labels = resampled_labels[order] resampled_features.shape
(363942, 29)
tf.data を使用する
tf.data を使用している場合、バランスの取れたサンプルを生成する最も容易な方法は positive と negative データセットから始めて、それらをマージすることです。より多くのサンプルについては tf.data ガイド を見てください。
BUFFER_SIZE = 100000 def make_ds(features, labels): ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache() ds = ds.shuffle(BUFFER_SIZE).repeat() return ds pos_ds = make_ds(pos_features, pos_labels) neg_ds = make_ds(neg_features, neg_labels)
各データセットは (feature, label) ペアを生成します :
for features, label in pos_ds.take(1): print("Features:\n", features.numpy()) print() print("Label: ", label.numpy())
Features: [-2.17940656 1.76763197 -3.92031926 4.88344634 -2.52559973 -0.87402282 -5. 2.39895859 -2.28526198 -4.70820898 5. -5. 2.80944748 -5. -2.29981762 -5. -5. -5. 0.62439703 -0.30217952 2.21481335 2.15973127 -0.93161994 -0.09709966 -3.50901907 -0.15144646 0.34072986 -1.87496516 -0.45141544] Label: 1
experimental.sample_from_datasets を使用して 2 つを一緒にマージします :
resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5]) resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1): print(label.numpy().mean())
0.51220703125
このデータセットを使用するために、エポック毎のステップ数を必要とします。
この場合の「エポック」の定義は曖昧です。例えばそれは各 negative サンプルを一度見るのに必要なバッチ数としましょう :
resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE) resampled_steps_per_epoch
278.0
オーバーサンプリングされたデータ上で訓練する
今はこれらのメソッドがどのように比較できるかを見るためにクラス重みを使用する代わりに再サンプリングされたデータセットでモデルを訓練してみます。
★ Note: positive サンプルを複製することによりデータはバランスが取られましたので、総計のデータセットサイズは大きくなり、そして各エポックはより多くの訓練ステップのために実行されます。
resampled_model = make_model() resampled_model.load_weights(initial_weights) # Reset the bias to zero, since this dataset is balanced. output_layer = resampled_model.layers[-1] output_layer.bias.assign([0]) val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache() val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) resampled_history = resampled_model.fit( resampled_ds, epochs=EPOCHS, steps_per_epoch=resampled_steps_per_epoch, callbacks = [early_stopping], validation_data=val_ds)
Train for 278.0 steps, validate for 23 steps Epoch 1/100 278/278 [==============================] - 11s 38ms/step - loss: 0.5141 - tp: 235679.0000 - fp: 89290.0000 - tn: 195069.0000 - fn: 49306.0000 - accuracy: 0.7566 - precision: 0.7252 - recall: 0.8270 - auc: 0.8606 - val_loss: 0.2508 - val_tp: 74.0000 - val_fp: 1342.0000 - val_tn: 44145.0000 - val_fn: 8.0000 - val_accuracy: 0.9704 - val_precision: 0.0523 - val_recall: 0.9024 - val_auc: 0.9747 Epoch 2/100 278/278 [==============================] - 8s 30ms/step - loss: 0.2221 - tp: 256642.0000 - fp: 18273.0000 - tn: 266245.0000 - fn: 28184.0000 - accuracy: 0.9184 - precision: 0.9335 - recall: 0.9010 - auc: 0.9667 - val_loss: 0.1341 - val_tp: 74.0000 - val_fp: 976.0000 - val_tn: 44511.0000 - val_fn: 8.0000 - val_accuracy: 0.9784 - val_precision: 0.0705 - val_recall: 0.9024 - val_auc: 0.9772 Epoch 3/100 278/278 [==============================] - 8s 29ms/step - loss: 0.1627 - tp: 263207.0000 - fp: 10666.0000 - tn: 273652.0000 - fn: 21819.0000 - accuracy: 0.9429 - precision: 0.9611 - recall: 0.9234 - auc: 0.9824 - val_loss: 0.0934 - val_tp: 73.0000 - val_fp: 822.0000 - val_tn: 44665.0000 - val_fn: 9.0000 - val_accuracy: 0.9818 - val_precision: 0.0816 - val_recall: 0.8902 - val_auc: 0.9768 Epoch 4/100 278/278 [==============================] - 8s 30ms/step - loss: 0.1373 - tp: 266123.0000 - fp: 8945.0000 - tn: 275364.0000 - fn: 18912.0000 - accuracy: 0.9511 - precision: 0.9675 - recall: 0.9337 - auc: 0.9876 - val_loss: 0.0773 - val_tp: 74.0000 - val_fp: 807.0000 - val_tn: 44680.0000 - val_fn: 8.0000 - val_accuracy: 0.9821 - val_precision: 0.0840 - val_recall: 0.9024 - val_auc: 0.9776 Epoch 5/100 278/278 [==============================] - 8s 30ms/step - loss: 0.1242 - tp: 266967.0000 - fp: 8276.0000 - tn: 276477.0000 - fn: 17624.0000 - accuracy: 0.9545 - precision: 0.9699 - recall: 0.9381 - auc: 0.9902 - val_loss: 0.0669 - val_tp: 74.0000 - val_fp: 771.0000 - val_tn: 44716.0000 - val_fn: 8.0000 - val_accuracy: 0.9829 - val_precision: 0.0876 - val_recall: 0.9024 - val_auc: 0.9766 Epoch 6/100 278/278 [==============================] - 8s 30ms/step - loss: 0.1150 - tp: 268272.0000 - fp: 7786.0000 - tn: 276538.0000 - fn: 16748.0000 - accuracy: 0.9569 - precision: 0.9718 - recall: 0.9412 - auc: 0.9918 - val_loss: 0.0604 - val_tp: 74.0000 - val_fp: 727.0000 - val_tn: 44760.0000 - val_fn: 8.0000 - val_accuracy: 0.9839 - val_precision: 0.0924 - val_recall: 0.9024 - val_auc: 0.9778 Epoch 7/100 278/278 [==============================] - 8s 30ms/step - loss: 0.1076 - tp: 268596.0000 - fp: 7335.0000 - tn: 277522.0000 - fn: 15891.0000 - accuracy: 0.9592 - precision: 0.9734 - recall: 0.9441 - auc: 0.9930 - val_loss: 0.0548 - val_tp: 73.0000 - val_fp: 668.0000 - val_tn: 44819.0000 - val_fn: 9.0000 - val_accuracy: 0.9851 - val_precision: 0.0985 - val_recall: 0.8902 - val_auc: 0.9786 Epoch 8/100 278/278 [==============================] - 9s 31ms/step - loss: 0.1011 - tp: 269934.0000 - fp: 7025.0000 - tn: 277658.0000 - fn: 14727.0000 - accuracy: 0.9618 - precision: 0.9746 - recall: 0.9483 - auc: 0.9940 - val_loss: 0.0505 - val_tp: 73.0000 - val_fp: 621.0000 - val_tn: 44866.0000 - val_fn: 9.0000 - val_accuracy: 0.9862 - val_precision: 0.1052 - val_recall: 0.8902 - val_auc: 0.9792 Epoch 9/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0959 - tp: 270679.0000 - fp: 6711.0000 - tn: 277837.0000 - fn: 14117.0000 - accuracy: 0.9634 - precision: 0.9758 - recall: 0.9504 - auc: 0.9948 - val_loss: 0.0483 - val_tp: 73.0000 - val_fp: 603.0000 - val_tn: 44884.0000 - val_fn: 9.0000 - val_accuracy: 0.9866 - val_precision: 0.1080 - val_recall: 0.8902 - val_auc: 0.9794 Epoch 10/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0913 - tp: 271592.0000 - fp: 6656.0000 - tn: 277917.0000 - fn: 13179.0000 - accuracy: 0.9652 - precision: 0.9761 - recall: 0.9537 - auc: 0.9953 - val_loss: 0.0440 - val_tp: 73.0000 - val_fp: 528.0000 - val_tn: 44959.0000 - val_fn: 9.0000 - val_accuracy: 0.9882 - val_precision: 0.1215 - val_recall: 0.8902 - val_auc: 0.9795 Epoch 11/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0875 - tp: 272110.0000 - fp: 6604.0000 - tn: 278063.0000 - fn: 12567.0000 - accuracy: 0.9663 - precision: 0.9763 - recall: 0.9559 - auc: 0.9957 - val_loss: 0.0408 - val_tp: 73.0000 - val_fp: 491.0000 - val_tn: 44996.0000 - val_fn: 9.0000 - val_accuracy: 0.9890 - val_precision: 0.1294 - val_recall: 0.8902 - val_auc: 0.9799 Epoch 12/100 278/278 [==============================] - 8s 29ms/step - loss: 0.0832 - tp: 272899.0000 - fp: 6475.0000 - tn: 278107.0000 - fn: 11863.0000 - accuracy: 0.9678 - precision: 0.9768 - recall: 0.9583 - auc: 0.9962 - val_loss: 0.0392 - val_tp: 73.0000 - val_fp: 481.0000 - val_tn: 45006.0000 - val_fn: 9.0000 - val_accuracy: 0.9892 - val_precision: 0.1318 - val_recall: 0.8902 - val_auc: 0.9803 Epoch 13/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0795 - tp: 274008.0000 - fp: 6417.0000 - tn: 278228.0000 - fn: 10691.0000 - accuracy: 0.9700 - precision: 0.9771 - recall: 0.9624 - auc: 0.9966 - val_loss: 0.0358 - val_tp: 73.0000 - val_fp: 442.0000 - val_tn: 45045.0000 - val_fn: 9.0000 - val_accuracy: 0.9901 - val_precision: 0.1417 - val_recall: 0.8902 - val_auc: 0.9806 Epoch 14/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0747 - tp: 275586.0000 - fp: 6254.0000 - tn: 277642.0000 - fn: 9862.0000 - accuracy: 0.9717 - precision: 0.9778 - recall: 0.9655 - auc: 0.9970 - val_loss: 0.0336 - val_tp: 73.0000 - val_fp: 432.0000 - val_tn: 45055.0000 - val_fn: 9.0000 - val_accuracy: 0.9903 - val_precision: 0.1446 - val_recall: 0.8902 - val_auc: 0.9808 Epoch 15/100 278/278 [==============================] - 8s 31ms/step - loss: 0.0719 - tp: 274707.0000 - fp: 6051.0000 - tn: 279108.0000 - fn: 9478.0000 - accuracy: 0.9727 - precision: 0.9784 - recall: 0.9666 - auc: 0.9971 - val_loss: 0.0306 - val_tp: 73.0000 - val_fp: 407.0000 - val_tn: 45080.0000 - val_fn: 9.0000 - val_accuracy: 0.9909 - val_precision: 0.1521 - val_recall: 0.8902 - val_auc: 0.9808 Epoch 16/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0697 - tp: 275176.0000 - fp: 6024.0000 - tn: 278945.0000 - fn: 9199.0000 - accuracy: 0.9733 - precision: 0.9786 - recall: 0.9677 - auc: 0.9972 - val_loss: 0.0305 - val_tp: 73.0000 - val_fp: 400.0000 - val_tn: 45087.0000 - val_fn: 9.0000 - val_accuracy: 0.9910 - val_precision: 0.1543 - val_recall: 0.8902 - val_auc: 0.9808 Epoch 17/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0681 - tp: 275922.0000 - fp: 6047.0000 - tn: 278615.0000 - fn: 8760.0000 - accuracy: 0.9740 - precision: 0.9786 - recall: 0.9692 - auc: 0.9973 - val_loss: 0.0277 - val_tp: 73.0000 - val_fp: 368.0000 - val_tn: 45119.0000 - val_fn: 9.0000 - val_accuracy: 0.9917 - val_precision: 0.1655 - val_recall: 0.8902 - val_auc: 0.9808 Epoch 18/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0671 - tp: 276085.0000 - fp: 6258.0000 - tn: 278545.0000 - fn: 8456.0000 - accuracy: 0.9742 - precision: 0.9778 - recall: 0.9703 - auc: 0.9974 - val_loss: 0.0264 - val_tp: 73.0000 - val_fp: 351.0000 - val_tn: 45136.0000 - val_fn: 9.0000 - val_accuracy: 0.9921 - val_precision: 0.1722 - val_recall: 0.8902 - val_auc: 0.9752 Epoch 19/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0652 - tp: 276364.0000 - fp: 6108.0000 - tn: 278761.0000 - fn: 8111.0000 - accuracy: 0.9750 - precision: 0.9784 - recall: 0.9715 - auc: 0.9975 - val_loss: 0.0263 - val_tp: 73.0000 - val_fp: 354.0000 - val_tn: 45133.0000 - val_fn: 9.0000 - val_accuracy: 0.9920 - val_precision: 0.1710 - val_recall: 0.8902 - val_auc: 0.9699 Epoch 20/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0640 - tp: 277072.0000 - fp: 6092.0000 - tn: 278461.0000 - fn: 7719.0000 - accuracy: 0.9757 - precision: 0.9785 - recall: 0.9729 - auc: 0.9975 - val_loss: 0.0247 - val_tp: 73.0000 - val_fp: 333.0000 - val_tn: 45154.0000 - val_fn: 9.0000 - val_accuracy: 0.9925 - val_precision: 0.1798 - val_recall: 0.8902 - val_auc: 0.9645 Epoch 21/100 278/278 [==============================] - 8s 29ms/step - loss: 0.0635 - tp: 277050.0000 - fp: 6229.0000 - tn: 278537.0000 - fn: 7528.0000 - accuracy: 0.9758 - precision: 0.9780 - recall: 0.9735 - auc: 0.9976 - val_loss: 0.0247 - val_tp: 73.0000 - val_fp: 338.0000 - val_tn: 45149.0000 - val_fn: 9.0000 - val_accuracy: 0.9924 - val_precision: 0.1776 - val_recall: 0.8902 - val_auc: 0.9590 Epoch 22/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0619 - tp: 276909.0000 - fp: 6199.0000 - tn: 278826.0000 - fn: 7410.0000 - accuracy: 0.9761 - precision: 0.9781 - recall: 0.9739 - auc: 0.9977 - val_loss: 0.0242 - val_tp: 73.0000 - val_fp: 328.0000 - val_tn: 45159.0000 - val_fn: 9.0000 - val_accuracy: 0.9926 - val_precision: 0.1820 - val_recall: 0.8902 - val_auc: 0.9534 Epoch 23/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0610 - tp: 278124.0000 - fp: 6195.0000 - tn: 278074.0000 - fn: 6951.0000 - accuracy: 0.9769 - precision: 0.9782 - recall: 0.9756 - auc: 0.9977 - val_loss: 0.0231 - val_tp: 73.0000 - val_fp: 311.0000 - val_tn: 45176.0000 - val_fn: 9.0000 - val_accuracy: 0.9930 - val_precision: 0.1901 - val_recall: 0.8902 - val_auc: 0.9536 Epoch 24/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0605 - tp: 277101.0000 - fp: 6262.0000 - tn: 278966.0000 - fn: 7015.0000 - accuracy: 0.9767 - precision: 0.9779 - recall: 0.9753 - auc: 0.9977 - val_loss: 0.0222 - val_tp: 73.0000 - val_fp: 312.0000 - val_tn: 45175.0000 - val_fn: 9.0000 - val_accuracy: 0.9930 - val_precision: 0.1896 - val_recall: 0.8902 - val_auc: 0.9538 Epoch 25/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0596 - tp: 277655.0000 - fp: 6154.0000 - tn: 278556.0000 - fn: 6979.0000 - accuracy: 0.9769 - precision: 0.9783 - recall: 0.9755 - auc: 0.9977 - val_loss: 0.0212 - val_tp: 73.0000 - val_fp: 295.0000 - val_tn: 45192.0000 - val_fn: 9.0000 - val_accuracy: 0.9933 - val_precision: 0.1984 - val_recall: 0.8902 - val_auc: 0.9482 Epoch 26/100 278/278 [==============================] - 8s 30ms/step - loss: 0.0586 - tp: 278013.0000 - fp: 6165.0000 - tn: 278169.0000 - fn: 6997.0000 - accuracy: 0.9769 - precision: 0.9783 - recall: 0.9754 - auc: 0.9978 - val_loss: 0.0217 - val_tp: 72.0000 - val_fp: 292.0000 - val_tn: 45195.0000 - val_fn: 10.0000 - val_accuracy: 0.9934 - val_precision: 0.1978 - val_recall: 0.8780 - val_auc: 0.9483 Epoch 27/100 277/278 [============================>.] - ETA: 0s - loss: 0.0581 - tp: 276446.0000 - fp: 6154.0000 - tn: 277889.0000 - fn: 6807.0000 - accuracy: 0.9772 - precision: 0.9782 - recall: 0.9760 - auc: 0.9978Restoring model weights from the end of the best epoch. 278/278 [==============================] - 8s 30ms/step - loss: 0.0581 - tp: 277470.0000 - fp: 6173.0000 - tn: 278871.0000 - fn: 6830.0000 - accuracy: 0.9772 - precision: 0.9782 - recall: 0.9760 - auc: 0.9978 - val_loss: 0.0209 - val_tp: 72.0000 - val_fp: 281.0000 - val_tn: 45206.0000 - val_fn: 10.0000 - val_accuracy: 0.9936 - val_precision: 0.2040 - val_recall: 0.8780 - val_auc: 0.9485 Epoch 00027: early stopping
訓練プロセスが各勾配更新でデータセット全体を考えていた場合、このオーバーサンプリングは基本的にはクラス重み付けと同値です。
しかしここで行なったように、モデルをバッチ-wise に訓練するとき、オーバーサンプリングされたデータはより滑らかな勾配信号を提供します : 各 positive サンプルが一つのバッチで巨大な重み上で見せられる代わりに、それらは小さい重みで各回に多くの異なるバッチで示されます。
この滑らかな勾配信号はモデルを訓練することを容易にします。
訓練履歴を確認する
ここではメトリクスの分布は異なることに注意してください、何故ならば訓練データは検証とテストデータからは大きく異なる分布を持つからです。
plot_metrics(resampled_history )
再訓練
訓練はバランスの取れたデータ上では容易ですので、上の訓練手続きは素早く overfit するかもしれません。
そこで callbacks.EarlyStopping にいつ訓練を停止するかについてより良い制御を与えるためにエポックを分割します。
resampled_model = make_model() resampled_model.load_weights(initial_weights) # Reset the bias to zero, since this dataset is balanced. output_layer = resampled_model.layers[-1] output_layer.bias.assign([0]) resampled_history = resampled_model.fit( resampled_ds, # These are not real epochs steps_per_epoch = 20, epochs=10*EPOCHS, callbacks = [early_stopping], validation_data=(val_ds))
Train for 20 steps, validate for 23 steps Epoch 1/1000 20/20 [==============================] - 3s 168ms/step - loss: 1.2909 - tp: 10811.0000 - fp: 11937.0000 - tn: 8418.0000 - fn: 9794.0000 - accuracy: 0.4695 - precision: 0.4753 - recall: 0.5247 - auc: 0.4887 - val_loss: 0.9277 - val_tp: 78.0000 - val_fp: 28484.0000 - val_tn: 17003.0000 - val_fn: 4.0000 - val_accuracy: 0.3748 - val_precision: 0.0027 - val_recall: 0.9512 - val_auc: 0.8413 Epoch 2/1000 20/20 [==============================] - 1s 33ms/step - loss: 0.8973 - tp: 14267.0000 - fp: 11350.0000 - tn: 9081.0000 - fn: 6262.0000 - accuracy: 0.5700 - precision: 0.5569 - recall: 0.6950 - auc: 0.6600 - val_loss: 0.8451 - val_tp: 80.0000 - val_fp: 25378.0000 - val_tn: 20109.0000 - val_fn: 2.0000 - val_accuracy: 0.4430 - val_precision: 0.0031 - val_recall: 0.9756 - val_auc: 0.9362 Epoch 3/1000 20/20 [==============================] - 1s 35ms/step - loss: 0.7107 - tp: 15904.0000 - fp: 10547.0000 - tn: 9941.0000 - fn: 4568.0000 - accuracy: 0.6310 - precision: 0.6013 - recall: 0.7769 - auc: 0.7582 - val_loss: 0.7505 - val_tp: 79.0000 - val_fp: 21165.0000 - val_tn: 24322.0000 - val_fn: 3.0000 - val_accuracy: 0.5355 - val_precision: 0.0037 - val_recall: 0.9634 - val_auc: 0.9478 Epoch 4/1000 20/20 [==============================] - 1s 35ms/step - loss: 0.5944 - tp: 16759.0000 - fp: 9572.0000 - tn: 10917.0000 - fn: 3712.0000 - accuracy: 0.6757 - precision: 0.6365 - recall: 0.8187 - auc: 0.8177 - val_loss: 0.6589 - val_tp: 77.0000 - val_fp: 16659.0000 - val_tn: 28828.0000 - val_fn: 5.0000 - val_accuracy: 0.6343 - val_precision: 0.0046 - val_recall: 0.9390 - val_auc: 0.9519 Epoch 5/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.5151 - tp: 17257.0000 - fp: 8350.0000 - tn: 12146.0000 - fn: 3207.0000 - accuracy: 0.7178 - precision: 0.6739 - recall: 0.8433 - auc: 0.8568 - val_loss: 0.5801 - val_tp: 77.0000 - val_fp: 12473.0000 - val_tn: 33014.0000 - val_fn: 5.0000 - val_accuracy: 0.7262 - val_precision: 0.0061 - val_recall: 0.9390 - val_auc: 0.9552 Epoch 6/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.4546 - tp: 17829.0000 - fp: 7203.0000 - tn: 13200.0000 - fn: 2728.0000 - accuracy: 0.7575 - precision: 0.7122 - recall: 0.8673 - auc: 0.8886 - val_loss: 0.5138 - val_tp: 76.0000 - val_fp: 9021.0000 - val_tn: 36466.0000 - val_fn: 6.0000 - val_accuracy: 0.8019 - val_precision: 0.0084 - val_recall: 0.9268 - val_auc: 0.9587 Epoch 7/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.4176 - tp: 17998.0000 - fp: 6179.0000 - tn: 14182.0000 - fn: 2601.0000 - accuracy: 0.7856 - precision: 0.7444 - recall: 0.8737 - auc: 0.9035 - val_loss: 0.4574 - val_tp: 76.0000 - val_fp: 6272.0000 - val_tn: 39215.0000 - val_fn: 6.0000 - val_accuracy: 0.8622 - val_precision: 0.0120 - val_recall: 0.9268 - val_auc: 0.9618 Epoch 8/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.3846 - tp: 18020.0000 - fp: 5174.0000 - tn: 15298.0000 - fn: 2468.0000 - accuracy: 0.8134 - precision: 0.7769 - recall: 0.8795 - auc: 0.9166 - val_loss: 0.4109 - val_tp: 75.0000 - val_fp: 4351.0000 - val_tn: 41136.0000 - val_fn: 7.0000 - val_accuracy: 0.9044 - val_precision: 0.0169 - val_recall: 0.9146 - val_auc: 0.9651 Epoch 9/1000 20/20 [==============================] - 1s 37ms/step - loss: 0.3631 - tp: 17972.0000 - fp: 4520.0000 - tn: 16011.0000 - fn: 2457.0000 - accuracy: 0.8297 - precision: 0.7990 - recall: 0.8797 - auc: 0.9230 - val_loss: 0.3704 - val_tp: 74.0000 - val_fp: 3021.0000 - val_tn: 42466.0000 - val_fn: 8.0000 - val_accuracy: 0.9335 - val_precision: 0.0239 - val_recall: 0.9024 - val_auc: 0.9674 Epoch 10/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.3378 - tp: 17994.0000 - fp: 3772.0000 - tn: 16815.0000 - fn: 2379.0000 - accuracy: 0.8498 - precision: 0.8267 - recall: 0.8832 - auc: 0.9308 - val_loss: 0.3381 - val_tp: 74.0000 - val_fp: 2323.0000 - val_tn: 43164.0000 - val_fn: 8.0000 - val_accuracy: 0.9488 - val_precision: 0.0309 - val_recall: 0.9024 - val_auc: 0.9693 Epoch 11/1000 20/20 [==============================] - 1s 33ms/step - loss: 0.3213 - tp: 18268.0000 - fp: 3260.0000 - tn: 17041.0000 - fn: 2391.0000 - accuracy: 0.8620 - precision: 0.8486 - recall: 0.8843 - auc: 0.9355 - val_loss: 0.3113 - val_tp: 74.0000 - val_fp: 1888.0000 - val_tn: 43599.0000 - val_fn: 8.0000 - val_accuracy: 0.9584 - val_precision: 0.0377 - val_recall: 0.9024 - val_auc: 0.9715 Epoch 12/1000 20/20 [==============================] - 1s 34ms/step - loss: 0.3040 - tp: 17977.0000 - fp: 2813.0000 - tn: 17832.0000 - fn: 2338.0000 - accuracy: 0.8742 - precision: 0.8647 - recall: 0.8849 - auc: 0.9420 - val_loss: 0.2880 - val_tp: 74.0000 - val_fp: 1621.0000 - val_tn: 43866.0000 - val_fn: 8.0000 - val_accuracy: 0.9643 - val_precision: 0.0437 - val_recall: 0.9024 - val_auc: 0.9733 Epoch 13/1000 20/20 [==============================] - 1s 34ms/step - loss: 0.2907 - tp: 18262.0000 - fp: 2594.0000 - tn: 17824.0000 - fn: 2280.0000 - accuracy: 0.8810 - precision: 0.8756 - recall: 0.8890 - auc: 0.9460 - val_loss: 0.2679 - val_tp: 74.0000 - val_fp: 1426.0000 - val_tn: 44061.0000 - val_fn: 8.0000 - val_accuracy: 0.9685 - val_precision: 0.0493 - val_recall: 0.9024 - val_auc: 0.9744 Epoch 14/1000 20/20 [==============================] - 1s 31ms/step - loss: 0.2803 - tp: 18314.0000 - fp: 2190.0000 - tn: 18192.0000 - fn: 2264.0000 - accuracy: 0.8913 - precision: 0.8932 - recall: 0.8900 - auc: 0.9491 - val_loss: 0.2503 - val_tp: 74.0000 - val_fp: 1330.0000 - val_tn: 44157.0000 - val_fn: 8.0000 - val_accuracy: 0.9706 - val_precision: 0.0527 - val_recall: 0.9024 - val_auc: 0.9752 Epoch 15/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2679 - tp: 18275.0000 - fp: 1967.0000 - tn: 18471.0000 - fn: 2247.0000 - accuracy: 0.8971 - precision: 0.9028 - recall: 0.8905 - auc: 0.9528 - val_loss: 0.2350 - val_tp: 74.0000 - val_fp: 1258.0000 - val_tn: 44229.0000 - val_fn: 8.0000 - val_accuracy: 0.9722 - val_precision: 0.0556 - val_recall: 0.9024 - val_auc: 0.9759 Epoch 16/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2575 - tp: 18217.0000 - fp: 1735.0000 - tn: 18790.0000 - fn: 2218.0000 - accuracy: 0.9035 - precision: 0.9130 - recall: 0.8915 - auc: 0.9558 - val_loss: 0.2215 - val_tp: 73.0000 - val_fp: 1194.0000 - val_tn: 44293.0000 - val_fn: 9.0000 - val_accuracy: 0.9736 - val_precision: 0.0576 - val_recall: 0.8902 - val_auc: 0.9766 Epoch 17/1000 20/20 [==============================] - 1s 31ms/step - loss: 0.2514 - tp: 18223.0000 - fp: 1700.0000 - tn: 18896.0000 - fn: 2141.0000 - accuracy: 0.9062 - precision: 0.9147 - recall: 0.8949 - auc: 0.9582 - val_loss: 0.2085 - val_tp: 73.0000 - val_fp: 1136.0000 - val_tn: 44351.0000 - val_fn: 9.0000 - val_accuracy: 0.9749 - val_precision: 0.0604 - val_recall: 0.8902 - val_auc: 0.9767 Epoch 18/1000 20/20 [==============================] - 1s 37ms/step - loss: 0.2424 - tp: 18402.0000 - fp: 1466.0000 - tn: 18902.0000 - fn: 2190.0000 - accuracy: 0.9107 - precision: 0.9262 - recall: 0.8936 - auc: 0.9601 - val_loss: 0.1980 - val_tp: 73.0000 - val_fp: 1105.0000 - val_tn: 44382.0000 - val_fn: 9.0000 - val_accuracy: 0.9756 - val_precision: 0.0620 - val_recall: 0.8902 - val_auc: 0.9771 Epoch 19/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2358 - tp: 18413.0000 - fp: 1449.0000 - tn: 18924.0000 - fn: 2174.0000 - accuracy: 0.9115 - precision: 0.9270 - recall: 0.8944 - auc: 0.9624 - val_loss: 0.1889 - val_tp: 73.0000 - val_fp: 1074.0000 - val_tn: 44413.0000 - val_fn: 9.0000 - val_accuracy: 0.9762 - val_precision: 0.0636 - val_recall: 0.8902 - val_auc: 0.9772 Epoch 20/1000 20/20 [==============================] - 1s 33ms/step - loss: 0.2254 - tp: 18587.0000 - fp: 1249.0000 - tn: 19002.0000 - fn: 2122.0000 - accuracy: 0.9177 - precision: 0.9370 - recall: 0.8975 - auc: 0.9653 - val_loss: 0.1801 - val_tp: 73.0000 - val_fp: 1040.0000 - val_tn: 44447.0000 - val_fn: 9.0000 - val_accuracy: 0.9770 - val_precision: 0.0656 - val_recall: 0.8902 - val_auc: 0.9771 Epoch 21/1000 20/20 [==============================] - 1s 34ms/step - loss: 0.2200 - tp: 18517.0000 - fp: 1188.0000 - tn: 19205.0000 - fn: 2050.0000 - accuracy: 0.9209 - precision: 0.9397 - recall: 0.9003 - auc: 0.9675 - val_loss: 0.1720 - val_tp: 73.0000 - val_fp: 1024.0000 - val_tn: 44463.0000 - val_fn: 9.0000 - val_accuracy: 0.9773 - val_precision: 0.0665 - val_recall: 0.8902 - val_auc: 0.9775 Epoch 22/1000 20/20 [==============================] - 1s 34ms/step - loss: 0.2154 - tp: 18254.0000 - fp: 1078.0000 - tn: 19556.0000 - fn: 2072.0000 - accuracy: 0.9231 - precision: 0.9442 - recall: 0.8981 - auc: 0.9685 - val_loss: 0.1645 - val_tp: 73.0000 - val_fp: 1007.0000 - val_tn: 44480.0000 - val_fn: 9.0000 - val_accuracy: 0.9777 - val_precision: 0.0676 - val_recall: 0.8902 - val_auc: 0.9773 Epoch 23/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2092 - tp: 18465.0000 - fp: 1074.0000 - tn: 19455.0000 - fn: 1966.0000 - accuracy: 0.9258 - precision: 0.9450 - recall: 0.9038 - auc: 0.9706 - val_loss: 0.1588 - val_tp: 74.0000 - val_fp: 1022.0000 - val_tn: 44465.0000 - val_fn: 8.0000 - val_accuracy: 0.9774 - val_precision: 0.0675 - val_recall: 0.9024 - val_auc: 0.9776 Epoch 24/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2034 - tp: 18590.0000 - fp: 1025.0000 - tn: 19414.0000 - fn: 1931.0000 - accuracy: 0.9278 - precision: 0.9477 - recall: 0.9059 - auc: 0.9720 - val_loss: 0.1532 - val_tp: 74.0000 - val_fp: 1011.0000 - val_tn: 44476.0000 - val_fn: 8.0000 - val_accuracy: 0.9776 - val_precision: 0.0682 - val_recall: 0.9024 - val_auc: 0.9772 Epoch 25/1000 20/20 [==============================] - 1s 33ms/step - loss: 0.1999 - tp: 18484.0000 - fp: 1076.0000 - tn: 19531.0000 - fn: 1869.0000 - accuracy: 0.9281 - precision: 0.9450 - recall: 0.9082 - auc: 0.9736 - val_loss: 0.1472 - val_tp: 74.0000 - val_fp: 1009.0000 - val_tn: 44478.0000 - val_fn: 8.0000 - val_accuracy: 0.9777 - val_precision: 0.0683 - val_recall: 0.9024 - val_auc: 0.9774 Epoch 26/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.1961 - tp: 18723.0000 - fp: 1048.0000 - tn: 19357.0000 - fn: 1832.0000 - accuracy: 0.9297 - precision: 0.9470 - recall: 0.9109 - auc: 0.9739 - val_loss: 0.1415 - val_tp: 74.0000 - val_fp: 989.0000 - val_tn: 44498.0000 - val_fn: 8.0000 - val_accuracy: 0.9781 - val_precision: 0.0696 - val_recall: 0.9024 - val_auc: 0.9772 Epoch 27/1000 20/20 [==============================] - 1s 38ms/step - loss: 0.1890 - tp: 18580.0000 - fp: 990.0000 - tn: 19537.0000 - fn: 1853.0000 - accuracy: 0.9306 - precision: 0.9494 - recall: 0.9093 - auc: 0.9760 - val_loss: 0.1367 - val_tp: 74.0000 - val_fp: 976.0000 - val_tn: 44511.0000 - val_fn: 8.0000 - val_accuracy: 0.9784 - val_precision: 0.0705 - val_recall: 0.9024 - val_auc: 0.9773 Epoch 28/1000 20/20 [==============================] - 1s 33ms/step - loss: 0.1835 - tp: 18707.0000 - fp: 921.0000 - tn: 19609.0000 - fn: 1723.0000 - accuracy: 0.9354 - precision: 0.9531 - recall: 0.9157 - auc: 0.9777 - val_loss: 0.1316 - val_tp: 74.0000 - val_fp: 967.0000 - val_tn: 44520.0000 - val_fn: 8.0000 - val_accuracy: 0.9786 - val_precision: 0.0711 - val_recall: 0.9024 - val_auc: 0.9774 Epoch 29/1000 20/20 [==============================] - 1s 34ms/step - loss: 0.1820 - tp: 18835.0000 - fp: 862.0000 - tn: 19527.0000 - fn: 1736.0000 - accuracy: 0.9366 - precision: 0.9562 - recall: 0.9156 - auc: 0.9776 - val_loss: 0.1272 - val_tp: 74.0000 - val_fp: 936.0000 - val_tn: 44551.0000 - val_fn: 8.0000 - val_accuracy: 0.9793 - val_precision: 0.0733 - val_recall: 0.9024 - val_auc: 0.9774 Epoch 30/1000 20/20 [==============================] - 1s 35ms/step - loss: 0.1764 - tp: 18796.0000 - fp: 844.0000 - tn: 19677.0000 - fn: 1643.0000 - accuracy: 0.9393 - precision: 0.9570 - recall: 0.9196 - auc: 0.9791 - val_loss: 0.1234 - val_tp: 74.0000 - val_fp: 931.0000 - val_tn: 44556.0000 - val_fn: 8.0000 - val_accuracy: 0.9794 - val_precision: 0.0736 - val_recall: 0.9024 - val_auc: 0.9769 Epoch 31/1000 20/20 [==============================] - 1s 33ms/step - loss: 0.1727 - tp: 18808.0000 - fp: 793.0000 - tn: 19706.0000 - fn: 1653.0000 - accuracy: 0.9403 - precision: 0.9595 - recall: 0.9192 - auc: 0.9800 - val_loss: 0.1200 - val_tp: 74.0000 - val_fp: 909.0000 - val_tn: 44578.0000 - val_fn: 8.0000 - val_accuracy: 0.9799 - val_precision: 0.0753 - val_recall: 0.9024 - val_auc: 0.9773 Epoch 32/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.1704 - tp: 18934.0000 - fp: 824.0000 - tn: 19610.0000 - fn: 1592.0000 - accuracy: 0.9410 - precision: 0.9583 - recall: 0.9224 - auc: 0.9807 - val_loss: 0.1171 - val_tp: 74.0000 - val_fp: 907.0000 - val_tn: 44580.0000 - val_fn: 8.0000 - val_accuracy: 0.9799 - val_precision: 0.0754 - val_recall: 0.9024 - val_auc: 0.9772 Epoch 33/1000 18/20 [==========================>...] - ETA: 0s - loss: 0.1659 - tp: 17032.0000 - fp: 717.0000 - tn: 17676.0000 - fn: 1439.0000 - accuracy: 0.9415 - precision: 0.9596 - recall: 0.9221 - auc: 0.9817Restoring model weights from the end of the best epoch. 20/20 [==============================] - 1s 34ms/step - loss: 0.1654 - tp: 18936.0000 - fp: 788.0000 - tn: 19642.0000 - fn: 1594.0000 - accuracy: 0.9418 - precision: 0.9600 - recall: 0.9224 - auc: 0.9818 - val_loss: 0.1142 - val_tp: 74.0000 - val_fp: 908.0000 - val_tn: 44579.0000 - val_fn: 8.0000 - val_accuracy: 0.9799 - val_precision: 0.0754 - val_recall: 0.9024 - val_auc: 0.9771 Epoch 00033: early stopping
訓練履歴を再確認する
plot_metrics(resampled_history)
メトリクスを評価する
train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE) test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
resampled_results = resampled_model.evaluate(test_features, test_labels, batch_size=BATCH_SIZE, verbose=0) for name, value in zip(resampled_model.metrics_names, resampled_results): print(name, ': ', value) print() plot_cm(test_labels, test_predictions_weighted)
loss : 0.1573299011001384 tp : 92.0 fp : 1193.0 tn : 55664.0 fn : 13.0 accuracy : 0.978828 precision : 0.07159533 recall : 0.8761905 auc : 0.96230674 Legitimate Transactions Detected (True Negatives): 55920 Legitimate Transactions Incorrectly Detected (False Positives): 937 Fraudulent Transactions Missed (False Negatives): 13 Fraudulent Transactions Detected (True Positives): 92 Total Fraudulent Transactions: 105
ROC をプロットする
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0]) plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--') plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1]) plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--') plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2]) plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--') plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fc67a751588>
このチュートリアルを貴方の問題に適用する
不均衡なデータ分類は本質的に困難なタスクです、何故ならばそこから学習する非常に少ないサンプルがあるからです。貴方は常に最初にデータから開始してできる限り多くのサンプルを収集することに最善をつくしてそしてどの特徴が関連するかもしれないのかをしっかりと念頭におくべきです、そうすればモデルは貴方の最小クラスを最大限に活用できます。ある点で貴方のモデルは改良されて望む結果を生成するために努力するかもしれません、そして貴方の問題のコンテキストとエラーの異なるタイプの間のトレードオフを念頭に置くことは重要です。
以上