TensorFlow : How To : 新しい Op を追加する (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
更新日時 : 09/16/2017
作成日時 : 03/04/2016
* 本ページは、TensorFlow 本家サイトの Extend – Adding a New Op を翻訳した上で
適宜、補足説明したものです:
* (obsolete, リンク切れ) 本ページは、TensorFlow の本家サイトの How To – Adding a New Op を翻訳した上で
適宜、補足説明したものです:
https://www.tensorflow.org/versions/master/how_tos/adding_an_op/index.html#adding-a-new-op
* サンプルコードの動作確認はしておりますが、適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
既存のライブラリでカバーされていない演算を組み入れたい (incorporate) のであれば、
カスタム Op を作成することがてきます。カスタム Op を組み入れるためには、以下が必要となるでしょう :
- C++ ファイルの新しい Op を登録します。Op 登録は実装からは独立で、どのように Op が呼び出される (invoke) かの semantics を記述します。例えば、それは Op 名を定義して、その入出力を指定します。
- Op を C++ で実装します。この実装は “カーネル” と呼ばれ、異なるアーキテクチャ(例えば CPU、GPU)や入出力タイプのために複数のカーネルがあってもかまいません。
- オプションで、Python ラッパーを作成します。このラッパーは Op を作成するための public API です。デフォルト・ラッパーは Op 登録から生成され、これは直接使用できますし、追加することもできます。
- オプションで、Op のための勾配を計算する関数を書きます。
- オプションで、Op のための入出力形状 (shape) を記述する関数を書きます。これは貴方の Op と動作するための形状推論を可能にします。
- Op をテストします、典型的には Python で。もし貴方が勾配を定義するならば、Python GradientChecker でそれらを検証できます。
Op のインターフェイスを定義する
TensorFlow システムに登録することで Op のインターフェイスを定義します。登録においては、貴方の Op の名前、その入力(型と名前)と出力(型と名前)、そして Op が必要とするかもしれない docstrings と attrs を指定します。
これがどのように動作するかを見るために、int32s のテンソルを取り、最初の要素を除いてゼロに設定されたテンソルのコピーを出力する Op を作成したいと仮定しましょう。ファイル tensorflow/core/user_ops/zero_out.cc を作成してそのような Op のためのインターフェイスを定義する REGISTER_OP マクロへの呼び出しを追加します。
#include "tensorflow/core/framework/op.h" REGISTER_OP("ZeroOut") .Input("to_zero: int32") .Output("zeroed: int32");
この ZeroOut Op は入力として 32-bit 整数の一つのテンソル to_zero を取り、32-bit 整数のテンソル zeroed を出力します。
ネーミングについてのノート : Ops の名前は一意 (unique) でキャメルケースであるべきです。アンダースコア (_) で始まる名前は内部ユースのために予約されています。
Op のためのカーネルを実装する
インターフェイスを定義した後は、Op の一つまたはそれ以上の実装を提供します。これらのカーネルを作成するためには、OpKernel を拡張したクラスを作成してCompute メソッドをオーバーライドします。Compute メソッドは型 OpKernelContext* の一つの context 引数を提供し、そこから入出力テンソルのような有用なものにアクセスできます。
上で作成したファイルにカーネルを追加します。カーネルはこのようなものに見えるでしょう:
#include "tensorflow/core/framework/op_kernel.h" using namespace tensorflow; class ZeroOutOp : public OpKernel { public: explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { // 入力テンソルを取得します const Tensor& input_tensor = context->input(0); auto input = input_tensor.flat<int32>(); // 出力テンソルを作成します Tensor* output_tensor = NULL; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output = output_tensor->template flat<int32>(); // 出力テンソルの最初の要素以外の全ては 0 にセットされます。 const int N = input.size(); for (int i = 1; i < N; i++) { output(i) = 0; } // 可能ならば最初の入力値を保存します。 if (N > 0) output(0) = input(0); } };
カーネルを実装した後は、それを TensorFlow システムで登録します。登録においては、異なる束縛 (constraints)、そこでこのカーネルが動作します、を指定します。例えば、CPU のために作られた一つのカーネルがあり、そして GPU のために別の一つを持つかもしれません。
ZeroOut op のためにこれを行なうには、zero_out.cc に次を追加します :
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
Op ライブラリをビルドする
With TensorFlow バイナリ・インストレーション
貴方のシステム上で利用可能な g++ あるいは clang のような C++ コンパイラで zero_out.cc をコンパイルできるべきです。バイナリ PIP パッケージは、Op をコンパイルするのに必要なヘッダファイルとライブラリをシステム固有の場所にインストールします。しかしながら、TensorFlow python ライブラリはヘッダとライブラリ・ディレクトリをそれぞれ取得するための関数 get_include と get_lib を提供します。
Ubuntu マシン上のそれらの関数の出力をここに示します。
$ python >>> import tensorflow as tf >>> tf.sysconfig.get_include() '/usr/local/lib/python2.7/site-packages/tensorflow/include' >>> tf.sysconfig.get_lib() '/usr/local/lib/python2.7/site-packages/tensorflow/core' >>>
g++ がインストールされていると仮定して、ここに、Op をダイナミックライブラリにコンパイルするために使用できるコマンドのシークエンスを示します。
$ TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') $ TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') $ g++ -std=c++11 -shared zero_out.cc -o zero_out.so \ -I $TF_INC -l tensorflow_framework -L $TF_LIB \ -fPIC -Wl,-rpath $TF_LIB
With TensorFlow ソース・インストレーション
TensorFlow ソース・インストールをしているのであれば、Op をコンパイルするために TensorFlow のビルドシステムが使用できます。次の Bazel ビルド・ルールを持つ BUILD ファイルを tensorflow/core/user_ops ディレクトリに配置してください。
cc_binary( name = "zero_out.so", srcs = ["zero_out.cc"], linkopts = [ "-Wl,-Bsymbolic", "-lm", ], linkshared = 1, linkstatic = 1, deps = [ "//third_party/tensorflow/core:framework", ], )
zero_out.so をビルドするために次のコマンドを実行します。
$ bazel build -c opt //tensorflow/core/user_ops:zero_out.so
Python で Op を使用する
TensorFlow Python API は ダイナミックライブラリをロードして Op を TensorFlow フレームワークで登録するために load_op_library 関数を提供します。load_op_library は Python モジュールを返し、これは Op への Python ラッパーを含みます。こうして、ひとたび Op をビルドすれば、Python からそれを実行するために次を行なうことができます :
import tensorflow as tf zero_out_module = tf.load_op_library('zero_out.so') with tf.Session(''): zero_out_module.zero_out([[1, 2], [3, 4]]).eval() # Prints array([[1, 0], [0, 0]], dtype=int32)
ノート : 生成された関数は(PEP8 に応じた)snake_case 名が与えられます。そのためもし op が C++ ファイルで ZeroOut と命名されているのであれば、python 関数は zero_out と呼称されます。
Op を Python モジュールから import 可能な標準関数として利用可能にするためには、
次のように Python ソースファイルで load_op_library 呼び出しを持つことは多分有用でしょう(zero_out_op_1.py を見てください):
import tensorflow as tf _zero_out_module = tf.load_op_library('zero_out_op_kernel_1.so') zero_out = _zero_out_module.zero_out
動作することを検証する
Op が成功的に実装されたかを検証する良い方法はそのためのテストを書くとです。次の内容でファイル tensorflow/python/kernel_tests/zero_out_op_test.py を作成します :
import tensorflow as tf class ZeroOutTest(tf.test.TestCase): def testZeroOut(self): zero_out_module = tf.load_op_library('zero_out.so') with self.test_session(): result = zero_out_module.zero_out([5, 4, 3, 2, 1]) self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
そしてテストを実行します :
$ bazel test tensorflow/python:zero_out_op_test
妥当性確認 (Validation)
上の例は Op が任意の形状に適用されることを仮定していました。それがベクタだけに適用されるとしたらどうでしょう?これは上の OpKernel 実装にチェックを追加することを意味しています。
void Compute(OpKernelContext* context) override { // 入力テンソルを取得する const Tensor& input_tensor = context->input(0); OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()), errors::InvalidArgument("ZeroOut expects a 1-D vector.")); // ... }
これは入力がベクタであると assert し、もしそうでないならば InvalidArgument ステータスを設定して返します。OP_REQUIRES マクロ は3つの引数を取ります :
- コンテキスト (context)、これは OpKernelContext か OpKernelConstruction ポインタ (tensorflow/core/framework/op_kernel.h 参照) で、SetStatus() メソッドのためです。
- 条件。例えば、テンソルの形状を確認するための関数が tensorflow/core/framework/tensor_shape.h にあります。
- エラー自身、これは Status オブジェクトで表されます、tensorflow/core/lib/core/status.h を見てください。Status は型(しばしば InvalidArgument、しかし型リスト参照)とメッセージの両方を持ちます。エラーを構築するための関数は tensorflow/core/lib/core/errors.h で見つかるでしょう。
他の選択肢として、もしある関数から返された Status オブジェクトがエラーかどうかテストして、もうそうならばそれを返したいのであれば、OP_REQUIRES_OK を使います。これらマクロ両方はエラー時の関数から返ります。
Op 登録
Attrs
Ops は attrs を持てます、この値は Op がグラフに追加された時に設定されます。これらは Op を構成するために使用されて、そしてその値はカーネル実装の内部と Op 登録の入出力の型の中の両方でアクセスできます。可能な時には attr の代わりに入力を使用することが好ましいです、何故なら入力がより柔軟だからです。それらはfeed を使用するように設定する等、全てのステップを変更できます。attr は入力ではできない事柄に対して使用されます : signature(入出力の数や型)に影響したり step-to-step からは変更できない任意の構成です。
Op を登録する時に、Attr メソッドを使用してその名前と型を指定することで attr を定義します、これは次の形式の仕様が期待されます :
<name>: <attr-type-expr>
ここで <name> は文字で始まり、英数字とアンダースコアからなります、そして <attr-type-expr> は下で記述される形式の型の式 (type expression) です。
例えば、ZeroOut Op をユーザ指定 index として preserve (保存) する場合には、0 番目の要素のみの代わりに、Op を次のように登録できます :
REGISTER_OP("ZeroOut") .Attr("preserve_index: int") .Input("to_zero: int32") .Output("zeroed: int32");
そうすれば貴方のカーネルはそのコンストラクタで context パラメータを通してこの attr にアクセスすることができます:
class ZeroOutOp : public OpKernel { public: explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) { // 保存 (preserve) するための値の index を得る。 OP_REQUIRES_OK(context, context->GetAttr("preserve_index", &preserve_index_)); // preserve_index が正値であることを確認する。 OP_REQUIRES(context, preserve_index_ >= 0, errors::InvalidArgument("Need preserve_index >= 0, got ", preserve_index_)); } void Compute(OpKernelContext* context) override { // ... } private: int preserve_index_; };
そしてそれは Compute メソッドで使用できます :
void Compute(OpKernelContext* context) override { // ... // preserve_index is が範囲にあることを確認する。 OP_REQUIRES(context, preserve_index_ < input.dimension(0), errors::InvalidArgument("preserve_index out of range")); // 出力テンソルの全ての要素を 0 に設定する。 const int N = input.size(); for (int i = 0; i < N; i++) { output_flat(i) = 0; } // 要求された入力値を保存します。 output_flat(preserve_index_) = input(preserve_index_); }
後方互換性 を守るためには、既存の op に attr を追加する時 デフォルト値 を指定すべきです :
REGISTER_OP("ZeroOut") .Attr("preserve_index: int = 0") .Input("to_zero: int32") .Output("zeroed: int32");
Attr 型
attr においては次の型がサポートされます:
string
: 任意のバイト列 (Any sequence of bytes) (UTF8 である必要はない)。int
: 符号付き整数 (signed integer)float
: 浮動小数点数値。bool
: True または false.type
: One of the (non-ref) values ofDataType
.shape
: ATensorShapeProto
.tensor
: ATensorProto
.list(<type>)
: A list of<type>
, where<type>
is one of the above types.
Note thatlist(list(<type>))
is invalid.
See also: op_def_builder.cc:FinalizeAttr
for a definitive list.
Default values & constraints
Attrs may have default values, and some types of attrs can have constraints. To
define an attr with constraints, you can use the following <attr-type-expr>
s:
-
{'<string1>', '<string2>'}
: The value must be a string that has either the
value<string1>
or<string2>
. The name of the type,string
, is implied
when you use this syntax. This emulates an enum:REGISTER_OP("EnumExample") .Attr("e: {'apple', 'orange'}");
-
{<type1>, <type2>}
: The value is of typetype
, and must be one of
<type1>
or<type2>
, where<type1>
and<type2>
are supported
tensor types. You don't specify
that the type of the attr istype
. This is implied when you have a list of
types in{...}
. For example, in this case the attrt
is a type that must
be anint32
, afloat
, or abool
:REGISTER_OP("RestrictedTypeExample") .Attr("t: {int32, float, bool}");
-
There are shortcuts for common type constraints:
numbertype
: Typetype
restricted to the numeric (non-string and
non-bool) types.realnumbertype
: Likenumbertype
without complex types.quantizedtype
: Likenumbertype
but just the quantized number types.
The specific lists of types allowed by these are defined by the functions
(likeNumberTypes()
) in
tensorflow/core/framework/types.h
.
In this example the attrt
must be one of the numeric types:REGISTER_OP("NumberType") .Attr("t: numbertype");
For this op:
tf.number_type(t=tf.int32) # Valid tf.number_type(t=tf.bool) # Invalid
-
int >= <n>
: The value must be an int whose value is greater than or equal to
<n>
, where<n>
is a natural number.For example, the following Op registration specifies that the attr
a
must
have a value that is at least2
:REGISTER_OP("MinIntExample") .Attr("a: int >= 2");
-
list(<type>) >= <n>
: A list of type<type>
whose length is greater than
or equal to<n>
.For example, the following Op registration specifies that the attr
a
is a
list of types (eitherint32
orfloat
), and that there must be at least 3
of them:REGISTER_OP("TypeListExample") .Attr("a: list({int32, float}) >= 3");
To set a default value for an attr (making it optional in the generated code),
add = <default>
to the end, as in:
REGISTER_OP("AttrDefaultExample")
.Attr("i: int = 0");
The supported syntax of the default value is what would be used in the proto
representation of the resulting GraphDef definition.
Here are examples for how to specify a default for all types:
REGISTER_OP("AttrDefaultExampleForAllTypes")
.Attr("s: string = 'foo'")
.Attr("i: int = 0")
.Attr("f: float = 1.0")
.Attr("b: bool = true")
.Attr("ty: type = DT_INT32")
.Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
.Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
.Attr("l_empty: list(int) = []")
.Attr("l_int: list(int) = [2, 3, 5, 7]");
Note in particular that the values of type type
use the DT_*
names
for the types.
以上