Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

TensorFlow 2.0 : 上級 Tutorials : 構造化データ :- 不均衡なデータ上の分類

Posted on 11/18/2019 by Sales Information

TensorFlow 2.0 : 上級 Tutorials : 構造化データ :- 不均衡なデータ上の分類 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/18/2019

* 本ページは、TensorFlow org サイトの TF 2.0 – Advanced Tutorials – Structured data の以下のページを翻訳した上で
適宜、補足説明したものです:

  • Classification on imbalanced data

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

★ 無料セミナー開催中 ★ クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。

◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

構造化データ :- 不均衡なデータ上の分類

このチュートリアルでは非常に不均衡なデータセットをどのように分類するかを実演します、そこでは一つのクラスのサンプル数がもう一つの (クラスの) サンプルに大いに数で上回ります。Kaggle でホストされている Credit Card Fraud Detection データセットで作業します。目的は全部で 284,807 トランザクションから僅か 492 の詐欺トランザクションを検出することです。不均衡なデータからモデルが学習することを助けるためにモデルと クラス重み を定義するために Keras を使用します。

このチュートリアルは以下を行なうための完全なコードを含みます :

  • Pandas を使用して CSV ファイルをロードする。
  • 訓練、検証とテストセットを作成する。
  • (クラス重みを設定することを含む) Keras を使用してモデルを定義して訓練する。
  • (精度 (= precision) と recall を含む) 様々なメトリクスを使用してモデルを評価する。
  • 次のような不均衡なデータを扱うための一般的なテクニックを試す :
    • クラス重み付け (= Class weighting)
    • Oversampling

 

セットアップ

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

 

データ前処理と調査

Kaggle Credit Card Fraud データセットをダウンロードする

Pandas は構造化データをロードしてそれで作業するための多くの役立つユティリティを持つ Python ライブラリで CSV を dataframe にダウンロードするために使用できます。

★ Note: このデータセットはビッグ・データマイニングと詐欺検出上の Worldline と ULB (Université Libre de Bruxelles) の 機械学習グループ の研究コラボレーションの間に収集されて解析されました。関連トピックについて現在と過去のプロジェクトのより詳細は ここ と DefeatFraud プロジェクトのページで利用可能です。

file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()

Time V1 V2 V3 V4 V5 V6 V7 V8 V9 … V21 V22 V23 V24 V25 V26 V27 V28 Amount Class
0 0.0 -1.359807 -0.072781 2.536347 1.378155 -0.338321 0.462388 0.239599 0.098698 0.363787 … -0.018307 0.277838 -0.110474 0.066928 0.128539 -0.189115 0.133558 -0.021053 149.62 0
1 0.0 1.191857 0.266151 0.166480 0.448154 0.060018 -0.082361 -0.078803 0.085102 -0.255425 … -0.225775 -0.638672 0.101288 -0.339846 0.167170 0.125895 -0.008983 0.014724 2.69 0
2 1.0 -1.358354 -1.340163 1.773209 0.379780 -0.503198 1.800499 0.791461 0.247676 -1.514654 … 0.247998 0.771679 0.909412 -0.689281 -0.327642 -0.139097 -0.055353 -0.059752 378.66 0
3 1.0 -0.966272 -0.185226 1.792993 -0.863291 -0.010309 1.247203 0.237609 0.377436 -1.387024 … -0.108300 0.005274 -0.190321 -1.175575 0.647376 -0.221929 0.062723 0.061458 123.50 0
4 2.0 -1.158233 0.877737 1.548718 0.403034 -0.407193 0.095921 0.592941 -0.270533 0.817739 … -0.009431 0.798278 -0.137458 0.141267 -0.206010 0.502292 0.219422 0.215153 69.99 0

5 rows × 31 columns
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()

Time V1 V2 V3 V4 V5 V26 V27 V28 Amount クラス
カウント 284807.000000 2.848070e+05 2.848070e+05 2.848070e+05 2.848070e+05 2.848070e+05 2.848070e+05 2.848070e+05 2.848070e+05 284807.000000 284807.000000
平均 94813.859575 1.165980e-15 3.416908e-16 -1.373150e-15 2.086869e-15 9.604066e-16 1.687098e-15 -3.666453e-16 -1.220404e-16 88.349619 0.001727
標準偏差 47488.145955 1.958696e+00 1.651309e+00 1.516255e+00 1.415869e+00 1.380247e+00 4.822270e-01 4.036325e-01 3.300833e-01 250.120109 0.041527
最小値 0.000000 -5.640751e+01 -7.271573e+01 -4.832559e+01 -5.683171e+00 -1.137433e+02 -2.604551e+00 -2.256568e+01 -1.543008e+01 0.000000 0.000000
25% 54201.500000 -9.203734e-01 -5.985499e-01 -8.903648e-01 -8.486401e-01 -6.915971e-01 -3.269839e-01 -7.083953e-02 -5.295979e-02 5.600000 0.000000
50% 84692.000000 1.810880e-02 6.548556e-02 1.798463e-01 -1.984653e-02 -5.433583e-02 -5.213911e-02 1.342146e-03 1.124383e-02 22.000000 0.000000
75% 139320.500000 1.315642e+00 8.037239e-01 1.027196e+00 7.433413e-01 6.119264e-01 2.409522e-01 9.104512e-02 7.827995e-02 77.165000 0.000000
最大値 172792.000000 2.454930e+00 2.205773e+01 9.382558e+00 1.687534e+01 3.480167e+01 3.517346e+00 3.161220e+01 3.384781e+01 25691.160000 1.000000

 

クラスラベルが不均衡であることを調べる

不均衡なデータセットを見てみましょう :

neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)

これはポジティブなサンプルの小さい割合を示します。

 

データをクリーンアップ、分割して正規化する

生データは 2, 3 の問題を持ちます。最初に Time と Amount カラムは直接使用するには変化し過ぎです。Time カラムは破棄して (何故ならばそれが何を意味するか明確でないからです) そして範囲を減じるために Amount カラムの対数を取ります。

cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps=0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)

データセットを訓練、検証とテストセットに分割します。検証セットは損失と任意のメトリクスを評価するためにモデル fitting の間に使用されます、けれどもモデルはこのデータでは fit されません。テストセットは訓練段階の間には全く使用されません、そしてモデルが新しいデータにどのくらい上手く一般化されたかを評価するために最後に使用されるだけです。これは不均衡なデータセットでは特に重要です、そこでは overfitting は訓練データの欠落から本質的な関心事です。

# Use a utility from sklearn to split and shuffle our dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)

sklearn StandardScaler を使用して入力特徴を正規化します。これは平均を 0 そして標準偏差を 1 に設定します。

★ Note: StandardScaler はモデルが検証やテストセットを覗き見ないことを確実にするために train_features を使用して fit するだけです。

scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)
警告: モデルを配備することを望む場合、前処理の計算を保持することは重要です。それらを実装するための最も容易な方法は層として、そしてそれらを貴方のモデルにエクスポートする前にアタッチします。

 

データ分布を見る

次に 2,3 の特徴に渡りポジティブとネガティブ・サンプルの分布を比較します。この時点で貴方自身に尋ねる良い質問は :

  • これらの分布は意味があるでしょうか?
    • はい。入力を正規化してそしてこれらは +/- 2 範囲に殆ど集中しています。
  • 分布の間に違いを見ることができるでしょうか?
    • はい、ポジティブ・サンプルは遥かに高いレートの極値 (= extreme values) を含みます。
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns = train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns = train_df.columns)

sns.jointplot(pos_df['V5'], pos_df['V6'],
              kind='hex', xlim = (-5,5), ylim = (-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(neg_df['V5'], neg_df['V6'],
              kind='hex', xlim = (-5,5), ylim = (-5,5))
_ = plt.suptitle("Negative distribution")

 

モデルとメトリクスを定義する

密に接続された隠れ層、overfitting を減じる dropout 層、そしてトランザクションが不正である確率を返す出力 sigmoid 層を持つ単純なネットワークを作成する関数を定義します :

METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
]

def make_model(metrics = METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(lr=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model

 

有用なメトリクスを理解する

パフォーマンスを評価するときに有用である、モデルにより計算可能な上で定義された 2, 3 のメトリクスがあることに注意してください。

  • false negative と false positive は 間違って 分類されたサンプルです
  • true negative と true positive は 正しく 分類されたサンプルです
  • accuracy は正しく分類されたサンプルのパーセンテージです > $\frac{\text{true samples}}{\text{total samples}}$
  • precision は正しく分類された 予測された ポジティブのパーセンテージです > $\frac{\text{true positives}}{\text{true positives + false positives}}$
  • recall は正しく分類された 実際の ポジティブのパーセンテージです > $\frac{\text{true positives}}{\text{true positives + false negatives}}$
  • AUC は ROC曲線下の面積 (ROC-AUC, Area Under the Curve of a Receiver Operating Characteristic curve) を参照します。このメトリックは、分類器がランダムなポジティブ・サンプルをランダムなネガティブ・サンプルよりも高く位置付ける確率に等しいです。

★ Note: accuracy (精度) はこのタスクのために有用なメトリクスではありません。常に False を予測することによりこのタスク上 99.8%+ 精度を得られるでしょう。

Read more: * True vs. False と Positive vs. Negative * Accuracy * Precision と Recall * ROC-AUC

 

ベースライン・モデル

モデルを構築する

今は先に定義された関数を使用してモデルを作成して訓練します。モデルは 2048 のデフォルトよりも大きいバッチサイズを使用して fit することに注意してください、これは各バッチが 2, 3 のポジティブ・サンプルを含む妥当な機会を持つことを確実にするために重要です。バッチサイズが小さ過ぎれば、そこから学習するための詐欺的なトランザクションを持たない傾向になるでしょう。

Note: このモデルは不均衡なクラスを上手く扱いません。このチュートリアルの後でそれを改良します。

EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                480       
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0

Test run the model:

model.predict(train_features[:10])
array([[0.35366294],
       [0.17252561],
       [0.08079773],
       [0.13621533],
       [0.18569201],
       [0.22863552],
       [0.3243227 ],
       [0.25937212],
       [0.27047133],
       [0.1558134 ]], dtype=float32)

 

オプション: 正しい初期バイアスを設定する

これらは初期推論は良くはありません。貴方はデータセットは不均衡であることを知っています。それを反映するように出力層のバイアスを設定します (参照: A Recipe for Training Neural Networks: “init well”)。これは初期収束に役立つことができます。

デフォルトのバイアス初期化では損失はおよそ math.log(2) = 0.69314 になるはずです

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.2624

設定する正しいバイアスは以下から導出できます :

$$ p_0 = pos/(pos + neg) = 1/(1+e^{-b_0}) $$
$$ b_0 = -log_e(1/p_0 – 1) $$
$$ b_0 = log_e(pos/neg)$$
initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])

それを初期バイアスとして設定すると、モデルは遥かに合理的な初期推論を与えます。

それは pos/total = 0.0018 近くになるはずです

model = make_model(output_bias = initial_bias)
model.predict(train_features[:10])
array([[0.00416172],
       [0.0006738 ],
       [0.00036496],
       [0.00125635],
       [0.00200993],
       [0.0012227 ],
       [0.01357868],
       [0.00356525],
       [0.00580611],
       [0.02072738]], dtype=float32)

この初期化で初期損失はおよそ次になるはずです :

$$-p_0log(p_0)-(1-p_0)log(1-p_0) = 0.01317$$
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0164

この初期損失は素朴な初期化であるよりもおよそ 50 倍少ないです。

このようにしてモデルはポジティブサンプルが可能性が低いことを単に学習する、最初の数エポックを消費する必要がありません。これはまた訓練の間に損失のプロットを読むことを容易にします。

 

初期重みをチェックポイントする

様々な訓練をより比較可能に実行するために、この初期モデルの重みをチェックポイント・ファイルに保持します、そしてそれらを訓練前に各モデルにロードします。

initial_weights = os.path.join(tempfile.mkdtemp(),'initial_weights')
model.save_weights(initial_weights)

 

バイアス修正 (= fix) が役立つことをを確認する

進む前に、注意深いバイアス初期化が実際に役立つことを素早く確認します。

この注意深い初期化ありとなしで、20 エポックの間モデルを訓練します、そして損失を比較します :

model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
def plot_loss(history, label, n):
  # Use a log scale to show the wide range of values.
  plt.semilogy(history.epoch,  history.history['loss'],
               color=colors[n], label='Train '+label)
  plt.semilogy(history.epoch,  history.history['val_loss'],
          color=colors[n], label='Val '+label,
          linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  
  plt.legend()
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)

上の図は次を明白にします: この問題上、検証損失の観点から、この注意深い初期化は明白な優位を与えます。

 

モデルを訓練する

model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels))
Train on 182276 samples, validate on 45569 samples
Epoch 1/100
182276/182276 [==============================] - 3s 15us/sample - loss: 0.0148 - tp: 33.0000 - fp: 107.0000 - tn: 181864.0000 - fn: 272.0000 - accuracy: 0.9979 - precision: 0.2357 - recall: 0.1082 - auc: 0.6736 - val_loss: 0.0072 - val_tp: 3.0000 - val_fp: 2.0000 - val_tn: 45485.0000 - val_fn: 79.0000 - val_accuracy: 0.9982 - val_precision: 0.6000 - val_recall: 0.0366 - val_auc: 0.9058
Epoch 2/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0087 - tp: 94.0000 - fp: 34.0000 - tn: 181937.0000 - fn: 211.0000 - accuracy: 0.9987 - precision: 0.7344 - recall: 0.3082 - auc: 0.8194 - val_loss: 0.0049 - val_tp: 40.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 42.0000 - val_accuracy: 0.9989 - val_precision: 0.8511 - val_recall: 0.4878 - val_auc: 0.9264
Epoch 3/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0075 - tp: 122.0000 - fp: 33.0000 - tn: 181938.0000 - fn: 183.0000 - accuracy: 0.9988 - precision: 0.7871 - recall: 0.4000 - auc: 0.8638 - val_loss: 0.0043 - val_tp: 47.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 35.0000 - val_accuracy: 0.9991 - val_precision: 0.8704 - val_recall: 0.5732 - val_auc: 0.9266
Epoch 4/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0065 - tp: 131.0000 - fp: 32.0000 - tn: 181939.0000 - fn: 174.0000 - accuracy: 0.9989 - precision: 0.8037 - recall: 0.4295 - auc: 0.8855 - val_loss: 0.0040 - val_tp: 57.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 25.0000 - val_accuracy: 0.9993 - val_precision: 0.8906 - val_recall: 0.6951 - val_auc: 0.9327
Epoch 5/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0055 - tp: 172.0000 - fp: 28.0000 - tn: 181943.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8600 - recall: 0.5639 - auc: 0.9170 - val_loss: 0.0038 - val_tp: 59.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 23.0000 - val_accuracy: 0.9993 - val_precision: 0.8939 - val_recall: 0.7195 - val_auc: 0.9326
Epoch 6/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0056 - tp: 158.0000 - fp: 32.0000 - tn: 181939.0000 - fn: 147.0000 - accuracy: 0.9990 - precision: 0.8316 - recall: 0.5180 - auc: 0.9008 - val_loss: 0.0036 - val_tp: 59.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 23.0000 - val_accuracy: 0.9993 - val_precision: 0.8939 - val_recall: 0.7195 - val_auc: 0.9326
Epoch 7/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0055 - tp: 157.0000 - fp: 31.0000 - tn: 181940.0000 - fn: 148.0000 - accuracy: 0.9990 - precision: 0.8351 - recall: 0.5148 - auc: 0.9112 - val_loss: 0.0035 - val_tp: 62.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8986 - val_recall: 0.7561 - val_auc: 0.9326
Epoch 8/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0048 - tp: 172.0000 - fp: 33.0000 - tn: 181938.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8390 - recall: 0.5639 - auc: 0.9130 - val_loss: 0.0034 - val_tp: 62.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8986 - val_recall: 0.7561 - val_auc: 0.9326
Epoch 9/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0051 - tp: 159.0000 - fp: 28.0000 - tn: 181943.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8503 - recall: 0.5213 - auc: 0.9081 - val_loss: 0.0033 - val_tp: 62.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8986 - val_recall: 0.7561 - val_auc: 0.9326
Epoch 10/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0050 - tp: 169.0000 - fp: 33.0000 - tn: 181938.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8366 - recall: 0.5541 - auc: 0.9215 - val_loss: 0.0033 - val_tp: 62.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8986 - val_recall: 0.7561 - val_auc: 0.9326
Epoch 11/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0051 - tp: 157.0000 - fp: 29.0000 - tn: 181942.0000 - fn: 148.0000 - accuracy: 0.9990 - precision: 0.8441 - recall: 0.5148 - auc: 0.9233 - val_loss: 0.0032 - val_tp: 62.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8986 - val_recall: 0.7561 - val_auc: 0.9326
Epoch 12/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0045 - tp: 178.0000 - fp: 37.0000 - tn: 181934.0000 - fn: 127.0000 - accuracy: 0.9991 - precision: 0.8279 - recall: 0.5836 - auc: 0.9317 - val_loss: 0.0031 - val_tp: 62.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8986 - val_recall: 0.7561 - val_auc: 0.9326
Epoch 13/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0047 - tp: 160.0000 - fp: 32.0000 - tn: 181939.0000 - fn: 145.0000 - accuracy: 0.9990 - precision: 0.8333 - recall: 0.5246 - auc: 0.9185 - val_loss: 0.0031 - val_tp: 63.0000 - val_fp: 8.0000 - val_tn: 45479.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8873 - val_recall: 0.7683 - val_auc: 0.9326
Epoch 14/100
161792/182276 [=========================>....] - ETA: 0s - loss: 0.0045 - tp: 155.0000 - fp: 31.0000 - tn: 161488.0000 - fn: 118.0000 - accuracy: 0.9991 - precision: 0.8333 - recall: 0.5678 - auc: 0.9256Restoring model weights from the end of the best epoch.
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0044 - tp: 175.0000 - fp: 33.0000 - tn: 181938.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8413 - recall: 0.5738 - auc: 0.9251 - val_loss: 0.0031 - val_tp: 64.0000 - val_fp: 8.0000 - val_tn: 45479.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8889 - val_recall: 0.7805 - val_auc: 0.9326
Epoch 00014: early stopping

 

訓練履歴を確認する

このセクションでは、モデルの精度と訓練と検証セット上の損失のプロットを生成します。これは overfitting を確認するために有用です、それについてはこの チュートリアル で更に学習できます。

更に、上で作成した任意のメトリクスのためにこれらのプロットを生成できます。false negative がサンプルとして含まれます。

def plot_metrics(history):
  metrics =  ['loss', 'auc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch,  history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()
plot_metrics(baseline_history)

★ Note: 検証カーブは訓練カーブよりも一般により良く遂行します。これは主として、モデルを評価するとき dropout 層が有効ではないという事実によります。

 

メトリクスを評価する

実際 (= actual) vs 予測されたラベルを要約するために 混同行列 を使用できます、そこでは X 軸は予測されたラベルで Y 軸は実際のラベルです。

train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
def plot_cm(labels, predictions, p=0.5):
  cm = confusion_matrix(labels, predictions > p)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(p))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))

テストデータセット上のモデルを評価して上で作成したメトリクスのための結果を表示します。

baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
loss :  0.004311129764585171
tp :  68.0
fp :  11.0
tn :  56846.0
fn :  37.0
accuracy :  0.9991573
precision :  0.8607595
recall :  0.64761907
auc :  0.92348534

Legitimate Transactions Detected (True Negatives):  56846
Legitimate Transactions Incorrectly Detected (False Positives):  11
Fraudulent Transactions Missed (False Negatives):  37
Fraudulent Transactions Detected (True Positives):  68
Total Fraudulent Transactions:  105

モデルが総てを完全に予測したのであれば、これは 対角行列 になるでしょう、そこでは (正しくない予測を示す) 主対角線から離れた値はゼロになるでしょう。この場合行列は比較的少ない false positive を持つことを示します、これは誤ってフラグを建てられた比較的少ない正当なトランザクションがあったことを意味します。けれども、false positive の数が増加しているコストにもかかわらず、貴方はより少ない false negative でさえも望みがちでしょう。このトレードオフは好ましいかもしれません、何故ならば false positive はカスタマーにカードの動きを検証することを要求するための電子メールが送られることを引き起こすかもしれない一方で、false negative は詐欺的なトランザクションが通り抜けることを可能にするからです。

 

ROC をプロットする

今は ROC をプロットします。このプロットは有用です、何故ならばそれはひと目で、単に出力閾値を調整することによりモデルが到達できるパフォーマンスの範囲を示すからです。

def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fc6ba4dcbe0>

それは precision は比較的高いように見えますが、recall と ROC 曲線下の面積 (AUC) は貴方が好むほどには高くありません。分類器は precision と recall の両者を最大化しようとするときしばしば課題に直面します、これは不均衡なデータセットで作業するときに特に真です。関心がある問題のコンテキストで異なるタイプのエラーのコストを考えることは重要です。この例では、false negative (詐欺的なトランザクションが取り逃がされる) がファイナンシャルコストを持つかもしれない一方で、false positive (トランザクションは誤って詐欺的とフラグが立てられる) はユーザの幸せを減らすかもしれません。

 

クラス重み

クラス重みを計算する

ゴールは詐欺的トランザクションを識別することですが、作業するためのそれらの非常に多くの positive サンプルを持ちませんので、分類器に利用可能な少ないサンプルに大きく重み付けすることを望むでしょう。Keras にパラメータを通して各クラスのための重みを渡すことによりこれを行なうことができます。これらはモデルに under-represented クラスからのサンプルに「より大きな注意を払う」ことをさせます。

# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg)*(total)/2.0 
weight_for_1 = (1 / pos)*(total)/2.0

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50
Weight for class 1: 289.44

 

クラス重みでモデルを訓練する

今は予測にどのように影響するかを見るためにクラス重みでモデルを再訓練して評価してみます。

★ Note: class_weights の使用は損失の範囲を変更します。これは optimizer に依拠して訓練の安定性に影響するかもしれません。optimizers.SGD のような、ステップサイズが勾配の大きさに依拠する optimizer は失敗するかもしれません。ここで使用される optimizer、optimizers.Adam はスケーリング変更の影響を受けません。また重み付けゆえに、トータル損失は 2 つのモデル間で比較可能ではないことに注意してください。

weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight) 
Train on 182276 samples, validate on 45569 samples
Epoch 1/100
182276/182276 [==============================] - 3s 17us/sample - loss: 2.3320 - tp: 49.0000 - fp: 550.0000 - tn: 181421.0000 - fn: 256.0000 - accuracy: 0.9956 - precision: 0.0818 - recall: 0.1607 - auc: 0.6815 - val_loss: 0.7987 - val_tp: 40.0000 - val_fp: 14.0000 - val_tn: 45473.0000 - val_fn: 42.0000 - val_accuracy: 0.9988 - val_precision: 0.7407 - val_recall: 0.4878 - val_auc: 0.9440
Epoch 2/100
182276/182276 [==============================] - 1s 3us/sample - loss: 1.0274 - tp: 151.0000 - fp: 1122.0000 - tn: 180849.0000 - fn: 154.0000 - accuracy: 0.9930 - precision: 0.1186 - recall: 0.4951 - auc: 0.8471 - val_loss: 0.4155 - val_tp: 65.0000 - val_fp: 28.0000 - val_tn: 45459.0000 - val_fn: 17.0000 - val_accuracy: 0.9990 - val_precision: 0.6989 - val_recall: 0.7927 - val_auc: 0.9664
Epoch 3/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.6664 - tp: 196.0000 - fp: 1788.0000 - tn: 180183.0000 - fn: 109.0000 - accuracy: 0.9896 - precision: 0.0988 - recall: 0.6426 - auc: 0.8987 - val_loss: 0.3144 - val_tp: 69.0000 - val_fp: 58.0000 - val_tn: 45429.0000 - val_fn: 13.0000 - val_accuracy: 0.9984 - val_precision: 0.5433 - val_recall: 0.8415 - val_auc: 0.9748
Epoch 4/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.4983 - tp: 221.0000 - fp: 2856.0000 - tn: 179115.0000 - fn: 84.0000 - accuracy: 0.9839 - precision: 0.0718 - recall: 0.7246 - auc: 0.9235 - val_loss: 0.2678 - val_tp: 70.0000 - val_fp: 109.0000 - val_tn: 45378.0000 - val_fn: 12.0000 - val_accuracy: 0.9973 - val_precision: 0.3911 - val_recall: 0.8537 - val_auc: 0.9746
Epoch 5/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.4914 - tp: 228.0000 - fp: 3722.0000 - tn: 178249.0000 - fn: 77.0000 - accuracy: 0.9792 - precision: 0.0577 - recall: 0.7475 - auc: 0.9132 - val_loss: 0.2342 - val_tp: 71.0000 - val_fp: 215.0000 - val_tn: 45272.0000 - val_fn: 11.0000 - val_accuracy: 0.9950 - val_precision: 0.2483 - val_recall: 0.8659 - val_auc: 0.9767
Epoch 6/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.4245 - tp: 240.0000 - fp: 4810.0000 - tn: 177161.0000 - fn: 65.0000 - accuracy: 0.9733 - precision: 0.0475 - recall: 0.7869 - auc: 0.9247 - val_loss: 0.2168 - val_tp: 71.0000 - val_fp: 414.0000 - val_tn: 45073.0000 - val_fn: 11.0000 - val_accuracy: 0.9907 - val_precision: 0.1464 - val_recall: 0.8659 - val_auc: 0.9772
Epoch 7/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.3187 - tp: 256.0000 - fp: 5810.0000 - tn: 176161.0000 - fn: 49.0000 - accuracy: 0.9679 - precision: 0.0422 - recall: 0.8393 - auc: 0.9448 - val_loss: 0.2047 - val_tp: 71.0000 - val_fp: 535.0000 - val_tn: 44952.0000 - val_fn: 11.0000 - val_accuracy: 0.9880 - val_precision: 0.1172 - val_recall: 0.8659 - val_auc: 0.9772
Epoch 8/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.3666 - tp: 256.0000 - fp: 6358.0000 - tn: 175613.0000 - fn: 49.0000 - accuracy: 0.9649 - precision: 0.0387 - recall: 0.8393 - auc: 0.9320 - val_loss: 0.2008 - val_tp: 71.0000 - val_fp: 595.0000 - val_tn: 44892.0000 - val_fn: 11.0000 - val_accuracy: 0.9867 - val_precision: 0.1066 - val_recall: 0.8659 - val_auc: 0.9779
Epoch 9/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.3080 - tp: 254.0000 - fp: 6611.0000 - tn: 175360.0000 - fn: 51.0000 - accuracy: 0.9635 - precision: 0.0370 - recall: 0.8328 - auc: 0.9521 - val_loss: 0.1941 - val_tp: 71.0000 - val_fp: 685.0000 - val_tn: 44802.0000 - val_fn: 11.0000 - val_accuracy: 0.9847 - val_precision: 0.0939 - val_recall: 0.8659 - val_auc: 0.9781
Epoch 10/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2565 - tp: 263.0000 - fp: 7082.0000 - tn: 174889.0000 - fn: 42.0000 - accuracy: 0.9609 - precision: 0.0358 - recall: 0.8623 - auc: 0.9594 - val_loss: 0.1907 - val_tp: 73.0000 - val_fp: 746.0000 - val_tn: 44741.0000 - val_fn: 9.0000 - val_accuracy: 0.9834 - val_precision: 0.0891 - val_recall: 0.8902 - val_auc: 0.9790
Epoch 11/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2785 - tp: 263.0000 - fp: 7027.0000 - tn: 174944.0000 - fn: 42.0000 - accuracy: 0.9612 - precision: 0.0361 - recall: 0.8623 - auc: 0.9541 - val_loss: 0.1873 - val_tp: 73.0000 - val_fp: 790.0000 - val_tn: 44697.0000 - val_fn: 9.0000 - val_accuracy: 0.9825 - val_precision: 0.0846 - val_recall: 0.8902 - val_auc: 0.9790
Epoch 12/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.3415 - tp: 256.0000 - fp: 7570.0000 - tn: 174401.0000 - fn: 49.0000 - accuracy: 0.9582 - precision: 0.0327 - recall: 0.8393 - auc: 0.9374 - val_loss: 0.1861 - val_tp: 73.0000 - val_fp: 863.0000 - val_tn: 44624.0000 - val_fn: 9.0000 - val_accuracy: 0.9809 - val_precision: 0.0780 - val_recall: 0.8902 - val_auc: 0.9792
Epoch 13/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2324 - tp: 272.0000 - fp: 7138.0000 - tn: 174833.0000 - fn: 33.0000 - accuracy: 0.9607 - precision: 0.0367 - recall: 0.8918 - auc: 0.9638 - val_loss: 0.1861 - val_tp: 73.0000 - val_fp: 821.0000 - val_tn: 44666.0000 - val_fn: 9.0000 - val_accuracy: 0.9818 - val_precision: 0.0817 - val_recall: 0.8902 - val_auc: 0.9800
Epoch 14/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2341 - tp: 271.0000 - fp: 7125.0000 - tn: 174846.0000 - fn: 34.0000 - accuracy: 0.9607 - precision: 0.0366 - recall: 0.8885 - auc: 0.9662 - val_loss: 0.1851 - val_tp: 74.0000 - val_fp: 830.0000 - val_tn: 44657.0000 - val_fn: 8.0000 - val_accuracy: 0.9816 - val_precision: 0.0819 - val_recall: 0.9024 - val_auc: 0.9800
Epoch 15/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2603 - tp: 269.0000 - fp: 6786.0000 - tn: 175185.0000 - fn: 36.0000 - accuracy: 0.9626 - precision: 0.0381 - recall: 0.8820 - auc: 0.9581 - val_loss: 0.1843 - val_tp: 74.0000 - val_fp: 824.0000 - val_tn: 44663.0000 - val_fn: 8.0000 - val_accuracy: 0.9817 - val_precision: 0.0824 - val_recall: 0.9024 - val_auc: 0.9806
Epoch 16/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2448 - tp: 267.0000 - fp: 6749.0000 - tn: 175222.0000 - fn: 38.0000 - accuracy: 0.9628 - precision: 0.0381 - recall: 0.8754 - auc: 0.9605 - val_loss: 0.1836 - val_tp: 74.0000 - val_fp: 810.0000 - val_tn: 44677.0000 - val_fn: 8.0000 - val_accuracy: 0.9820 - val_precision: 0.0837 - val_recall: 0.9024 - val_auc: 0.9812
Epoch 17/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2779 - tp: 263.0000 - fp: 7040.0000 - tn: 174931.0000 - fn: 42.0000 - accuracy: 0.9611 - precision: 0.0360 - recall: 0.8623 - auc: 0.9531 - val_loss: 0.1799 - val_tp: 74.0000 - val_fp: 866.0000 - val_tn: 44621.0000 - val_fn: 8.0000 - val_accuracy: 0.9808 - val_precision: 0.0787 - val_recall: 0.9024 - val_auc: 0.9814
Epoch 18/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2188 - tp: 272.0000 - fp: 7279.0000 - tn: 174692.0000 - fn: 33.0000 - accuracy: 0.9599 - precision: 0.0360 - recall: 0.8918 - auc: 0.9712 - val_loss: 0.1791 - val_tp: 74.0000 - val_fp: 902.0000 - val_tn: 44585.0000 - val_fn: 8.0000 - val_accuracy: 0.9800 - val_precision: 0.0758 - val_recall: 0.9024 - val_auc: 0.9816
Epoch 19/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2562 - tp: 268.0000 - fp: 6972.0000 - tn: 174999.0000 - fn: 37.0000 - accuracy: 0.9615 - precision: 0.0370 - recall: 0.8787 - auc: 0.9599 - val_loss: 0.1789 - val_tp: 74.0000 - val_fp: 864.0000 - val_tn: 44623.0000 - val_fn: 8.0000 - val_accuracy: 0.9809 - val_precision: 0.0789 - val_recall: 0.9024 - val_auc: 0.9817
Epoch 20/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2241 - tp: 274.0000 - fp: 6996.0000 - tn: 174975.0000 - fn: 31.0000 - accuracy: 0.9614 - precision: 0.0377 - recall: 0.8984 - auc: 0.9671 - val_loss: 0.1800 - val_tp: 74.0000 - val_fp: 862.0000 - val_tn: 44625.0000 - val_fn: 8.0000 - val_accuracy: 0.9809 - val_precision: 0.0791 - val_recall: 0.9024 - val_auc: 0.9818
Epoch 21/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2015 - tp: 273.0000 - fp: 6637.0000 - tn: 175334.0000 - fn: 32.0000 - accuracy: 0.9634 - precision: 0.0395 - recall: 0.8951 - auc: 0.9752 - val_loss: 0.1765 - val_tp: 74.0000 - val_fp: 855.0000 - val_tn: 44632.0000 - val_fn: 8.0000 - val_accuracy: 0.9811 - val_precision: 0.0797 - val_recall: 0.9024 - val_auc: 0.9824
Epoch 22/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2209 - tp: 271.0000 - fp: 6428.0000 - tn: 175543.0000 - fn: 34.0000 - accuracy: 0.9645 - precision: 0.0405 - recall: 0.8885 - auc: 0.9674 - val_loss: 0.1760 - val_tp: 74.0000 - val_fp: 835.0000 - val_tn: 44652.0000 - val_fn: 8.0000 - val_accuracy: 0.9815 - val_precision: 0.0814 - val_recall: 0.9024 - val_auc: 0.9826
Epoch 23/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1654 - tp: 282.0000 - fp: 6397.0000 - tn: 175574.0000 - fn: 23.0000 - accuracy: 0.9648 - precision: 0.0422 - recall: 0.9246 - auc: 0.9817 - val_loss: 0.1763 - val_tp: 74.0000 - val_fp: 773.0000 - val_tn: 44714.0000 - val_fn: 8.0000 - val_accuracy: 0.9829 - val_precision: 0.0874 - val_recall: 0.9024 - val_auc: 0.9827
Epoch 24/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2263 - tp: 271.0000 - fp: 6441.0000 - tn: 175530.0000 - fn: 34.0000 - accuracy: 0.9645 - precision: 0.0404 - recall: 0.8885 - auc: 0.9670 - val_loss: 0.1740 - val_tp: 74.0000 - val_fp: 802.0000 - val_tn: 44685.0000 - val_fn: 8.0000 - val_accuracy: 0.9822 - val_precision: 0.0845 - val_recall: 0.9024 - val_auc: 0.9829
Epoch 25/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2225 - tp: 271.0000 - fp: 6317.0000 - tn: 175654.0000 - fn: 34.0000 - accuracy: 0.9652 - precision: 0.0411 - recall: 0.8885 - auc: 0.9678 - val_loss: 0.1715 - val_tp: 74.0000 - val_fp: 831.0000 - val_tn: 44656.0000 - val_fn: 8.0000 - val_accuracy: 0.9816 - val_precision: 0.0818 - val_recall: 0.9024 - val_auc: 0.9808
Epoch 26/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2039 - tp: 279.0000 - fp: 6674.0000 - tn: 175297.0000 - fn: 26.0000 - accuracy: 0.9632 - precision: 0.0401 - recall: 0.9148 - auc: 0.9691 - val_loss: 0.1716 - val_tp: 74.0000 - val_fp: 798.0000 - val_tn: 44689.0000 - val_fn: 8.0000 - val_accuracy: 0.9823 - val_precision: 0.0849 - val_recall: 0.9024 - val_auc: 0.9833
Epoch 27/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2030 - tp: 276.0000 - fp: 6050.0000 - tn: 175921.0000 - fn: 29.0000 - accuracy: 0.9666 - precision: 0.0436 - recall: 0.9049 - auc: 0.9722 - val_loss: 0.1748 - val_tp: 74.0000 - val_fp: 722.0000 - val_tn: 44765.0000 - val_fn: 8.0000 - val_accuracy: 0.9840 - val_precision: 0.0930 - val_recall: 0.9024 - val_auc: 0.9813
Epoch 28/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2273 - tp: 270.0000 - fp: 6034.0000 - tn: 175937.0000 - fn: 35.0000 - accuracy: 0.9667 - precision: 0.0428 - recall: 0.8852 - auc: 0.9656 - val_loss: 0.1730 - val_tp: 74.0000 - val_fp: 748.0000 - val_tn: 44739.0000 - val_fn: 8.0000 - val_accuracy: 0.9834 - val_precision: 0.0900 - val_recall: 0.9024 - val_auc: 0.9814
Epoch 29/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1998 - tp: 274.0000 - fp: 5768.0000 - tn: 176203.0000 - fn: 31.0000 - accuracy: 0.9682 - precision: 0.0453 - recall: 0.8984 - auc: 0.9758 - val_loss: 0.1718 - val_tp: 74.0000 - val_fp: 783.0000 - val_tn: 44704.0000 - val_fn: 8.0000 - val_accuracy: 0.9826 - val_precision: 0.0863 - val_recall: 0.9024 - val_auc: 0.9813
Epoch 30/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1844 - tp: 277.0000 - fp: 5926.0000 - tn: 176045.0000 - fn: 28.0000 - accuracy: 0.9673 - precision: 0.0447 - recall: 0.9082 - auc: 0.9767 - val_loss: 0.1700 - val_tp: 74.0000 - val_fp: 783.0000 - val_tn: 44704.0000 - val_fn: 8.0000 - val_accuracy: 0.9826 - val_precision: 0.0863 - val_recall: 0.9024 - val_auc: 0.9815
Epoch 31/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1811 - tp: 276.0000 - fp: 6118.0000 - tn: 175853.0000 - fn: 29.0000 - accuracy: 0.9663 - precision: 0.0432 - recall: 0.9049 - auc: 0.9785 - val_loss: 0.1679 - val_tp: 74.0000 - val_fp: 791.0000 - val_tn: 44696.0000 - val_fn: 8.0000 - val_accuracy: 0.9825 - val_precision: 0.0855 - val_recall: 0.9024 - val_auc: 0.9816
Epoch 32/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1889 - tp: 276.0000 - fp: 6230.0000 - tn: 175741.0000 - fn: 29.0000 - accuracy: 0.9657 - precision: 0.0424 - recall: 0.9049 - auc: 0.9774 - val_loss: 0.1677 - val_tp: 74.0000 - val_fp: 777.0000 - val_tn: 44710.0000 - val_fn: 8.0000 - val_accuracy: 0.9828 - val_precision: 0.0870 - val_recall: 0.9024 - val_auc: 0.9817
Epoch 33/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1986 - tp: 280.0000 - fp: 5610.0000 - tn: 176361.0000 - fn: 25.0000 - accuracy: 0.9691 - precision: 0.0475 - recall: 0.9180 - auc: 0.9717 - val_loss: 0.1702 - val_tp: 74.0000 - val_fp: 679.0000 - val_tn: 44808.0000 - val_fn: 8.0000 - val_accuracy: 0.9849 - val_precision: 0.0983 - val_recall: 0.9024 - val_auc: 0.9817
Epoch 34/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1848 - tp: 277.0000 - fp: 5707.0000 - tn: 176264.0000 - fn: 28.0000 - accuracy: 0.9685 - precision: 0.0463 - recall: 0.9082 - auc: 0.9729 - val_loss: 0.1689 - val_tp: 74.0000 - val_fp: 713.0000 - val_tn: 44774.0000 - val_fn: 8.0000 - val_accuracy: 0.9842 - val_precision: 0.0940 - val_recall: 0.9024 - val_auc: 0.9818
Epoch 35/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1910 - tp: 273.0000 - fp: 5492.0000 - tn: 176479.0000 - fn: 32.0000 - accuracy: 0.9697 - precision: 0.0474 - recall: 0.8951 - auc: 0.9747 - val_loss: 0.1665 - val_tp: 74.0000 - val_fp: 748.0000 - val_tn: 44739.0000 - val_fn: 8.0000 - val_accuracy: 0.9834 - val_precision: 0.0900 - val_recall: 0.9024 - val_auc: 0.9820
Epoch 36/100
163840/182276 [=========================>....] - ETA: 0s - loss: 0.1844 - tp: 245.0000 - fp: 5415.0000 - tn: 158154.0000 - fn: 26.0000 - accuracy: 0.9668 - precision: 0.0433 - recall: 0.9041 - auc: 0.9783Restoring model weights from the end of the best epoch.
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1787 - tp: 277.0000 - fp: 6024.0000 - tn: 175947.0000 - fn: 28.0000 - accuracy: 0.9668 - precision: 0.0440 - recall: 0.9082 - auc: 0.9795 - val_loss: 0.1624 - val_tp: 74.0000 - val_fp: 790.0000 - val_tn: 44697.0000 - val_fn: 8.0000 - val_accuracy: 0.9825 - val_precision: 0.0856 - val_recall: 0.9024 - val_auc: 0.9821
Epoch 00036: early stopping

 

訓練履歴を確認する

plot_metrics(weighted_history)

 

メトリクスを評価する

train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
loss :  0.07607786336474939
tp :  92.0
fp :  937.0
tn :  55920.0
fn :  13.0
accuracy :  0.9833222
precision :  0.08940719
recall :  0.8761905
auc :  0.9757879

Legitimate Transactions Detected (True Negatives):  55920
Legitimate Transactions Incorrectly Detected (False Positives):  937
Fraudulent Transactions Missed (False Negatives):  13
Fraudulent Transactions Detected (True Positives):  92
Total Fraudulent Transactions:  105

ここでクラス重みで accuracy と precision はより低いことを見て取れます、何故ならばより多くの false positive があるからです、しかし反対に recall と AUC はより高いです、何故ならばモデルはまたより多くの true positive を見つけたからです。より低い accuracy を持つにもかかわらず、このモデルはより高い recall を持ちます (そしてより多くの詐欺的トランザクションを識別します)。もちろん、両者のタイプのエラーへのコストはあります (非常に多くの正当なトランザクションを詐欺的としてフラグを立ててユーザを困らせることもまた望まないでしょう)。貴方のアプリケーションのためにこれらの異なるタイプのエラーの間のトレードオフを注意深く考えてください。

 

ROC をプロットする

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fc6b41b67f0>

 

オーバーサンプリング

少数 (= minotiry) クラスをオーバーサンプリングする

関連するアプローチは少数クラスをオーバーサンプリングすることによりデータセットを再サンプリングすることです。

pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]

 
NumPy を使用する

positive サンプルからランダムインデックスの適切な数を選択してデータセットを手動で平衡にすることができます :

ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
(181971, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
(363942, 29)

 
tf.data を使用する

tf.data を使用している場合、バランスの取れたサンプルを生成する最も容易な方法は positive と negative データセットから始めて、それらをマージすることです。より多くのサンプルについては tf.data ガイド を見てください。

BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)

各データセットは (feature, label) ペアを生成します :

for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
Features:
 [-2.17940656  1.76763197 -3.92031926  4.88344634 -2.52559973 -0.87402282
 -5.          2.39895859 -2.28526198 -4.70820898  5.         -5.
  2.80944748 -5.         -2.29981762 -5.         -5.         -5.
  0.62439703 -0.30217952  2.21481335  2.15973127 -0.93161994 -0.09709966
 -3.50901907 -0.15144646  0.34072986 -1.87496516 -0.45141544]

Label:  1

experimental.sample_from_datasets を使用して 2 つを一緒にマージします :

resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
0.51220703125

このデータセットを使用するために、エポック毎のステップ数を必要とします。

この場合の「エポック」の定義は曖昧です。例えばそれは各 negative サンプルを一度見るのに必要なバッチ数としましょう :

resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0

 

オーバーサンプリングされたデータ上で訓練する

今はこれらのメソッドがどのように比較できるかを見るためにクラス重みを使用する代わりに再サンプリングされたデータセットでモデルを訓練してみます。

★ Note: positive サンプルを複製することによりデータはバランスが取られましたので、総計のデータセットサイズは大きくなり、そして各エポックはより多くの訓練ステップのために実行されます。

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks = [early_stopping],
    validation_data=val_ds)
Train for 278.0 steps, validate for 23 steps
Epoch 1/100
278/278 [==============================] - 11s 38ms/step - loss: 0.5141 - tp: 235679.0000 - fp: 89290.0000 - tn: 195069.0000 - fn: 49306.0000 - accuracy: 0.7566 - precision: 0.7252 - recall: 0.8270 - auc: 0.8606 - val_loss: 0.2508 - val_tp: 74.0000 - val_fp: 1342.0000 - val_tn: 44145.0000 - val_fn: 8.0000 - val_accuracy: 0.9704 - val_precision: 0.0523 - val_recall: 0.9024 - val_auc: 0.9747
Epoch 2/100
278/278 [==============================] - 8s 30ms/step - loss: 0.2221 - tp: 256642.0000 - fp: 18273.0000 - tn: 266245.0000 - fn: 28184.0000 - accuracy: 0.9184 - precision: 0.9335 - recall: 0.9010 - auc: 0.9667 - val_loss: 0.1341 - val_tp: 74.0000 - val_fp: 976.0000 - val_tn: 44511.0000 - val_fn: 8.0000 - val_accuracy: 0.9784 - val_precision: 0.0705 - val_recall: 0.9024 - val_auc: 0.9772
Epoch 3/100
278/278 [==============================] - 8s 29ms/step - loss: 0.1627 - tp: 263207.0000 - fp: 10666.0000 - tn: 273652.0000 - fn: 21819.0000 - accuracy: 0.9429 - precision: 0.9611 - recall: 0.9234 - auc: 0.9824 - val_loss: 0.0934 - val_tp: 73.0000 - val_fp: 822.0000 - val_tn: 44665.0000 - val_fn: 9.0000 - val_accuracy: 0.9818 - val_precision: 0.0816 - val_recall: 0.8902 - val_auc: 0.9768
Epoch 4/100
278/278 [==============================] - 8s 30ms/step - loss: 0.1373 - tp: 266123.0000 - fp: 8945.0000 - tn: 275364.0000 - fn: 18912.0000 - accuracy: 0.9511 - precision: 0.9675 - recall: 0.9337 - auc: 0.9876 - val_loss: 0.0773 - val_tp: 74.0000 - val_fp: 807.0000 - val_tn: 44680.0000 - val_fn: 8.0000 - val_accuracy: 0.9821 - val_precision: 0.0840 - val_recall: 0.9024 - val_auc: 0.9776
Epoch 5/100
278/278 [==============================] - 8s 30ms/step - loss: 0.1242 - tp: 266967.0000 - fp: 8276.0000 - tn: 276477.0000 - fn: 17624.0000 - accuracy: 0.9545 - precision: 0.9699 - recall: 0.9381 - auc: 0.9902 - val_loss: 0.0669 - val_tp: 74.0000 - val_fp: 771.0000 - val_tn: 44716.0000 - val_fn: 8.0000 - val_accuracy: 0.9829 - val_precision: 0.0876 - val_recall: 0.9024 - val_auc: 0.9766
Epoch 6/100
278/278 [==============================] - 8s 30ms/step - loss: 0.1150 - tp: 268272.0000 - fp: 7786.0000 - tn: 276538.0000 - fn: 16748.0000 - accuracy: 0.9569 - precision: 0.9718 - recall: 0.9412 - auc: 0.9918 - val_loss: 0.0604 - val_tp: 74.0000 - val_fp: 727.0000 - val_tn: 44760.0000 - val_fn: 8.0000 - val_accuracy: 0.9839 - val_precision: 0.0924 - val_recall: 0.9024 - val_auc: 0.9778
Epoch 7/100
278/278 [==============================] - 8s 30ms/step - loss: 0.1076 - tp: 268596.0000 - fp: 7335.0000 - tn: 277522.0000 - fn: 15891.0000 - accuracy: 0.9592 - precision: 0.9734 - recall: 0.9441 - auc: 0.9930 - val_loss: 0.0548 - val_tp: 73.0000 - val_fp: 668.0000 - val_tn: 44819.0000 - val_fn: 9.0000 - val_accuracy: 0.9851 - val_precision: 0.0985 - val_recall: 0.8902 - val_auc: 0.9786
Epoch 8/100
278/278 [==============================] - 9s 31ms/step - loss: 0.1011 - tp: 269934.0000 - fp: 7025.0000 - tn: 277658.0000 - fn: 14727.0000 - accuracy: 0.9618 - precision: 0.9746 - recall: 0.9483 - auc: 0.9940 - val_loss: 0.0505 - val_tp: 73.0000 - val_fp: 621.0000 - val_tn: 44866.0000 - val_fn: 9.0000 - val_accuracy: 0.9862 - val_precision: 0.1052 - val_recall: 0.8902 - val_auc: 0.9792
Epoch 9/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0959 - tp: 270679.0000 - fp: 6711.0000 - tn: 277837.0000 - fn: 14117.0000 - accuracy: 0.9634 - precision: 0.9758 - recall: 0.9504 - auc: 0.9948 - val_loss: 0.0483 - val_tp: 73.0000 - val_fp: 603.0000 - val_tn: 44884.0000 - val_fn: 9.0000 - val_accuracy: 0.9866 - val_precision: 0.1080 - val_recall: 0.8902 - val_auc: 0.9794
Epoch 10/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0913 - tp: 271592.0000 - fp: 6656.0000 - tn: 277917.0000 - fn: 13179.0000 - accuracy: 0.9652 - precision: 0.9761 - recall: 0.9537 - auc: 0.9953 - val_loss: 0.0440 - val_tp: 73.0000 - val_fp: 528.0000 - val_tn: 44959.0000 - val_fn: 9.0000 - val_accuracy: 0.9882 - val_precision: 0.1215 - val_recall: 0.8902 - val_auc: 0.9795
Epoch 11/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0875 - tp: 272110.0000 - fp: 6604.0000 - tn: 278063.0000 - fn: 12567.0000 - accuracy: 0.9663 - precision: 0.9763 - recall: 0.9559 - auc: 0.9957 - val_loss: 0.0408 - val_tp: 73.0000 - val_fp: 491.0000 - val_tn: 44996.0000 - val_fn: 9.0000 - val_accuracy: 0.9890 - val_precision: 0.1294 - val_recall: 0.8902 - val_auc: 0.9799
Epoch 12/100
278/278 [==============================] - 8s 29ms/step - loss: 0.0832 - tp: 272899.0000 - fp: 6475.0000 - tn: 278107.0000 - fn: 11863.0000 - accuracy: 0.9678 - precision: 0.9768 - recall: 0.9583 - auc: 0.9962 - val_loss: 0.0392 - val_tp: 73.0000 - val_fp: 481.0000 - val_tn: 45006.0000 - val_fn: 9.0000 - val_accuracy: 0.9892 - val_precision: 0.1318 - val_recall: 0.8902 - val_auc: 0.9803
Epoch 13/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0795 - tp: 274008.0000 - fp: 6417.0000 - tn: 278228.0000 - fn: 10691.0000 - accuracy: 0.9700 - precision: 0.9771 - recall: 0.9624 - auc: 0.9966 - val_loss: 0.0358 - val_tp: 73.0000 - val_fp: 442.0000 - val_tn: 45045.0000 - val_fn: 9.0000 - val_accuracy: 0.9901 - val_precision: 0.1417 - val_recall: 0.8902 - val_auc: 0.9806
Epoch 14/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0747 - tp: 275586.0000 - fp: 6254.0000 - tn: 277642.0000 - fn: 9862.0000 - accuracy: 0.9717 - precision: 0.9778 - recall: 0.9655 - auc: 0.9970 - val_loss: 0.0336 - val_tp: 73.0000 - val_fp: 432.0000 - val_tn: 45055.0000 - val_fn: 9.0000 - val_accuracy: 0.9903 - val_precision: 0.1446 - val_recall: 0.8902 - val_auc: 0.9808
Epoch 15/100
278/278 [==============================] - 8s 31ms/step - loss: 0.0719 - tp: 274707.0000 - fp: 6051.0000 - tn: 279108.0000 - fn: 9478.0000 - accuracy: 0.9727 - precision: 0.9784 - recall: 0.9666 - auc: 0.9971 - val_loss: 0.0306 - val_tp: 73.0000 - val_fp: 407.0000 - val_tn: 45080.0000 - val_fn: 9.0000 - val_accuracy: 0.9909 - val_precision: 0.1521 - val_recall: 0.8902 - val_auc: 0.9808
Epoch 16/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0697 - tp: 275176.0000 - fp: 6024.0000 - tn: 278945.0000 - fn: 9199.0000 - accuracy: 0.9733 - precision: 0.9786 - recall: 0.9677 - auc: 0.9972 - val_loss: 0.0305 - val_tp: 73.0000 - val_fp: 400.0000 - val_tn: 45087.0000 - val_fn: 9.0000 - val_accuracy: 0.9910 - val_precision: 0.1543 - val_recall: 0.8902 - val_auc: 0.9808
Epoch 17/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0681 - tp: 275922.0000 - fp: 6047.0000 - tn: 278615.0000 - fn: 8760.0000 - accuracy: 0.9740 - precision: 0.9786 - recall: 0.9692 - auc: 0.9973 - val_loss: 0.0277 - val_tp: 73.0000 - val_fp: 368.0000 - val_tn: 45119.0000 - val_fn: 9.0000 - val_accuracy: 0.9917 - val_precision: 0.1655 - val_recall: 0.8902 - val_auc: 0.9808
Epoch 18/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0671 - tp: 276085.0000 - fp: 6258.0000 - tn: 278545.0000 - fn: 8456.0000 - accuracy: 0.9742 - precision: 0.9778 - recall: 0.9703 - auc: 0.9974 - val_loss: 0.0264 - val_tp: 73.0000 - val_fp: 351.0000 - val_tn: 45136.0000 - val_fn: 9.0000 - val_accuracy: 0.9921 - val_precision: 0.1722 - val_recall: 0.8902 - val_auc: 0.9752
Epoch 19/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0652 - tp: 276364.0000 - fp: 6108.0000 - tn: 278761.0000 - fn: 8111.0000 - accuracy: 0.9750 - precision: 0.9784 - recall: 0.9715 - auc: 0.9975 - val_loss: 0.0263 - val_tp: 73.0000 - val_fp: 354.0000 - val_tn: 45133.0000 - val_fn: 9.0000 - val_accuracy: 0.9920 - val_precision: 0.1710 - val_recall: 0.8902 - val_auc: 0.9699
Epoch 20/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0640 - tp: 277072.0000 - fp: 6092.0000 - tn: 278461.0000 - fn: 7719.0000 - accuracy: 0.9757 - precision: 0.9785 - recall: 0.9729 - auc: 0.9975 - val_loss: 0.0247 - val_tp: 73.0000 - val_fp: 333.0000 - val_tn: 45154.0000 - val_fn: 9.0000 - val_accuracy: 0.9925 - val_precision: 0.1798 - val_recall: 0.8902 - val_auc: 0.9645
Epoch 21/100
278/278 [==============================] - 8s 29ms/step - loss: 0.0635 - tp: 277050.0000 - fp: 6229.0000 - tn: 278537.0000 - fn: 7528.0000 - accuracy: 0.9758 - precision: 0.9780 - recall: 0.9735 - auc: 0.9976 - val_loss: 0.0247 - val_tp: 73.0000 - val_fp: 338.0000 - val_tn: 45149.0000 - val_fn: 9.0000 - val_accuracy: 0.9924 - val_precision: 0.1776 - val_recall: 0.8902 - val_auc: 0.9590
Epoch 22/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0619 - tp: 276909.0000 - fp: 6199.0000 - tn: 278826.0000 - fn: 7410.0000 - accuracy: 0.9761 - precision: 0.9781 - recall: 0.9739 - auc: 0.9977 - val_loss: 0.0242 - val_tp: 73.0000 - val_fp: 328.0000 - val_tn: 45159.0000 - val_fn: 9.0000 - val_accuracy: 0.9926 - val_precision: 0.1820 - val_recall: 0.8902 - val_auc: 0.9534
Epoch 23/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0610 - tp: 278124.0000 - fp: 6195.0000 - tn: 278074.0000 - fn: 6951.0000 - accuracy: 0.9769 - precision: 0.9782 - recall: 0.9756 - auc: 0.9977 - val_loss: 0.0231 - val_tp: 73.0000 - val_fp: 311.0000 - val_tn: 45176.0000 - val_fn: 9.0000 - val_accuracy: 0.9930 - val_precision: 0.1901 - val_recall: 0.8902 - val_auc: 0.9536
Epoch 24/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0605 - tp: 277101.0000 - fp: 6262.0000 - tn: 278966.0000 - fn: 7015.0000 - accuracy: 0.9767 - precision: 0.9779 - recall: 0.9753 - auc: 0.9977 - val_loss: 0.0222 - val_tp: 73.0000 - val_fp: 312.0000 - val_tn: 45175.0000 - val_fn: 9.0000 - val_accuracy: 0.9930 - val_precision: 0.1896 - val_recall: 0.8902 - val_auc: 0.9538
Epoch 25/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0596 - tp: 277655.0000 - fp: 6154.0000 - tn: 278556.0000 - fn: 6979.0000 - accuracy: 0.9769 - precision: 0.9783 - recall: 0.9755 - auc: 0.9977 - val_loss: 0.0212 - val_tp: 73.0000 - val_fp: 295.0000 - val_tn: 45192.0000 - val_fn: 9.0000 - val_accuracy: 0.9933 - val_precision: 0.1984 - val_recall: 0.8902 - val_auc: 0.9482
Epoch 26/100
278/278 [==============================] - 8s 30ms/step - loss: 0.0586 - tp: 278013.0000 - fp: 6165.0000 - tn: 278169.0000 - fn: 6997.0000 - accuracy: 0.9769 - precision: 0.9783 - recall: 0.9754 - auc: 0.9978 - val_loss: 0.0217 - val_tp: 72.0000 - val_fp: 292.0000 - val_tn: 45195.0000 - val_fn: 10.0000 - val_accuracy: 0.9934 - val_precision: 0.1978 - val_recall: 0.8780 - val_auc: 0.9483
Epoch 27/100
277/278 [============================>.] - ETA: 0s - loss: 0.0581 - tp: 276446.0000 - fp: 6154.0000 - tn: 277889.0000 - fn: 6807.0000 - accuracy: 0.9772 - precision: 0.9782 - recall: 0.9760 - auc: 0.9978Restoring model weights from the end of the best epoch.
278/278 [==============================] - 8s 30ms/step - loss: 0.0581 - tp: 277470.0000 - fp: 6173.0000 - tn: 278871.0000 - fn: 6830.0000 - accuracy: 0.9772 - precision: 0.9782 - recall: 0.9760 - auc: 0.9978 - val_loss: 0.0209 - val_tp: 72.0000 - val_fp: 281.0000 - val_tn: 45206.0000 - val_fn: 10.0000 - val_accuracy: 0.9936 - val_precision: 0.2040 - val_recall: 0.8780 - val_auc: 0.9485
Epoch 00027: early stopping

訓練プロセスが各勾配更新でデータセット全体を考えていた場合、このオーバーサンプリングは基本的にはクラス重み付けと同値です。

しかしここで行なったように、モデルをバッチ-wise に訓練するとき、オーバーサンプリングされたデータはより滑らかな勾配信号を提供します : 各 positive サンプルが一つのバッチで巨大な重み上で見せられる代わりに、それらは小さい重みで各回に多くの異なるバッチで示されます。

この滑らかな勾配信号はモデルを訓練することを容易にします。

 

訓練履歴を確認する

ここではメトリクスの分布は異なることに注意してください、何故ならば訓練データは検証とテストデータからは大きく異なる分布を持つからです。

plot_metrics(resampled_history )

 

再訓練

訓練はバランスの取れたデータ上では容易ですので、上の訓練手続きは素早く overfit するかもしれません。

そこで callbacks.EarlyStopping にいつ訓練を停止するかについてより良い制御を与えるためにエポックを分割します。

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch = 20,
    epochs=10*EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_ds))
Train for 20 steps, validate for 23 steps
Epoch 1/1000
20/20 [==============================] - 3s 168ms/step - loss: 1.2909 - tp: 10811.0000 - fp: 11937.0000 - tn: 8418.0000 - fn: 9794.0000 - accuracy: 0.4695 - precision: 0.4753 - recall: 0.5247 - auc: 0.4887 - val_loss: 0.9277 - val_tp: 78.0000 - val_fp: 28484.0000 - val_tn: 17003.0000 - val_fn: 4.0000 - val_accuracy: 0.3748 - val_precision: 0.0027 - val_recall: 0.9512 - val_auc: 0.8413
Epoch 2/1000
20/20 [==============================] - 1s 33ms/step - loss: 0.8973 - tp: 14267.0000 - fp: 11350.0000 - tn: 9081.0000 - fn: 6262.0000 - accuracy: 0.5700 - precision: 0.5569 - recall: 0.6950 - auc: 0.6600 - val_loss: 0.8451 - val_tp: 80.0000 - val_fp: 25378.0000 - val_tn: 20109.0000 - val_fn: 2.0000 - val_accuracy: 0.4430 - val_precision: 0.0031 - val_recall: 0.9756 - val_auc: 0.9362
Epoch 3/1000
20/20 [==============================] - 1s 35ms/step - loss: 0.7107 - tp: 15904.0000 - fp: 10547.0000 - tn: 9941.0000 - fn: 4568.0000 - accuracy: 0.6310 - precision: 0.6013 - recall: 0.7769 - auc: 0.7582 - val_loss: 0.7505 - val_tp: 79.0000 - val_fp: 21165.0000 - val_tn: 24322.0000 - val_fn: 3.0000 - val_accuracy: 0.5355 - val_precision: 0.0037 - val_recall: 0.9634 - val_auc: 0.9478
Epoch 4/1000
20/20 [==============================] - 1s 35ms/step - loss: 0.5944 - tp: 16759.0000 - fp: 9572.0000 - tn: 10917.0000 - fn: 3712.0000 - accuracy: 0.6757 - precision: 0.6365 - recall: 0.8187 - auc: 0.8177 - val_loss: 0.6589 - val_tp: 77.0000 - val_fp: 16659.0000 - val_tn: 28828.0000 - val_fn: 5.0000 - val_accuracy: 0.6343 - val_precision: 0.0046 - val_recall: 0.9390 - val_auc: 0.9519
Epoch 5/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.5151 - tp: 17257.0000 - fp: 8350.0000 - tn: 12146.0000 - fn: 3207.0000 - accuracy: 0.7178 - precision: 0.6739 - recall: 0.8433 - auc: 0.8568 - val_loss: 0.5801 - val_tp: 77.0000 - val_fp: 12473.0000 - val_tn: 33014.0000 - val_fn: 5.0000 - val_accuracy: 0.7262 - val_precision: 0.0061 - val_recall: 0.9390 - val_auc: 0.9552
Epoch 6/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.4546 - tp: 17829.0000 - fp: 7203.0000 - tn: 13200.0000 - fn: 2728.0000 - accuracy: 0.7575 - precision: 0.7122 - recall: 0.8673 - auc: 0.8886 - val_loss: 0.5138 - val_tp: 76.0000 - val_fp: 9021.0000 - val_tn: 36466.0000 - val_fn: 6.0000 - val_accuracy: 0.8019 - val_precision: 0.0084 - val_recall: 0.9268 - val_auc: 0.9587
Epoch 7/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.4176 - tp: 17998.0000 - fp: 6179.0000 - tn: 14182.0000 - fn: 2601.0000 - accuracy: 0.7856 - precision: 0.7444 - recall: 0.8737 - auc: 0.9035 - val_loss: 0.4574 - val_tp: 76.0000 - val_fp: 6272.0000 - val_tn: 39215.0000 - val_fn: 6.0000 - val_accuracy: 0.8622 - val_precision: 0.0120 - val_recall: 0.9268 - val_auc: 0.9618
Epoch 8/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.3846 - tp: 18020.0000 - fp: 5174.0000 - tn: 15298.0000 - fn: 2468.0000 - accuracy: 0.8134 - precision: 0.7769 - recall: 0.8795 - auc: 0.9166 - val_loss: 0.4109 - val_tp: 75.0000 - val_fp: 4351.0000 - val_tn: 41136.0000 - val_fn: 7.0000 - val_accuracy: 0.9044 - val_precision: 0.0169 - val_recall: 0.9146 - val_auc: 0.9651
Epoch 9/1000
20/20 [==============================] - 1s 37ms/step - loss: 0.3631 - tp: 17972.0000 - fp: 4520.0000 - tn: 16011.0000 - fn: 2457.0000 - accuracy: 0.8297 - precision: 0.7990 - recall: 0.8797 - auc: 0.9230 - val_loss: 0.3704 - val_tp: 74.0000 - val_fp: 3021.0000 - val_tn: 42466.0000 - val_fn: 8.0000 - val_accuracy: 0.9335 - val_precision: 0.0239 - val_recall: 0.9024 - val_auc: 0.9674
Epoch 10/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.3378 - tp: 17994.0000 - fp: 3772.0000 - tn: 16815.0000 - fn: 2379.0000 - accuracy: 0.8498 - precision: 0.8267 - recall: 0.8832 - auc: 0.9308 - val_loss: 0.3381 - val_tp: 74.0000 - val_fp: 2323.0000 - val_tn: 43164.0000 - val_fn: 8.0000 - val_accuracy: 0.9488 - val_precision: 0.0309 - val_recall: 0.9024 - val_auc: 0.9693
Epoch 11/1000
20/20 [==============================] - 1s 33ms/step - loss: 0.3213 - tp: 18268.0000 - fp: 3260.0000 - tn: 17041.0000 - fn: 2391.0000 - accuracy: 0.8620 - precision: 0.8486 - recall: 0.8843 - auc: 0.9355 - val_loss: 0.3113 - val_tp: 74.0000 - val_fp: 1888.0000 - val_tn: 43599.0000 - val_fn: 8.0000 - val_accuracy: 0.9584 - val_precision: 0.0377 - val_recall: 0.9024 - val_auc: 0.9715
Epoch 12/1000
20/20 [==============================] - 1s 34ms/step - loss: 0.3040 - tp: 17977.0000 - fp: 2813.0000 - tn: 17832.0000 - fn: 2338.0000 - accuracy: 0.8742 - precision: 0.8647 - recall: 0.8849 - auc: 0.9420 - val_loss: 0.2880 - val_tp: 74.0000 - val_fp: 1621.0000 - val_tn: 43866.0000 - val_fn: 8.0000 - val_accuracy: 0.9643 - val_precision: 0.0437 - val_recall: 0.9024 - val_auc: 0.9733
Epoch 13/1000
20/20 [==============================] - 1s 34ms/step - loss: 0.2907 - tp: 18262.0000 - fp: 2594.0000 - tn: 17824.0000 - fn: 2280.0000 - accuracy: 0.8810 - precision: 0.8756 - recall: 0.8890 - auc: 0.9460 - val_loss: 0.2679 - val_tp: 74.0000 - val_fp: 1426.0000 - val_tn: 44061.0000 - val_fn: 8.0000 - val_accuracy: 0.9685 - val_precision: 0.0493 - val_recall: 0.9024 - val_auc: 0.9744
Epoch 14/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2803 - tp: 18314.0000 - fp: 2190.0000 - tn: 18192.0000 - fn: 2264.0000 - accuracy: 0.8913 - precision: 0.8932 - recall: 0.8900 - auc: 0.9491 - val_loss: 0.2503 - val_tp: 74.0000 - val_fp: 1330.0000 - val_tn: 44157.0000 - val_fn: 8.0000 - val_accuracy: 0.9706 - val_precision: 0.0527 - val_recall: 0.9024 - val_auc: 0.9752
Epoch 15/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2679 - tp: 18275.0000 - fp: 1967.0000 - tn: 18471.0000 - fn: 2247.0000 - accuracy: 0.8971 - precision: 0.9028 - recall: 0.8905 - auc: 0.9528 - val_loss: 0.2350 - val_tp: 74.0000 - val_fp: 1258.0000 - val_tn: 44229.0000 - val_fn: 8.0000 - val_accuracy: 0.9722 - val_precision: 0.0556 - val_recall: 0.9024 - val_auc: 0.9759
Epoch 16/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2575 - tp: 18217.0000 - fp: 1735.0000 - tn: 18790.0000 - fn: 2218.0000 - accuracy: 0.9035 - precision: 0.9130 - recall: 0.8915 - auc: 0.9558 - val_loss: 0.2215 - val_tp: 73.0000 - val_fp: 1194.0000 - val_tn: 44293.0000 - val_fn: 9.0000 - val_accuracy: 0.9736 - val_precision: 0.0576 - val_recall: 0.8902 - val_auc: 0.9766
Epoch 17/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2514 - tp: 18223.0000 - fp: 1700.0000 - tn: 18896.0000 - fn: 2141.0000 - accuracy: 0.9062 - precision: 0.9147 - recall: 0.8949 - auc: 0.9582 - val_loss: 0.2085 - val_tp: 73.0000 - val_fp: 1136.0000 - val_tn: 44351.0000 - val_fn: 9.0000 - val_accuracy: 0.9749 - val_precision: 0.0604 - val_recall: 0.8902 - val_auc: 0.9767
Epoch 18/1000
20/20 [==============================] - 1s 37ms/step - loss: 0.2424 - tp: 18402.0000 - fp: 1466.0000 - tn: 18902.0000 - fn: 2190.0000 - accuracy: 0.9107 - precision: 0.9262 - recall: 0.8936 - auc: 0.9601 - val_loss: 0.1980 - val_tp: 73.0000 - val_fp: 1105.0000 - val_tn: 44382.0000 - val_fn: 9.0000 - val_accuracy: 0.9756 - val_precision: 0.0620 - val_recall: 0.8902 - val_auc: 0.9771
Epoch 19/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2358 - tp: 18413.0000 - fp: 1449.0000 - tn: 18924.0000 - fn: 2174.0000 - accuracy: 0.9115 - precision: 0.9270 - recall: 0.8944 - auc: 0.9624 - val_loss: 0.1889 - val_tp: 73.0000 - val_fp: 1074.0000 - val_tn: 44413.0000 - val_fn: 9.0000 - val_accuracy: 0.9762 - val_precision: 0.0636 - val_recall: 0.8902 - val_auc: 0.9772
Epoch 20/1000
20/20 [==============================] - 1s 33ms/step - loss: 0.2254 - tp: 18587.0000 - fp: 1249.0000 - tn: 19002.0000 - fn: 2122.0000 - accuracy: 0.9177 - precision: 0.9370 - recall: 0.8975 - auc: 0.9653 - val_loss: 0.1801 - val_tp: 73.0000 - val_fp: 1040.0000 - val_tn: 44447.0000 - val_fn: 9.0000 - val_accuracy: 0.9770 - val_precision: 0.0656 - val_recall: 0.8902 - val_auc: 0.9771
Epoch 21/1000
20/20 [==============================] - 1s 34ms/step - loss: 0.2200 - tp: 18517.0000 - fp: 1188.0000 - tn: 19205.0000 - fn: 2050.0000 - accuracy: 0.9209 - precision: 0.9397 - recall: 0.9003 - auc: 0.9675 - val_loss: 0.1720 - val_tp: 73.0000 - val_fp: 1024.0000 - val_tn: 44463.0000 - val_fn: 9.0000 - val_accuracy: 0.9773 - val_precision: 0.0665 - val_recall: 0.8902 - val_auc: 0.9775
Epoch 22/1000
20/20 [==============================] - 1s 34ms/step - loss: 0.2154 - tp: 18254.0000 - fp: 1078.0000 - tn: 19556.0000 - fn: 2072.0000 - accuracy: 0.9231 - precision: 0.9442 - recall: 0.8981 - auc: 0.9685 - val_loss: 0.1645 - val_tp: 73.0000 - val_fp: 1007.0000 - val_tn: 44480.0000 - val_fn: 9.0000 - val_accuracy: 0.9777 - val_precision: 0.0676 - val_recall: 0.8902 - val_auc: 0.9773
Epoch 23/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2092 - tp: 18465.0000 - fp: 1074.0000 - tn: 19455.0000 - fn: 1966.0000 - accuracy: 0.9258 - precision: 0.9450 - recall: 0.9038 - auc: 0.9706 - val_loss: 0.1588 - val_tp: 74.0000 - val_fp: 1022.0000 - val_tn: 44465.0000 - val_fn: 8.0000 - val_accuracy: 0.9774 - val_precision: 0.0675 - val_recall: 0.9024 - val_auc: 0.9776
Epoch 24/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2034 - tp: 18590.0000 - fp: 1025.0000 - tn: 19414.0000 - fn: 1931.0000 - accuracy: 0.9278 - precision: 0.9477 - recall: 0.9059 - auc: 0.9720 - val_loss: 0.1532 - val_tp: 74.0000 - val_fp: 1011.0000 - val_tn: 44476.0000 - val_fn: 8.0000 - val_accuracy: 0.9776 - val_precision: 0.0682 - val_recall: 0.9024 - val_auc: 0.9772
Epoch 25/1000
20/20 [==============================] - 1s 33ms/step - loss: 0.1999 - tp: 18484.0000 - fp: 1076.0000 - tn: 19531.0000 - fn: 1869.0000 - accuracy: 0.9281 - precision: 0.9450 - recall: 0.9082 - auc: 0.9736 - val_loss: 0.1472 - val_tp: 74.0000 - val_fp: 1009.0000 - val_tn: 44478.0000 - val_fn: 8.0000 - val_accuracy: 0.9777 - val_precision: 0.0683 - val_recall: 0.9024 - val_auc: 0.9774
Epoch 26/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1961 - tp: 18723.0000 - fp: 1048.0000 - tn: 19357.0000 - fn: 1832.0000 - accuracy: 0.9297 - precision: 0.9470 - recall: 0.9109 - auc: 0.9739 - val_loss: 0.1415 - val_tp: 74.0000 - val_fp: 989.0000 - val_tn: 44498.0000 - val_fn: 8.0000 - val_accuracy: 0.9781 - val_precision: 0.0696 - val_recall: 0.9024 - val_auc: 0.9772
Epoch 27/1000
20/20 [==============================] - 1s 38ms/step - loss: 0.1890 - tp: 18580.0000 - fp: 990.0000 - tn: 19537.0000 - fn: 1853.0000 - accuracy: 0.9306 - precision: 0.9494 - recall: 0.9093 - auc: 0.9760 - val_loss: 0.1367 - val_tp: 74.0000 - val_fp: 976.0000 - val_tn: 44511.0000 - val_fn: 8.0000 - val_accuracy: 0.9784 - val_precision: 0.0705 - val_recall: 0.9024 - val_auc: 0.9773
Epoch 28/1000
20/20 [==============================] - 1s 33ms/step - loss: 0.1835 - tp: 18707.0000 - fp: 921.0000 - tn: 19609.0000 - fn: 1723.0000 - accuracy: 0.9354 - precision: 0.9531 - recall: 0.9157 - auc: 0.9777 - val_loss: 0.1316 - val_tp: 74.0000 - val_fp: 967.0000 - val_tn: 44520.0000 - val_fn: 8.0000 - val_accuracy: 0.9786 - val_precision: 0.0711 - val_recall: 0.9024 - val_auc: 0.9774
Epoch 29/1000
20/20 [==============================] - 1s 34ms/step - loss: 0.1820 - tp: 18835.0000 - fp: 862.0000 - tn: 19527.0000 - fn: 1736.0000 - accuracy: 0.9366 - precision: 0.9562 - recall: 0.9156 - auc: 0.9776 - val_loss: 0.1272 - val_tp: 74.0000 - val_fp: 936.0000 - val_tn: 44551.0000 - val_fn: 8.0000 - val_accuracy: 0.9793 - val_precision: 0.0733 - val_recall: 0.9024 - val_auc: 0.9774
Epoch 30/1000
20/20 [==============================] - 1s 35ms/step - loss: 0.1764 - tp: 18796.0000 - fp: 844.0000 - tn: 19677.0000 - fn: 1643.0000 - accuracy: 0.9393 - precision: 0.9570 - recall: 0.9196 - auc: 0.9791 - val_loss: 0.1234 - val_tp: 74.0000 - val_fp: 931.0000 - val_tn: 44556.0000 - val_fn: 8.0000 - val_accuracy: 0.9794 - val_precision: 0.0736 - val_recall: 0.9024 - val_auc: 0.9769
Epoch 31/1000
20/20 [==============================] - 1s 33ms/step - loss: 0.1727 - tp: 18808.0000 - fp: 793.0000 - tn: 19706.0000 - fn: 1653.0000 - accuracy: 0.9403 - precision: 0.9595 - recall: 0.9192 - auc: 0.9800 - val_loss: 0.1200 - val_tp: 74.0000 - val_fp: 909.0000 - val_tn: 44578.0000 - val_fn: 8.0000 - val_accuracy: 0.9799 - val_precision: 0.0753 - val_recall: 0.9024 - val_auc: 0.9773
Epoch 32/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1704 - tp: 18934.0000 - fp: 824.0000 - tn: 19610.0000 - fn: 1592.0000 - accuracy: 0.9410 - precision: 0.9583 - recall: 0.9224 - auc: 0.9807 - val_loss: 0.1171 - val_tp: 74.0000 - val_fp: 907.0000 - val_tn: 44580.0000 - val_fn: 8.0000 - val_accuracy: 0.9799 - val_precision: 0.0754 - val_recall: 0.9024 - val_auc: 0.9772
Epoch 33/1000
18/20 [==========================>...] - ETA: 0s - loss: 0.1659 - tp: 17032.0000 - fp: 717.0000 - tn: 17676.0000 - fn: 1439.0000 - accuracy: 0.9415 - precision: 0.9596 - recall: 0.9221 - auc: 0.9817Restoring model weights from the end of the best epoch.
20/20 [==============================] - 1s 34ms/step - loss: 0.1654 - tp: 18936.0000 - fp: 788.0000 - tn: 19642.0000 - fn: 1594.0000 - accuracy: 0.9418 - precision: 0.9600 - recall: 0.9224 - auc: 0.9818 - val_loss: 0.1142 - val_tp: 74.0000 - val_fp: 908.0000 - val_tn: 44579.0000 - val_fn: 8.0000 - val_accuracy: 0.9799 - val_precision: 0.0754 - val_recall: 0.9024 - val_auc: 0.9771
Epoch 00033: early stopping

 

訓練履歴を再確認する

plot_metrics(resampled_history)

 

メトリクスを評価する

train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
loss :  0.1573299011001384
tp :  92.0
fp :  1193.0
tn :  55664.0
fn :  13.0
accuracy :  0.978828
precision :  0.07159533
recall :  0.8761905
auc :  0.96230674

Legitimate Transactions Detected (True Negatives):  55920
Legitimate Transactions Incorrectly Detected (False Positives):  937
Fraudulent Transactions Missed (False Negatives):  13
Fraudulent Transactions Detected (True Positives):  92
Total Fraudulent Transactions:  105

 

ROC をプロットする

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_roc("Train Resampled", train_labels, train_predictions_resampled,  color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled,  color=colors[2], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fc67a751588>

 

このチュートリアルを貴方の問題に適用する

不均衡なデータ分類は本質的に困難なタスクです、何故ならばそこから学習する非常に少ないサンプルがあるからです。貴方は常に最初にデータから開始してできる限り多くのサンプルを収集することに最善をつくしてそしてどの特徴が関連するかもしれないのかをしっかりと念頭におくべきです、そうすればモデルは貴方の最小クラスを最大限に活用できます。ある点で貴方のモデルは改良されて望む結果を生成するために努力するかもしれません、そして貴方の問題のコンテキストとエラーの異なるタイプの間のトレードオフを念頭に置くことは重要です。

 

以上






クラスキャット

最近の投稿

  • LangGraph 0.5 : エージェント開発 : エージェント・アーキテクチャ
  • LangGraph 0.5 : エージェント開発 : ワークフローとエージェント
  • LangGraph 0.5 : エージェント開発 : エージェントの実行
  • LangGraph 0.5 : エージェント開発 : prebuilt コンポーネントを使用したエージェント開発
  • LangGraph 0.5 : Get started : ローカルサーバの実行

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (24) LangGraph 0.5 (9) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2019年11月
月 火 水 木 金 土 日
 123
45678910
11121314151617
18192021222324
252627282930  
« 10月   12月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme