TensorFlow 2.0 : Beginner Tutorials : Estimator :- Keras モデルから Estimator を作成する (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/11/2019
* 本ページは、TensorFlow org サイトの TF 2.0 – Beginner Tutorials – Estimator の以下のページを
翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
| 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
| E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
| Facebook: https://www.facebook.com/ClassCatJP/ |
Estimator :- Keras モデルから Estimator を作成する
概要
TensorFlow Estimator は TensorFlow で完全にサポートされます、そして新しいそして既存の tf.keras モデルから作成できます。このチュートリアルはそのプロセスの完全で、最小限のサンプルを含みます。
セットアップ
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf import numpy as np import tensorflow_datasets as tfds
単純な Keras モデルを作成する
Keras では、モデルを構築するために層をアセンブルします。モデルは (通常は) 層のグラフです。モデルの最も一般的なタイプは層のスタック : tf.keras.Sequential モデルです。
単純な、完全結合ネットワーク (i.e. 多層パーセプトロン) を構築するには :
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(1, activation='sigmoid')
])
モデルをコンパイルして要約を得ます。
model.compile(loss='categorical_crossentropy', optimizer='adam') model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 16) 80 _________________________________________________________________ dropout (Dropout) (None, 16) 0 _________________________________________________________________ dense_1 (Dense) (None, 1) 17 ================================================================= Total params: 97 Trainable params: 97 Non-trainable params: 0 _________________________________________________________________
入力関数を作成する
巨大なデータセットやマルチデバイス訓練にスケールするために Datasets API を使用します。
Estimator はそれらの入力パイプラインがいつそしてどのように構築されるかの制御を必要とします。これを可能にするために、それらは「入力関数」あるいは input_fn を必要とします。Estimator はこの関数を引数なしで呼び出します。input_fn は tf.data.Dataset を返さなければなりません。
def input_fn():
split = tfds.Split.TRAIN
dataset = tfds.load('iris', split=split, as_supervised=True)
dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
dataset = dataset.batch(32).repeat()
return dataset
貴方の input_fn を試します。
for features_batch, labels_batch in input_fn().take(1): print(features_batch) print(labels_batch)
Downloading and preparing dataset iris (4.44 KiB) to /home/kbuilder/tensorflow_datasets/iris/1.0.0...
HBox(children=(IntProgress(value=1, bar_style='info', description='Dl Completed...', max=1, style=ProgressStyl…
HBox(children=(IntProgress(value=1, bar_style='info', description='Dl Size...', max=1, style=ProgressStyle(des…
/home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:1004: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
InsecureRequestWarning,
HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))
HBox(children=(IntProgress(value=0, description='Shuffling...', max=1, style=ProgressStyle(description_width='…
WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and:
`tf.data.TFRecordDataset(path)`
WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and:
`tf.data.TFRecordDataset(path)`
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=150, style=ProgressStyle(description_width='…
Dataset iris downloaded and prepared to /home/kbuilder/tensorflow_datasets/iris/1.0.0. Subsequent calls will reuse this data.
{'dense_input': }
tf.Tensor([1 0 2 1 1 0 1 2 1 1 2 0 0 0 0 1 2 1 1 2 0 2 0 2 2 2 2 2 0 0 0 0], shape=(32,), dtype=int64)
tf.keras モデルから Estimator を作成する
tf.keras.Model は tf.keras.estimator.model_to_estimator でモデルを tf.estimator.Estimator オブジェクトに変換することにより tf.estimator API で訓練できます。
model_dir = "/tmp/tfkeras_example/"
keras_estimator = tf.keras.estimator.model_to_estimator(
keras_model=model, model_dir=model_dir)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using the Keras model provided.
INFO:tensorflow:Using the Keras model provided.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Using config: {'_num_ps_replicas': 0, '_session_creation_timeout_secs': 7200, '_global_id_in_cluster': 0, '_save_checkpoints_secs': 600, '_save_checkpoints_steps': None, '_protocol': None, '_log_step_count_steps': 100, '_num_worker_replicas': 1, '_master': '', '_task_id': 0, '_train_distribute': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_is_chief': True, '_experimental_max_worker_delay_secs': None, '_experimental_distribute': None, '_service': None, '_task_type': 'worker', '_eval_distribute': None, '_tf_random_seed': None, '_evaluation_master': '', '_cluster_spec': , '_device_fn': None, '_save_summary_steps': 100, '_model_dir': '/tmp/tfkeras_example/', '_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
}
estimator を訓練して評価します。
keras_estimator.train(input_fn=input_fn, steps=25)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tfkeras_example/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tfkeras_example/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmp/tfkeras_example/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmp/tfkeras_example/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 4 variables.
INFO:tensorflow:Warm-started 4 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tfkeras_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tfkeras_example/model.ckpt.
INFO:tensorflow:loss = 118.05543, step = 0
INFO:tensorflow:loss = 118.05543, step = 0
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tfkeras_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tfkeras_example/model.ckpt.
INFO:tensorflow:Loss for final step: 84.80701.
INFO:tensorflow:Loss for final step: 84.80701.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-10-01T01:26:35Z
INFO:tensorflow:Starting evaluation at 2019-10-01T01:26:35Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tfkeras_example/model.ckpt-25
INFO:tensorflow:Restoring parameters from /tmp/tfkeras_example/model.ckpt-25
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Finished evaluation at 2019-10-01-01:26:36
INFO:tensorflow:Finished evaluation at 2019-10-01-01:26:36
INFO:tensorflow:Saving dict for global step 25: global_step = 25, loss = 101.27602
INFO:tensorflow:Saving dict for global step 25: global_step = 25, loss = 101.27602
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tfkeras_example/model.ckpt-25
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tfkeras_example/model.ckpt-25
Eval result: {'global_step': 25, 'loss': 101.27602}
以上