Keras 2 : examples : PointNet によるポイントクラウド・セグメンテーション (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/07/2021 (keras 2.7.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Point cloud segmentation with PointNet (Author: Soumik Rakshit, Sayak Paul)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- E-Mail:sales-info@classcat.com ; WebSite: www.classcat.com ; Facebook
Keras 2 : examples : PointNet によるポイントクラウド・セグメンテーション
Description: ポイントクラウドのセグメンテーションのための PointNet ベースのモデルの実装。
イントロダクション
「ポイントクラウド (点群)」は幾何学的形状データをストアするための重要なタイプのデータ構造です。イレギュラーな形式ゆえに、深層学習アプリケーションで使用される前に、通常の 3D ボクセルグリッドや画像のコレクションに変換されることも多く、このステップはデータを不必要に大きくしてしまいます。PointNet モデルのファミリーは、ポイントデータの順列不変性 (= permutation-invariance) の特性を尊重し、ポイントクラウドを直接的に消費することによりこの問題を解きます。PointNet モデルのファミリーは、オブジェクト分類, パーツ・セグメンテーション から シーン・セマンティック解析 までの範囲に渡るアプリケーションに対して単純で、統一されたアーキテクチャを提供します。
このサンプルでは、形状セグメンテーションのための PointNet アーキテクチャの実装を実演します。
リファレンス
- PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation
- PointNet によるポイントクラウド分類
- Spatial Transformer Networks
インポート
import os
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from glob import glob
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
データセットのダウンロード
ShapeNet データセット は、 3D 形状の豊富なアノテーション付きの大規模なデータセットを作り上げる継続的な取り組みです。ShapeNetCore は full ShapeNet データセットのサブセットで、クリーンな単一 3D モデルと (手動で検証された) カテゴリーとアラインメント・アノテーションを含みます。それは 55 の一般的な物体カテゴリーをカバーし、約 51,300 の一意な 3D モデルを含みます。
このサンプルについては、PASCAL 3D+ の12 物体カテゴリーの一つを使用し、これは ShapenetCore データセットの一部として含まれています。
dataset_url = "https://git.io/JiY4i"
dataset_path = keras.utils.get_file(
fname="shapenet.zip",
origin=dataset_url,
cache_subdir="datasets",
hash_algorithm="auto",
extract=True,
archive_format="auto",
cache_dir="datasets",
)
データセットのロード
モデルカテゴリーをそれぞれのディレクトリに、そして可視化のためにセグメンテーション・クラスをカラーに簡単にマップするためにデータセットのメタデータを解析します。
with open("/tmp/.keras/datasets/PartAnnotation/metadata.json") as json_file:
metadata = json.load(json_file)
print(metadata)
{'Airplane': {'directory': '02691156', 'lables': ['wing', 'body', 'tail', 'engine'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Bag': {'directory': '02773838', 'lables': ['handle', 'body'], 'colors': ['blue', 'green']}, 'Cap': {'directory': '02954340', 'lables': ['panels', 'peak'], 'colors': ['blue', 'green']}, 'Car': {'directory': '02958343', 'lables': ['wheel', 'hood', 'roof'], 'colors': ['blue', 'green', 'red']}, 'Chair': {'directory': '03001627', 'lables': ['leg', 'arm', 'back', 'seat'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Earphone': {'directory': '03261776', 'lables': ['earphone', 'headband'], 'colors': ['blue', 'green']}, 'Guitar': {'directory': '03467517', 'lables': ['head', 'body', 'neck'], 'colors': ['blue', 'green', 'red']}, 'Knife': {'directory': '03624134', 'lables': ['handle', 'blade'], 'colors': ['blue', 'green']}, 'Lamp': {'directory': '03636649', 'lables': ['canopy', 'lampshade', 'base'], 'colors': ['blue', 'green', 'red']}, 'Laptop': {'directory': '03642806', 'lables': ['keyboard'], 'colors': ['blue']}, 'Motorbike': {'directory': '03790512', 'lables': ['wheel', 'handle', 'gas_tank', 'light', 'seat'], 'colors': ['blue', 'green', 'red', 'pink', 'yellow']}, 'Mug': {'directory': '03797390', 'lables': ['handle'], 'colors': ['blue']}, 'Pistol': {'directory': '03948459', 'lables': ['trigger_and_guard', 'handle', 'barrel'], 'colors': ['blue', 'green', 'red']}, 'Rocket': {'directory': '04099429', 'lables': ['nose', 'body', 'fin'], 'colors': ['blue', 'green', 'red']}, 'Skateboard': {'directory': '04225987', 'lables': ['wheel', 'deck'], 'colors': ['blue', 'green']}, 'Table': {'directory': '04379243', 'lables': ['leg', 'top'], 'colors': ['blue', 'green']}}
このサンプルでは、飛行機モデルのパーツをセグメント分けするために PointNet を訓練します。
points_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points".format(
metadata["Airplane"]["directory"]
)
labels_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points_label".format(
metadata["Airplane"]["directory"]
)
LABELS = metadata["Airplane"]["lables"]
COLORS = metadata["Airplane"]["colors"]
VAL_SPLIT = 0.2
NUM_SAMPLE_POINTS = 1024
BATCH_SIZE = 32
EPOCHS = 60
INITIAL_LR = 1e-3
データセットの構造化
飛行機のポイントクラウドとそれらのラベルから以下の in-memory データ構造を生成します :
- point_clouds は np.array オブジェクトのリストで、x, y と z 座標の形式でポイントクラウドのデータを表します。軸 0 がポイントクラウドのポイント数を表し、軸 1 は座標を表します。all_labels は各座標のラベルを文字列として表わすリストです (主として可視化目的で必要とされます)。
- test_point_clouds は point_clouds と同じ形式にありますが、ポイントクラウドの対応するラベルは持ちません。
- all_labels は np.array オブジェクトのリストで、point_clouds リストに対応して、各座標に対するポイントクラウドのラベルを表します。
- point_cloud_labels は np.array オブジェクトのリストで、point_clouds リストに対応して、one-hot エンコードされた形式で各座標に対するポイントクラウドのラベルを表します。
point_clouds, test_point_clouds = [], []
point_cloud_labels, all_labels = [], []
points_files = glob(os.path.join(points_dir, "*.pts"))
for point_file in tqdm(points_files):
point_cloud = np.loadtxt(point_file)
if point_cloud.shape[0] < NUM_SAMPLE_POINTS:
continue
# Get the file-id of the current point cloud for parsing its
# labels.
file_id = point_file.split("/")[-1].split(".")[0]
label_data, num_labels = {}, 0
for label in LABELS:
label_file = os.path.join(labels_dir, label, file_id + ".seg")
if os.path.exists(label_file):
label_data[label] = np.loadtxt(label_file).astype("float32")
num_labels = len(label_data[label])
# Point clouds having labels will be our training samples.
try:
label_map = ["none"] * num_labels
for label in LABELS:
for i, data in enumerate(label_data[label]):
label_map[i] = label if data == 1 else label_map[i]
label_data = [
LABELS.index(label) if label != "none" else len(LABELS)
for label in label_map
]
# Apply one-hot encoding to the dense label representation.
label_data = keras.utils.to_categorical(label_data, num_classes=len(LABELS) + 1)
point_clouds.append(point_cloud)
point_cloud_labels.append(label_data)
all_labels.append(label_map)
except KeyError:
test_point_clouds.append(point_cloud)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4045/4045 [03:35<00:00, 18.76it/s]
次に、生成したばかりの in-memory 配列から幾つかのサンプルを見ます :
for _ in range(5):
i = random.randint(0, len(point_clouds) - 1)
print(f"point_clouds[{i}].shape:", point_clouds[0].shape)
print(f"point_cloud_labels[{i}].shape:", point_cloud_labels[0].shape)
for j in range(5):
print(
f"all_labels[{i}][{j}]:",
all_labels[i][j],
f"\tpoint_cloud_labels[{i}][{j}]:",
point_cloud_labels[i][j],
"\n",
)
次に、ポイントクラウドの一部をそれらのラベルと一緒に可視化しましょう。
def visualize_data(point_cloud, labels):
df = pd.DataFrame(
data={
"x": point_cloud[:, 0],
"y": point_cloud[:, 1],
"z": point_cloud[:, 2],
"label": labels,
}
)
fig = plt.figure(figsize=(15, 10))
ax = plt.axes(projection="3d")
for index, label in enumerate(LABELS):
c_df = df[df["label"] == label]
try:
ax.scatter(
c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index]
)
except IndexError:
pass
ax.legend()
plt.show()
visualize_data(point_clouds[0], all_labels[0])
visualize_data(point_clouds[300], all_labels[300])
前処理
ロードしたポイントクラウドの総ては可変数のポイントから成り、それらをまとめてバッチ処理することを困難にしています。この問題を打開するために、各ポイントクラウドから固定数をランダムにサンプリングします。また、データをスケール不変にするためにポイントクラウドを正規化します。
for index in tqdm(range(len(point_clouds))):
current_point_cloud = point_clouds[index]
current_label_cloud = point_cloud_labels[index]
current_labels = all_labels[index]
num_points = len(current_point_cloud)
# Randomly sampling respective indices.
sampled_indices = random.sample(list(range(num_points)), NUM_SAMPLE_POINTS)
# Sampling points corresponding to sampled indices.
sampled_point_cloud = np.array([current_point_cloud[i] for i in sampled_indices])
# Sampling corresponding one-hot encoded labels.
sampled_label_cloud = np.array([current_label_cloud[i] for i in sampled_indices])
# Sampling corresponding labels for visualization.
sampled_labels = np.array([current_labels[i] for i in sampled_indices])
# Normalizing sampled point cloud.
norm_point_cloud = sampled_point_cloud - np.mean(sampled_point_cloud, axis=0)
norm_point_cloud /= np.max(np.linalg.norm(norm_point_cloud, axis=1))
point_clouds[index] = norm_point_cloud
point_cloud_labels[index] = sampled_label_cloud
all_labels[index] = sampled_labels
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3694/3694 [00:07<00:00, 478.67it/s]
サンプリングそして正規化されたポイントクラウドをそれらの対応するラベルと一緒に可視化しましょう。
visualize_data(point_clouds[0], all_labels[0])
visualize_data(point_clouds[300], all_labels[300])
TensorFlow データセットの作成
訓練と検証データのために tf.data.Dataset オブジェクトを作成します。また訓練ポイントクラウドをそれらにランダム jitter を適用することで増強します。
def load_data(point_cloud_batch, label_cloud_batch):
point_cloud_batch.set_shape([NUM_SAMPLE_POINTS, 3])
label_cloud_batch.set_shape([NUM_SAMPLE_POINTS, len(LABELS) + 1])
return point_cloud_batch, label_cloud_batch
def augment(point_cloud_batch, label_cloud_batch):
noise = tf.random.uniform(
tf.shape(label_cloud_batch), -0.005, 0.005, dtype=tf.float64
)
point_cloud_batch += noise[:, :, :3]
return point_cloud_batch, label_cloud_batch
def generate_dataset(point_clouds, label_clouds, is_training=True):
dataset = tf.data.Dataset.from_tensor_slices((point_clouds, label_clouds))
dataset = dataset.shuffle(BATCH_SIZE * 100) if is_training else dataset
dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(batch_size=BATCH_SIZE)
dataset = (
dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
if is_training
else dataset
)
return dataset
split_index = int(len(point_clouds) * (1 - VAL_SPLIT))
train_point_clouds = point_clouds[:split_index]
train_label_cloud = point_cloud_labels[:split_index]
total_training_examples = len(train_point_clouds)
val_point_clouds = point_clouds[split_index:]
val_label_cloud = point_cloud_labels[split_index:]
print("Num train point clouds:", len(train_point_clouds))
print("Num train point cloud labels:", len(train_label_cloud))
print("Num val point clouds:", len(val_point_clouds))
print("Num val point cloud labels:", len(val_label_cloud))
train_dataset = generate_dataset(train_point_clouds, train_label_cloud)
val_dataset = generate_dataset(val_point_clouds, val_label_cloud, is_training=False)
print("Train Dataset:", train_dataset)
print("Validation Dataset:", val_dataset)
Num train point clouds: 2955 Num train point cloud labels: 2955 Num val point clouds: 739 Num val point cloud labels: 739 Train Dataset: <ParallelMapDataset shapes: ((None, 1024, 3), (None, 1024, 5)), types: (tf.float64, tf.float32)> Validation Dataset: <BatchDataset shapes: ((None, 1024, 3), (None, 1024, 5)), types: (tf.float64, tf.float32)>
PointNet モデル
下の図は PointNet モデル・ファミリーの内部を描写しています :
PointNet は入力データとして座標の 順序付けられていないセット を消費することを意図していると仮定すると、そのアーキテクチャはポイントクラウド・データの以下の特有の性質に一致している必要があります :
順列不変性 (= Permutation invariance)
ポイントクラウド・データの非構造化的な性質を考えると、n ポイントから成るスキャンは n! の順列を持ちます。続くデータ処理は異なる表現に対して不変でなければなりません。PointNet を入力の順列に対して不変であるようにするために、n 個の入力ポイントが高次元空間にマップされた時点で、(最大プーリングのような) 対称関数を使用します。その結果は、n 個の入力ポイントの集合的な特徴 (= signature) を捕捉することを目的とする 大域的特徴ベクトル です。この大域的特徴ベクトルはセグメンテーションのために局所的ポイント特徴とともに使用されます。
変換不変性 (= Transformation invariance)
セグメンテーション出力は、オブジェクトが並行移動やスケールのような特定の変換を受ける場合、不変であるべきです。与えられた入力ポイントクラウドに対して、ポーズ正規化を実現するために適切な剛体 (= rigid) 変換やアフィン変換を適用します。n 個の入力ポイントの各々はベクトルとして表現され、埋め込み空間に独立にマップされるので、幾何学的変換の適用は単純に各ポイントを変換行列で行列乗算することになります。これは Spatial Transformer ネットワーク の概念により動機づけられています。
T-Net から成る演算は PointNet の高位アーキテクチャにより動機づけられています。MLP (or 完全結合層) は入力ポイントを独立的にかつ一意に高次元空間にマップするために使用されます ; 最大プーリングは大域特徴ベクトル、その次元は次に完全結合層により削減されます、をエンコードするために使用されます。最後の完全結合層における入力依存な特徴は次にグローバルに訓練可能な重みとバイアスと組み合わされて、3 x 3 変換行列という結果になります。
ポイントの相互作用
隣接ポイント間の相互作用は有用な情報を携行していることがしばしばあります (i.e. 単一のポイントは単独で扱われるべきではありません)。分類が大域的特徴だけを利用する必要がある一方で、セグメンテーションは大域的ポイント特徴とともに局所的ポイント特徴も活用できなければなりません。
Note : このセクションで提示される図は 原論文 から引用されています。
PointNet モデルを構成するピースを知った今、モデルを実装できます。基本的なブロック i.e., 畳み込みブロックと多層パーセプトロン・ブロックを実装することから始めます。
def conv_block(x: tf.Tensor, filters: int, name: str) -> tf.Tensor:
x = layers.Conv1D(filters, kernel_size=1, padding="valid", name=f"{name}_conv")(x)
x = layers.BatchNormalization(momentum=0.0, name=f"{name}_batch_norm")(x)
return layers.Activation("relu", name=f"{name}_relu")(x)
def mlp_block(x: tf.Tensor, filters: int, name: str) -> tf.Tensor:
x = layers.Dense(filters, name=f"{name}_dense")(x)
x = layers.BatchNormalization(momentum=0.0, name=f"{name}_batch_norm")(x)
return layers.Activation("relu", name=f"{name}_relu")(x)
特徴空間で直交性を強要するために ( このサンプル から引用された) regularizer を実装します。これは、変換された特徴の大きさが大きく変化し過ぎないことを確実にするために必要です。
class OrthogonalRegularizer(keras.regularizers.Regularizer):
"""Reference: https://keras.io/examples/vision/pointnet/#build-a-model"""
def __init__(self, num_features, l2reg=0.001):
self.num_features = num_features
self.l2reg = l2reg
self.identity = tf.eye(num_features)
def __call__(self, x):
x = tf.reshape(x, (-1, self.num_features, self.num_features))
xxt = tf.tensordot(x, x, axes=(2, 2))
xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
return tf.reduce_sum(self.l2reg * tf.square(xxt - self.identity))
def get_config(self):
config = super(TransformerEncoder, self).get_config()
config.update({"num_features": self.num_features, "l2reg_strength": self.l2reg})
return config
次のピースは先に説明した変換ネットワークです。
def transformation_net(inputs: tf.Tensor, num_features: int, name: str) -> tf.Tensor:
"""
Reference: https://keras.io/examples/vision/pointnet/#build-a-model.
The `filters` values come from the original paper:
https://arxiv.org/abs/1612.00593.
"""
x = conv_block(inputs, filters=64, name=f"{name}_1")
x = conv_block(x, filters=128, name=f"{name}_2")
x = conv_block(x, filters=1024, name=f"{name}_3")
x = layers.GlobalMaxPooling1D()(x)
x = mlp_block(x, filters=512, name=f"{name}_1_1")
x = mlp_block(x, filters=256, name=f"{name}_2_1")
return layers.Dense(
num_features * num_features,
kernel_initializer="zeros",
bias_initializer=keras.initializers.Constant(np.eye(num_features).flatten()),
activity_regularizer=OrthogonalRegularizer(num_features),
name=f"{name}_final",
)(x)
def transformation_block(inputs: tf.Tensor, num_features: int, name: str) -> tf.Tensor:
transformed_features = transformation_net(inputs, num_features, name=name)
transformed_features = layers.Reshape((num_features, num_features))(
transformed_features
)
return layers.Dot(axes=(2, 1), name=f"{name}_mm")([inputs, transformed_features])
最後に、上記のブロックをまとめてセグメンテーション・モデルを実装します。
def get_shape_segmentation_model(num_points: int, num_classes: int) -> keras.Model:
input_points = keras.Input(shape=(None, 3))
# PointNet Classification Network.
transformed_inputs = transformation_block(
input_points, num_features=3, name="input_transformation_block"
)
features_64 = conv_block(transformed_inputs, filters=64, name="features_64")
features_128_1 = conv_block(features_64, filters=128, name="features_128_1")
features_128_2 = conv_block(features_128_1, filters=128, name="features_128_2")
transformed_features = transformation_block(
features_128_2, num_features=128, name="transformed_features"
)
features_512 = conv_block(transformed_features, filters=512, name="features_512")
features_2048 = conv_block(features_512, filters=2048, name="pre_maxpool_block")
global_features = layers.MaxPool1D(pool_size=num_points, name="global_features")(
features_2048
)
global_features = tf.tile(global_features, [1, num_points, 1])
# Segmentation head.
segmentation_input = layers.Concatenate(name="segmentation_input")(
[
features_64,
features_128_1,
features_128_2,
transformed_features,
features_512,
global_features,
]
)
segmentation_features = conv_block(
segmentation_input, filters=128, name="segmentation_features"
)
outputs = layers.Conv1D(
num_classes, kernel_size=1, activation="softmax", name="segmentation_head"
)(segmentation_features)
return keras.Model(input_points, outputs)
モデルのインスタンス化
x, y = next(iter(train_dataset))
num_points = x.shape[1]
num_classes = y.shape[-1]
segmentation_model = get_shape_segmentation_model(num_points, num_classes)
segmentation_model.summary()
2021-10-25 01:26:33.563133: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2) Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, None, 3)] 0 __________________________________________________________________________________________________ input_transformation_block_1_co (None, None, 64) 256 input_1[0][0] __________________________________________________________________________________________________ input_transformation_block_1_ba (None, None, 64) 256 input_transformation_block_1_conv __________________________________________________________________________________________________ input_transformation_block_1_re (None, None, 64) 0 input_transformation_block_1_batc __________________________________________________________________________________________________ input_transformation_block_2_co (None, None, 128) 8320 input_transformation_block_1_relu __________________________________________________________________________________________________ input_transformation_block_2_ba (None, None, 128) 512 input_transformation_block_2_conv __________________________________________________________________________________________________ input_transformation_block_2_re (None, None, 128) 0 input_transformation_block_2_batc __________________________________________________________________________________________________ input_transformation_block_3_co (None, None, 1024) 132096 input_transformation_block_2_relu __________________________________________________________________________________________________ input_transformation_block_3_ba (None, None, 1024) 4096 input_transformation_block_3_conv __________________________________________________________________________________________________ input_transformation_block_3_re (None, None, 1024) 0 input_transformation_block_3_batc __________________________________________________________________________________________________ global_max_pooling1d (GlobalMax (None, 1024) 0 input_transformation_block_3_relu __________________________________________________________________________________________________ input_transformation_block_1_1_ (None, 512) 524800 global_max_pooling1d[0][0] __________________________________________________________________________________________________ input_transformation_block_1_1_ (None, 512) 2048 input_transformation_block_1_1_de __________________________________________________________________________________________________ input_transformation_block_1_1_ (None, 512) 0 input_transformation_block_1_1_ba __________________________________________________________________________________________________ input_transformation_block_2_1_ (None, 256) 131328 input_transformation_block_1_1_re __________________________________________________________________________________________________ input_transformation_block_2_1_ (None, 256) 1024 input_transformation_block_2_1_de __________________________________________________________________________________________________ input_transformation_block_2_1_ (None, 256) 0 input_transformation_block_2_1_ba __________________________________________________________________________________________________ input_transformation_block_fina (None, 9) 2313 input_transformation_block_2_1_re __________________________________________________________________________________________________ reshape (Reshape) (None, 3, 3) 0 input_transformation_block_final[ __________________________________________________________________________________________________ input_transformation_block_mm ( (None, None, 3) 0 input_1[0][0] reshape[0][0] __________________________________________________________________________________________________ features_64_conv (Conv1D) (None, None, 64) 256 input_transformation_block_mm[0][ __________________________________________________________________________________________________ features_64_batch_norm (BatchNo (None, None, 64) 256 features_64_conv[0][0] __________________________________________________________________________________________________ features_64_relu (Activation) (None, None, 64) 0 features_64_batch_norm[0][0] __________________________________________________________________________________________________ features_128_1_conv (Conv1D) (None, None, 128) 8320 features_64_relu[0][0] __________________________________________________________________________________________________ features_128_1_batch_norm (Batc (None, None, 128) 512 features_128_1_conv[0][0] __________________________________________________________________________________________________ features_128_1_relu (Activation (None, None, 128) 0 features_128_1_batch_norm[0][0] __________________________________________________________________________________________________ features_128_2_conv (Conv1D) (None, None, 128) 16512 features_128_1_relu[0][0] __________________________________________________________________________________________________ features_128_2_batch_norm (Batc (None, None, 128) 512 features_128_2_conv[0][0] __________________________________________________________________________________________________ features_128_2_relu (Activation (None, None, 128) 0 features_128_2_batch_norm[0][0] __________________________________________________________________________________________________ transformed_features_1_conv (Co (None, None, 64) 8256 features_128_2_relu[0][0] __________________________________________________________________________________________________ transformed_features_1_batch_no (None, None, 64) 256 transformed_features_1_conv[0][0] __________________________________________________________________________________________________ transformed_features_1_relu (Ac (None, None, 64) 0 transformed_features_1_batch_norm __________________________________________________________________________________________________ transformed_features_2_conv (Co (None, None, 128) 8320 transformed_features_1_relu[0][0] __________________________________________________________________________________________________ transformed_features_2_batch_no (None, None, 128) 512 transformed_features_2_conv[0][0] __________________________________________________________________________________________________ transformed_features_2_relu (Ac (None, None, 128) 0 transformed_features_2_batch_norm __________________________________________________________________________________________________ transformed_features_3_conv (Co (None, None, 1024) 132096 transformed_features_2_relu[0][0] __________________________________________________________________________________________________ transformed_features_3_batch_no (None, None, 1024) 4096 transformed_features_3_conv[0][0] __________________________________________________________________________________________________ transformed_features_3_relu (Ac (None, None, 1024) 0 transformed_features_3_batch_norm __________________________________________________________________________________________________ global_max_pooling1d_1 (GlobalM (None, 1024) 0 transformed_features_3_relu[0][0] __________________________________________________________________________________________________ transformed_features_1_1_dense (None, 512) 524800 global_max_pooling1d_1[0][0] __________________________________________________________________________________________________ transformed_features_1_1_batch_ (None, 512) 2048 transformed_features_1_1_dense[0] __________________________________________________________________________________________________ transformed_features_1_1_relu ( (None, 512) 0 transformed_features_1_1_batch_no __________________________________________________________________________________________________ transformed_features_2_1_dense (None, 256) 131328 transformed_features_1_1_relu[0][ __________________________________________________________________________________________________ transformed_features_2_1_batch_ (None, 256) 1024 transformed_features_2_1_dense[0] __________________________________________________________________________________________________ transformed_features_2_1_relu ( (None, 256) 0 transformed_features_2_1_batch_no __________________________________________________________________________________________________ transformed_features_final (Den (None, 16384) 4210688 transformed_features_2_1_relu[0][ __________________________________________________________________________________________________ reshape_1 (Reshape) (None, 128, 128) 0 transformed_features_final[0][0] __________________________________________________________________________________________________ transformed_features_mm (Dot) (None, None, 128) 0 features_128_2_relu[0][0] reshape_1[0][0] __________________________________________________________________________________________________ features_512_conv (Conv1D) (None, None, 512) 66048 transformed_features_mm[0][0] __________________________________________________________________________________________________ features_512_batch_norm (BatchN (None, None, 512) 2048 features_512_conv[0][0] __________________________________________________________________________________________________ features_512_relu (Activation) (None, None, 512) 0 features_512_batch_norm[0][0] __________________________________________________________________________________________________ pre_maxpool_block_conv (Conv1D) (None, None, 2048) 1050624 features_512_relu[0][0] __________________________________________________________________________________________________ pre_maxpool_block_batch_norm (B (None, None, 2048) 8192 pre_maxpool_block_conv[0][0] __________________________________________________________________________________________________ pre_maxpool_block_relu (Activat (None, None, 2048) 0 pre_maxpool_block_batch_norm[0][0 __________________________________________________________________________________________________ global_features (MaxPooling1D) (None, None, 2048) 0 pre_maxpool_block_relu[0][0] __________________________________________________________________________________________________ tf.tile (TFOpLambda) (None, None, 2048) 0 global_features[0][0] __________________________________________________________________________________________________ segmentation_input (Concatenate (None, None, 3008) 0 features_64_relu[0][0] features_128_1_relu[0][0] features_128_2_relu[0][0] transformed_features_mm[0][0] features_512_relu[0][0] tf.tile[0][0] __________________________________________________________________________________________________ segmentation_features_conv (Con (None, None, 128) 385152 segmentation_input[0][0] __________________________________________________________________________________________________ segmentation_features_batch_nor (None, None, 128) 512 segmentation_features_conv[0][0] __________________________________________________________________________________________________ segmentation_features_relu (Act (None, None, 128) 0 segmentation_features_batch_norm[ __________________________________________________________________________________________________ segmentation_head (Conv1D) (None, None, 5) 645 segmentation_features_relu[0][0] ================================================================================================== Total params: 7,370,062 Trainable params: 7,356,110 Non-trainable params: 13,952
訓練
訓練については、著者らは初期学習率を 20 エポック毎に半分にする学習率スケジュールの使用を勧めています。このサンプルでは、15 エポックにしています。
training_step_size = total_training_examples // BATCH_SIZE
total_training_steps = training_step_size * EPOCHS
print(f"Total training steps: {total_training_steps}.")
lr_schedule = keras.optimizers.schedules.PiecewiseConstantDecay(
boundaries=[training_step_size * 15, training_step_size * 15],
values=[INITIAL_LR, INITIAL_LR * 0.5, INITIAL_LR * 0.25],
)
steps = tf.range(total_training_steps, dtype=tf.int32)
lrs = [lr_schedule(step) for step in steps]
plt.plot(lrs)
plt.xlabel("Steps")
plt.ylabel("Learning Rate")
plt.show()
Total training steps: 5520.
最後に、実験を実行するためのユティリティを実装してモデル訓練を起動します。
def run_experiment(epochs):
segmentation_model = get_shape_segmentation_model(num_points, num_classes)
segmentation_model.compile(
optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
loss=keras.losses.CategoricalCrossentropy(),
metrics=["accuracy"],
)
checkpoint_filepath = "/tmp/checkpoint"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_loss",
save_best_only=True,
save_weights_only=True,
)
history = segmentation_model.fit(
train_dataset,
validation_data=val_dataset,
epochs=epochs,
callbacks=[checkpoint_callback],
)
segmentation_model.load_weights(checkpoint_filepath)
return segmentation_model, history
segmentation_model, history = run_experiment(epochs=EPOCHS)
Epoch 1/60 93/93 [==============================] - 28s 127ms/step - loss: 5.3556 - accuracy: 0.7448 - val_loss: 5.8386 - val_accuracy: 0.7471 Epoch 2/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7077 - accuracy: 0.8181 - val_loss: 5.2614 - val_accuracy: 0.7793 Epoch 3/60 93/93 [==============================] - 11s 118ms/step - loss: 4.6566 - accuracy: 0.8301 - val_loss: 4.7907 - val_accuracy: 0.8269 Epoch 4/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6059 - accuracy: 0.8406 - val_loss: 4.6031 - val_accuracy: 0.8482 Epoch 5/60 93/93 [==============================] - 11s 118ms/step - loss: 4.5828 - accuracy: 0.8444 - val_loss: 4.7692 - val_accuracy: 0.8220 Epoch 6/60 93/93 [==============================] - 11s 118ms/step - loss: 4.6150 - accuracy: 0.8408 - val_loss: 5.4460 - val_accuracy: 0.8192 Epoch 7/60 93/93 [==============================] - 11s 117ms/step - loss: 67.5943 - accuracy: 0.7378 - val_loss: 1617.1846 - val_accuracy: 0.5191 Epoch 8/60 93/93 [==============================] - 11s 117ms/step - loss: 15.2910 - accuracy: 0.6651 - val_loss: 8.1014 - val_accuracy: 0.7046 Epoch 9/60 93/93 [==============================] - 11s 117ms/step - loss: 6.8878 - accuracy: 0.7368 - val_loss: 14.2311 - val_accuracy: 0.6949 Epoch 10/60 93/93 [==============================] - 11s 117ms/step - loss: 5.8362 - accuracy: 0.7549 - val_loss: 14.6942 - val_accuracy: 0.6350 Epoch 11/60 93/93 [==============================] - 11s 117ms/step - loss: 5.4777 - accuracy: 0.7648 - val_loss: 44.1037 - val_accuracy: 0.6422 Epoch 12/60 93/93 [==============================] - 11s 117ms/step - loss: 5.2688 - accuracy: 0.7712 - val_loss: 4.9977 - val_accuracy: 0.7692 Epoch 13/60 93/93 [==============================] - 11s 117ms/step - loss: 5.1041 - accuracy: 0.7837 - val_loss: 6.0642 - val_accuracy: 0.7577 Epoch 14/60 93/93 [==============================] - 11s 117ms/step - loss: 5.0011 - accuracy: 0.7862 - val_loss: 4.9313 - val_accuracy: 0.7840 Epoch 15/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8910 - accuracy: 0.7953 - val_loss: 5.8368 - val_accuracy: 0.7725 Epoch 16/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8698 - accuracy: 0.8074 - val_loss: 73.0260 - val_accuracy: 0.7251 Epoch 17/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8299 - accuracy: 0.8109 - val_loss: 17.1503 - val_accuracy: 0.7415 Epoch 18/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8147 - accuracy: 0.8111 - val_loss: 62.2765 - val_accuracy: 0.7344 Epoch 19/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8316 - accuracy: 0.8141 - val_loss: 5.2200 - val_accuracy: 0.7890 Epoch 20/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7853 - accuracy: 0.8142 - val_loss: 5.7062 - val_accuracy: 0.7719 Epoch 21/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7753 - accuracy: 0.8157 - val_loss: 6.2089 - val_accuracy: 0.7839 Epoch 22/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7681 - accuracy: 0.8161 - val_loss: 5.1077 - val_accuracy: 0.8021 Epoch 23/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7554 - accuracy: 0.8187 - val_loss: 4.7912 - val_accuracy: 0.7912 Epoch 24/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7355 - accuracy: 0.8197 - val_loss: 4.9164 - val_accuracy: 0.7978 Epoch 25/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7483 - accuracy: 0.8197 - val_loss: 13.4724 - val_accuracy: 0.7631 Epoch 26/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7200 - accuracy: 0.8218 - val_loss: 8.3074 - val_accuracy: 0.7596 Epoch 27/60 93/93 [==============================] - 11s 118ms/step - loss: 4.7192 - accuracy: 0.8231 - val_loss: 12.4468 - val_accuracy: 0.7591 Epoch 28/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7151 - accuracy: 0.8241 - val_loss: 23.8681 - val_accuracy: 0.7689 Epoch 29/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7096 - accuracy: 0.8237 - val_loss: 4.9069 - val_accuracy: 0.8104 Epoch 30/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6991 - accuracy: 0.8257 - val_loss: 4.9858 - val_accuracy: 0.7950 Epoch 31/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6852 - accuracy: 0.8260 - val_loss: 5.0130 - val_accuracy: 0.7678 Epoch 32/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6630 - accuracy: 0.8286 - val_loss: 4.8523 - val_accuracy: 0.7676 Epoch 33/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6837 - accuracy: 0.8281 - val_loss: 5.4347 - val_accuracy: 0.8095 Epoch 34/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6571 - accuracy: 0.8296 - val_loss: 10.4595 - val_accuracy: 0.7410 Epoch 35/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6460 - accuracy: 0.8321 - val_loss: 4.9189 - val_accuracy: 0.8083 Epoch 36/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6430 - accuracy: 0.8327 - val_loss: 5.8674 - val_accuracy: 0.7911 Epoch 37/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6530 - accuracy: 0.8309 - val_loss: 4.7946 - val_accuracy: 0.8032 Epoch 38/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6391 - accuracy: 0.8318 - val_loss: 5.0111 - val_accuracy: 0.8024 Epoch 39/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6521 - accuracy: 0.8336 - val_loss: 8.1558 - val_accuracy: 0.7727 Epoch 40/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6443 - accuracy: 0.8329 - val_loss: 42.8513 - val_accuracy: 0.7688 Epoch 41/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6316 - accuracy: 0.8342 - val_loss: 5.0960 - val_accuracy: 0.8066 Epoch 42/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6322 - accuracy: 0.8335 - val_loss: 5.0634 - val_accuracy: 0.8158 Epoch 43/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6175 - accuracy: 0.8370 - val_loss: 6.0642 - val_accuracy: 0.8062 Epoch 44/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6175 - accuracy: 0.8371 - val_loss: 11.1805 - val_accuracy: 0.7790 Epoch 45/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6056 - accuracy: 0.8377 - val_loss: 4.7359 - val_accuracy: 0.8145 Epoch 46/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6108 - accuracy: 0.8383 - val_loss: 5.7125 - val_accuracy: 0.7713 Epoch 47/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6103 - accuracy: 0.8377 - val_loss: 6.3271 - val_accuracy: 0.8105 Epoch 48/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6020 - accuracy: 0.8383 - val_loss: 14.2876 - val_accuracy: 0.7529 Epoch 49/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6035 - accuracy: 0.8382 - val_loss: 4.8244 - val_accuracy: 0.8143 Epoch 50/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6076 - accuracy: 0.8381 - val_loss: 8.2636 - val_accuracy: 0.7528 Epoch 51/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5927 - accuracy: 0.8399 - val_loss: 4.6473 - val_accuracy: 0.8266 Epoch 52/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5927 - accuracy: 0.8408 - val_loss: 4.6443 - val_accuracy: 0.8276 Epoch 53/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5852 - accuracy: 0.8413 - val_loss: 5.1300 - val_accuracy: 0.7768 Epoch 54/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5787 - accuracy: 0.8426 - val_loss: 8.9590 - val_accuracy: 0.7582 Epoch 55/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5837 - accuracy: 0.8410 - val_loss: 5.1501 - val_accuracy: 0.8117 Epoch 56/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5875 - accuracy: 0.8422 - val_loss: 31.3518 - val_accuracy: 0.7590 Epoch 57/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5821 - accuracy: 0.8427 - val_loss: 4.8853 - val_accuracy: 0.8144 Epoch 58/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5751 - accuracy: 0.8446 - val_loss: 4.6653 - val_accuracy: 0.8222 Epoch 59/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5752 - accuracy: 0.8447 - val_loss: 6.0078 - val_accuracy: 0.8014 Epoch 60/60 93/93 [==============================] - 11s 118ms/step - loss: 4.5695 - accuracy: 0.8452 - val_loss: 4.8178 - val_accuracy: 0.8192
(訳者注: 実験結果)
Epoch 1/60 93/93 [==============================] - 22s 87ms/step - loss: 5.3702 - accuracy: 0.7357 - val_loss: 4.8234 - val_accuracy: 0.7959 Epoch 2/60 93/93 [==============================] - 7s 74ms/step - loss: 4.6737 - accuracy: 0.8238 - val_loss: 4.9133 - val_accuracy: 0.8133 Epoch 3/60 93/93 [==============================] - 7s 75ms/step - loss: 4.7244 - accuracy: 0.8216 - val_loss: 5.2728 - val_accuracy: 0.7539 Epoch 4/60 93/93 [==============================] - 7s 74ms/step - loss: 4.7115 - accuracy: 0.8144 - val_loss: 5.2501 - val_accuracy: 0.8020 Epoch 5/60 93/93 [==============================] - 7s 74ms/step - loss: 4.7318 - accuracy: 0.8130 - val_loss: 4.8355 - val_accuracy: 0.8264 Epoch 6/60 93/93 [==============================] - 7s 74ms/step - loss: 4.6112 - accuracy: 0.8364 - val_loss: 13.2324 - val_accuracy: 0.7926 Epoch 7/60 93/93 [==============================] - 7s 78ms/step - loss: 4.6172 - accuracy: 0.8353 - val_loss: 4.7182 - val_accuracy: 0.8403 Epoch 8/60 93/93 [==============================] - 7s 75ms/step - loss: 4.5751 - accuracy: 0.8424 - val_loss: 5.6088 - val_accuracy: 0.8353 Epoch 9/60 93/93 [==============================] - 7s 76ms/step - loss: 4.5851 - accuracy: 0.8425 - val_loss: 4.8051 - val_accuracy: 0.7950 Epoch 10/60 93/93 [==============================] - 7s 74ms/step - loss: 4.6435 - accuracy: 0.8338 - val_loss: 54.9336 - val_accuracy: 0.7872 Epoch 11/60 93/93 [==============================] - 7s 74ms/step - loss: 4.5741 - accuracy: 0.8447 - val_loss: 4.9625 - val_accuracy: 0.8154 Epoch 12/60 93/93 [==============================] - 7s 75ms/step - loss: 4.5540 - accuracy: 0.8484 - val_loss: 32.4925 - val_accuracy: 0.7406 Epoch 13/60 93/93 [==============================] - 7s 75ms/step - loss: 4.5377 - accuracy: 0.8514 - val_loss: 428.6281 - val_accuracy: 0.6626 Epoch 14/60 93/93 [==============================] - 7s 77ms/step - loss: 4.7567 - accuracy: 0.8115 - val_loss: 4.7127 - val_accuracy: 0.8267 Epoch 15/60 93/93 [==============================] - 7s 77ms/step - loss: 4.5393 - accuracy: 0.8501 - val_loss: 4.6364 - val_accuracy: 0.8311 Epoch 16/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4911 - accuracy: 0.8613 - val_loss: 5.4676 - val_accuracy: 0.8243 Epoch 17/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4850 - accuracy: 0.8619 - val_loss: 58570.8320 - val_accuracy: 0.7975 Epoch 18/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4766 - accuracy: 0.8637 - val_loss: 9.7880 - val_accuracy: 0.8211 Epoch 19/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4745 - accuracy: 0.8642 - val_loss: 10.1575 - val_accuracy: 0.8395 Epoch 20/60 93/93 [==============================] - 7s 78ms/step - loss: 4.4661 - accuracy: 0.8669 - val_loss: 4.5302 - val_accuracy: 0.8504 Epoch 21/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4657 - accuracy: 0.8671 - val_loss: 123.4067 - val_accuracy: 0.7972 Epoch 22/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4587 - accuracy: 0.8690 - val_loss: 6.2540 - val_accuracy: 0.8573 Epoch 23/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4521 - accuracy: 0.8704 - val_loss: 4.7138 - val_accuracy: 0.8560 Epoch 24/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4502 - accuracy: 0.8713 - val_loss: 4.7997 - val_accuracy: 0.8558 Epoch 25/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4695 - accuracy: 0.8682 - val_loss: 1727.5392 - val_accuracy: 0.7754 Epoch 26/60 93/93 [==============================] - 7s 78ms/step - loss: 4.4536 - accuracy: 0.8711 - val_loss: 4.5054 - val_accuracy: 0.8570 Epoch 27/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4442 - accuracy: 0.8736 - val_loss: 1847.9580 - val_accuracy: 0.8258 Epoch 28/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4419 - accuracy: 0.8739 - val_loss: 507323.7188 - val_accuracy: 0.6395 Epoch 29/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4410 - accuracy: 0.8749 - val_loss: 110.4861 - val_accuracy: 0.8279 Epoch 30/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4343 - accuracy: 0.8772 - val_loss: 13034.4678 - val_accuracy: 0.7615 Epoch 31/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4371 - accuracy: 0.8762 - val_loss: 4.5390 - val_accuracy: 0.8575 Epoch 32/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4299 - accuracy: 0.8783 - val_loss: 5.1719 - val_accuracy: 0.8466 Epoch 33/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4260 - accuracy: 0.8795 - val_loss: 5.2370 - val_accuracy: 0.8486 Epoch 34/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4335 - accuracy: 0.8776 - val_loss: 4.5271 - val_accuracy: 0.8526 Epoch 35/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4188 - accuracy: 0.8814 - val_loss: 8.0959 - val_accuracy: 0.8575 Epoch 36/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4114 - accuracy: 0.8842 - val_loss: 4.9942 - val_accuracy: 0.8560 Epoch 37/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4314 - accuracy: 0.8788 - val_loss: 5.0881 - val_accuracy: 0.8516 Epoch 38/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4148 - accuracy: 0.8831 - val_loss: 4.5428 - val_accuracy: 0.8618 Epoch 39/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4310 - accuracy: 0.8795 - val_loss: 8.5377 - val_accuracy: 0.8391 Epoch 40/60 93/93 [==============================] - 7s 74ms/step - loss: 4.4110 - accuracy: 0.8840 - val_loss: 4840.1558 - val_accuracy: 0.8243 Epoch 41/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4105 - accuracy: 0.8844 - val_loss: 4.5062 - val_accuracy: 0.8619 Epoch 42/60 93/93 [==============================] - 7s 78ms/step - loss: 4.4047 - accuracy: 0.8863 - val_loss: 4.4981 - val_accuracy: 0.8628 Epoch 43/60 93/93 [==============================] - 7s 74ms/step - loss: 4.3969 - accuracy: 0.8895 - val_loss: 7.4334 - val_accuracy: 0.8504 Epoch 44/60 93/93 [==============================] - 7s 74ms/step - loss: 4.3950 - accuracy: 0.8898 - val_loss: 4.6111 - val_accuracy: 0.8606 Epoch 45/60 93/93 [==============================] - 7s 75ms/step - loss: 4.3893 - accuracy: 0.8907 - val_loss: 25.9290 - val_accuracy: 0.8233 Epoch 46/60 93/93 [==============================] - 7s 75ms/step - loss: 4.3951 - accuracy: 0.8898 - val_loss: 168.4719 - val_accuracy: 0.8410 Epoch 47/60 93/93 [==============================] - 7s 74ms/step - loss: 4.3971 - accuracy: 0.8887 - val_loss: 116.8822 - val_accuracy: 0.8558 Epoch 48/60 93/93 [==============================] - 7s 75ms/step - loss: 4.3864 - accuracy: 0.8924 - val_loss: 4.5119 - val_accuracy: 0.8598 Epoch 49/60 93/93 [==============================] - 7s 74ms/step - loss: 4.3787 - accuracy: 0.8948 - val_loss: 8.3084 - val_accuracy: 0.8552 Epoch 50/60 93/93 [==============================] - 7s 74ms/step - loss: 4.3846 - accuracy: 0.8937 - val_loss: 2636.2273 - val_accuracy: 0.7942 Epoch 51/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4077 - accuracy: 0.8863 - val_loss: 1912.4525 - val_accuracy: 0.7862 Epoch 52/60 93/93 [==============================] - 7s 74ms/step - loss: 4.3784 - accuracy: 0.8954 - val_loss: 4.5436 - val_accuracy: 0.8639 Epoch 53/60 93/93 [==============================] - 7s 74ms/step - loss: 4.3998 - accuracy: 0.8897 - val_loss: 4.6248 - val_accuracy: 0.8334 Epoch 54/60 93/93 [==============================] - 7s 75ms/step - loss: 4.4187 - accuracy: 0.8857 - val_loss: 4.5236 - val_accuracy: 0.8604 Epoch 55/60 93/93 [==============================] - 7s 74ms/step - loss: 4.3752 - accuracy: 0.8970 - val_loss: 4.5569 - val_accuracy: 0.8630 Epoch 56/60 93/93 [==============================] - 7s 74ms/step - loss: 4.3772 - accuracy: 0.8965 - val_loss: 64.3617 - val_accuracy: 0.8405 Epoch 57/60 93/93 [==============================] - 7s 74ms/step - loss: 4.3951 - accuracy: 0.8906 - val_loss: 4.5302 - val_accuracy: 0.8592 Epoch 58/60 93/93 [==============================] - 7s 74ms/step - loss: 4.3737 - accuracy: 0.8967 - val_loss: 180.0808 - val_accuracy: 0.8335 Epoch 59/60 93/93 [==============================] - 7s 75ms/step - loss: 4.3709 - accuracy: 0.8978 - val_loss: 7.8048 - val_accuracy: 0.8477 Epoch 60/60 93/93 [==============================] - 7s 74ms/step - loss: 4.3596 - accuracy: 0.9009 - val_loss: 486.6651 - val_accuracy: 0.8325 CPU times: user 7min 25s, sys: 26.2 s, total: 7min 51s Wall time: 7min 17s
訓練状況の可視化
def plot_result(item):
plt.plot(history.history[item], label=item)
plt.plot(history.history["val_" + item], label="val_" + item)
plt.xlabel("Epochs")
plt.ylabel(item)
plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
plt.legend()
plt.grid()
plt.show()
plot_result("loss")
plot_result("accuracy")
推論
validation_batch = next(iter(val_dataset))
val_predictions = segmentation_model.predict(validation_batch[0])
print(f"Validation prediction shape: {val_predictions.shape}")
def visualize_single_point_cloud(point_clouds, label_clouds, idx):
label_map = LABELS + ["none"]
point_cloud = point_clouds[idx]
label_cloud = label_clouds[idx]
visualize_data(point_cloud, [label_map[np.argmax(label)] for label in label_cloud])
idx = np.random.choice(len(validation_batch[0]))
print(f"Index selected: {idx}")
# Plotting with ground-truth.
visualize_single_point_cloud(validation_batch[0], validation_batch[1], idx)
# Plotting with predicted labels.
visualize_single_point_cloud(validation_batch[0], val_predictions, idx)
Validation prediction shape: (32, 1024, 5) Index selected: 24
Final notes
If you are interested in learning more about this topic, you may find this repository useful.
以上