Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

Keras 2 : examples : コンピュータビジョン – PointNet によるポイントクラウド・セグメンテーション

Posted on 12/07/202112/10/2021 by Sales Information

Keras 2 : examples : PointNet によるポイントクラウド・セグメンテーション (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/07/2021 (keras 2.7.0)

* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

  • Code examples : Computer Vision : Point cloud segmentation with PointNet (Author: Soumik Rakshit, Sayak Paul)

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス ★ 無料 Web セミナー開催中 ★

◆ クラスキャットは人工知能・テレワークに関する各種サービスを提供しております。お気軽にご相談ください :

  • 人工知能研究開発支援
    1. 人工知能研修サービス(経営者層向けオンサイト研修)
    2. テクニカルコンサルティングサービス
    3. 実証実験(プロトタイプ構築)
    4. アプリケーションへの実装

  • 人工知能研修サービス

  • PoC(概念実証)を失敗させないための支援

  • テレワーク & オンライン授業を支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。

◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • E-Mail:sales-info@classcat.com  ;  WebSite: www.classcat.com  ;  Facebook

 

 

Keras 2 : examples : PointNet によるポイントクラウド・セグメンテーション

Description: ポイントクラウドのセグメンテーションのための PointNet ベースのモデルの実装。

 

イントロダクション

「ポイントクラウド (点群)」は幾何学的形状データをストアするための重要なタイプのデータ構造です。イレギュラーな形式ゆえに、深層学習アプリケーションで使用される前に、通常の 3D ボクセルグリッドや画像のコレクションに変換されることも多く、このステップはデータを不必要に大きくしてしまいます。PointNet モデルのファミリーは、ポイントデータの順列不変性 (= permutation-invariance) の特性を尊重し、ポイントクラウドを直接的に消費することによりこの問題を解きます。PointNet モデルのファミリーは、オブジェクト分類, パーツ・セグメンテーション から シーン・セマンティック解析 までの範囲に渡るアプリケーションに対して単純で、統一されたアーキテクチャを提供します。

このサンプルでは、形状セグメンテーションのための PointNet アーキテクチャの実装を実演します。

リファレンス

  • PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation
  • PointNet によるポイントクラウド分類
  • Spatial Transformer Networks

 

インポート

import os
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from glob import glob

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import matplotlib.pyplot as plt

 

データセットのダウンロード

ShapeNet データセット は、 3D 形状の豊富なアノテーション付きの大規模なデータセットを作り上げる継続的な取り組みです。ShapeNetCore は full ShapeNet データセットのサブセットで、クリーンな単一 3D モデルと (手動で検証された) カテゴリーとアラインメント・アノテーションを含みます。それは 55 の一般的な物体カテゴリーをカバーし、約 51,300 の一意な 3D モデルを含みます。

このサンプルについては、PASCAL 3D+ の12 物体カテゴリーの一つを使用し、これは ShapenetCore データセットの一部として含まれています。

dataset_url = "https://git.io/JiY4i"

dataset_path = keras.utils.get_file(
    fname="shapenet.zip",
    origin=dataset_url,
    cache_subdir="datasets",
    hash_algorithm="auto",
    extract=True,
    archive_format="auto",
    cache_dir="datasets",
)

 

データセットのロード

モデルカテゴリーをそれぞれのディレクトリに、そして可視化のためにセグメンテーション・クラスをカラーに簡単にマップするためにデータセットのメタデータを解析します。

with open("/tmp/.keras/datasets/PartAnnotation/metadata.json") as json_file:
    metadata = json.load(json_file)

print(metadata)
{'Airplane': {'directory': '02691156', 'lables': ['wing', 'body', 'tail', 'engine'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Bag': {'directory': '02773838', 'lables': ['handle', 'body'], 'colors': ['blue', 'green']}, 'Cap': {'directory': '02954340', 'lables': ['panels', 'peak'], 'colors': ['blue', 'green']}, 'Car': {'directory': '02958343', 'lables': ['wheel', 'hood', 'roof'], 'colors': ['blue', 'green', 'red']}, 'Chair': {'directory': '03001627', 'lables': ['leg', 'arm', 'back', 'seat'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Earphone': {'directory': '03261776', 'lables': ['earphone', 'headband'], 'colors': ['blue', 'green']}, 'Guitar': {'directory': '03467517', 'lables': ['head', 'body', 'neck'], 'colors': ['blue', 'green', 'red']}, 'Knife': {'directory': '03624134', 'lables': ['handle', 'blade'], 'colors': ['blue', 'green']}, 'Lamp': {'directory': '03636649', 'lables': ['canopy', 'lampshade', 'base'], 'colors': ['blue', 'green', 'red']}, 'Laptop': {'directory': '03642806', 'lables': ['keyboard'], 'colors': ['blue']}, 'Motorbike': {'directory': '03790512', 'lables': ['wheel', 'handle', 'gas_tank', 'light', 'seat'], 'colors': ['blue', 'green', 'red', 'pink', 'yellow']}, 'Mug': {'directory': '03797390', 'lables': ['handle'], 'colors': ['blue']}, 'Pistol': {'directory': '03948459', 'lables': ['trigger_and_guard', 'handle', 'barrel'], 'colors': ['blue', 'green', 'red']}, 'Rocket': {'directory': '04099429', 'lables': ['nose', 'body', 'fin'], 'colors': ['blue', 'green', 'red']}, 'Skateboard': {'directory': '04225987', 'lables': ['wheel', 'deck'], 'colors': ['blue', 'green']}, 'Table': {'directory': '04379243', 'lables': ['leg', 'top'], 'colors': ['blue', 'green']}}

このサンプルでは、飛行機モデルのパーツをセグメント分けするために PointNet を訓練します。

points_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points".format(
    metadata["Airplane"]["directory"]
)
labels_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points_label".format(
    metadata["Airplane"]["directory"]
)
LABELS = metadata["Airplane"]["lables"]
COLORS = metadata["Airplane"]["colors"]

VAL_SPLIT = 0.2
NUM_SAMPLE_POINTS = 1024
BATCH_SIZE = 32
EPOCHS = 60
INITIAL_LR = 1e-3

 

データセットの構造化

飛行機のポイントクラウドとそれらのラベルから以下の in-memory データ構造を生成します :

  • point_clouds は np.array オブジェクトのリストで、x, y と z 座標の形式でポイントクラウドのデータを表します。軸 0 がポイントクラウドのポイント数を表し、軸 1 は座標を表します。all_labels は各座標のラベルを文字列として表わすリストです (主として可視化目的で必要とされます)。

  • test_point_clouds は point_clouds と同じ形式にありますが、ポイントクラウドの対応するラベルは持ちません。

  • all_labels は np.array オブジェクトのリストで、point_clouds リストに対応して、各座標に対するポイントクラウドのラベルを表します。

  • point_cloud_labels は np.array オブジェクトのリストで、point_clouds リストに対応して、one-hot エンコードされた形式で各座標に対するポイントクラウドのラベルを表します。
point_clouds, test_point_clouds = [], []
point_cloud_labels, all_labels = [], []

points_files = glob(os.path.join(points_dir, "*.pts"))
for point_file in tqdm(points_files):
    point_cloud = np.loadtxt(point_file)
    if point_cloud.shape[0] < NUM_SAMPLE_POINTS:
        continue

    # Get the file-id of the current point cloud for parsing its
    # labels.
    file_id = point_file.split("/")[-1].split(".")[0]
    label_data, num_labels = {}, 0
    for label in LABELS:
        label_file = os.path.join(labels_dir, label, file_id + ".seg")
        if os.path.exists(label_file):
            label_data[label] = np.loadtxt(label_file).astype("float32")
            num_labels = len(label_data[label])

    # Point clouds having labels will be our training samples.
    try:
        label_map = ["none"] * num_labels
        for label in LABELS:
            for i, data in enumerate(label_data[label]):
                label_map[i] = label if data == 1 else label_map[i]
        label_data = [
            LABELS.index(label) if label != "none" else len(LABELS)
            for label in label_map
        ]
        # Apply one-hot encoding to the dense label representation.
        label_data = keras.utils.to_categorical(label_data, num_classes=len(LABELS) + 1)

        point_clouds.append(point_cloud)
        point_cloud_labels.append(label_data)
        all_labels.append(label_map)
    except KeyError:
        test_point_clouds.append(point_cloud)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4045/4045 [03:35<00:00, 18.76it/s]

次に、生成したばかりの in-memory 配列から幾つかのサンプルを見ます :

for _ in range(5):
    i = random.randint(0, len(point_clouds) - 1)
    print(f"point_clouds[{i}].shape:", point_clouds[0].shape)
    print(f"point_cloud_labels[{i}].shape:", point_cloud_labels[0].shape)
    for j in range(5):
        print(
            f"all_labels[{i}][{j}]:",
            all_labels[i][j],
            f"\tpoint_cloud_labels[{i}][{j}]:",
            point_cloud_labels[i][j],
            "\n",
        )

次に、ポイントクラウドの一部をそれらのラベルと一緒に可視化しましょう。

def visualize_data(point_cloud, labels):
    df = pd.DataFrame(
        data={
            "x": point_cloud[:, 0],
            "y": point_cloud[:, 1],
            "z": point_cloud[:, 2],
            "label": labels,
        }
    )
    fig = plt.figure(figsize=(15, 10))
    ax = plt.axes(projection="3d")
    for index, label in enumerate(LABELS):
        c_df = df[df["label"] == label]
        try:
            ax.scatter(
                c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index]
            )
        except IndexError:
            pass
    ax.legend()
    plt.show()


visualize_data(point_clouds[0], all_labels[0])
visualize_data(point_clouds[300], all_labels[300])

 

前処理

ロードしたポイントクラウドの総ては可変数のポイントから成り、それらをまとめてバッチ処理することを困難にしています。この問題を打開するために、各ポイントクラウドから固定数をランダムにサンプリングします。また、データをスケール不変にするためにポイントクラウドを正規化します。

for index in tqdm(range(len(point_clouds))):
    current_point_cloud = point_clouds[index]
    current_label_cloud = point_cloud_labels[index]
    current_labels = all_labels[index]
    num_points = len(current_point_cloud)
    # Randomly sampling respective indices.
    sampled_indices = random.sample(list(range(num_points)), NUM_SAMPLE_POINTS)
    # Sampling points corresponding to sampled indices.
    sampled_point_cloud = np.array([current_point_cloud[i] for i in sampled_indices])
    # Sampling corresponding one-hot encoded labels.
    sampled_label_cloud = np.array([current_label_cloud[i] for i in sampled_indices])
    # Sampling corresponding labels for visualization.
    sampled_labels = np.array([current_labels[i] for i in sampled_indices])
    # Normalizing sampled point cloud.
    norm_point_cloud = sampled_point_cloud - np.mean(sampled_point_cloud, axis=0)
    norm_point_cloud /= np.max(np.linalg.norm(norm_point_cloud, axis=1))
    point_clouds[index] = norm_point_cloud
    point_cloud_labels[index] = sampled_label_cloud
    all_labels[index] = sampled_labels
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3694/3694 [00:07<00:00, 478.67it/s]

サンプリングそして正規化されたポイントクラウドをそれらの対応するラベルと一緒に可視化しましょう。

visualize_data(point_clouds[0], all_labels[0])
visualize_data(point_clouds[300], all_labels[300])

 

TensorFlow データセットの作成

訓練と検証データのために tf.data.Dataset オブジェクトを作成します。また訓練ポイントクラウドをそれらにランダム jitter を適用することで増強します。

def load_data(point_cloud_batch, label_cloud_batch):
    point_cloud_batch.set_shape([NUM_SAMPLE_POINTS, 3])
    label_cloud_batch.set_shape([NUM_SAMPLE_POINTS, len(LABELS) + 1])
    return point_cloud_batch, label_cloud_batch


def augment(point_cloud_batch, label_cloud_batch):
    noise = tf.random.uniform(
        tf.shape(label_cloud_batch), -0.005, 0.005, dtype=tf.float64
    )
    point_cloud_batch += noise[:, :, :3]
    return point_cloud_batch, label_cloud_batch


def generate_dataset(point_clouds, label_clouds, is_training=True):
    dataset = tf.data.Dataset.from_tensor_slices((point_clouds, label_clouds))
    dataset = dataset.shuffle(BATCH_SIZE * 100) if is_training else dataset
    dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size=BATCH_SIZE)
    dataset = (
        dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
        if is_training
        else dataset
    )
    return dataset


split_index = int(len(point_clouds) * (1 - VAL_SPLIT))
train_point_clouds = point_clouds[:split_index]
train_label_cloud = point_cloud_labels[:split_index]
total_training_examples = len(train_point_clouds)

val_point_clouds = point_clouds[split_index:]
val_label_cloud = point_cloud_labels[split_index:]

print("Num train point clouds:", len(train_point_clouds))
print("Num train point cloud labels:", len(train_label_cloud))
print("Num val point clouds:", len(val_point_clouds))
print("Num val point cloud labels:", len(val_label_cloud))

train_dataset = generate_dataset(train_point_clouds, train_label_cloud)
val_dataset = generate_dataset(val_point_clouds, val_label_cloud, is_training=False)

print("Train Dataset:", train_dataset)
print("Validation Dataset:", val_dataset)
Num train point clouds: 2955
Num train point cloud labels: 2955
Num val point clouds: 739
Num val point cloud labels: 739

Train Dataset: <ParallelMapDataset shapes: ((None, 1024, 3), (None, 1024, 5)), types: (tf.float64, tf.float32)>
Validation Dataset: <BatchDataset shapes: ((None, 1024, 3), (None, 1024, 5)), types: (tf.float64, tf.float32)>

 

PointNet モデル

下の図は PointNet モデル・ファミリーの内部を描写しています :

PointNet は入力データとして座標の 順序付けられていないセット を消費することを意図していると仮定すると、そのアーキテクチャはポイントクラウド・データの以下の特有の性質に一致している必要があります :

 

順列不変性 (= Permutation invariance)

ポイントクラウド・データの非構造化的な性質を考えると、n ポイントから成るスキャンは n! の順列を持ちます。続くデータ処理は異なる表現に対して不変でなければなりません。PointNet を入力の順列に対して不変であるようにするために、n 個の入力ポイントが高次元空間にマップされた時点で、(最大プーリングのような) 対称関数を使用します。その結果は、n 個の入力ポイントの集合的な特徴 (= signature) を捕捉することを目的とする 大域的特徴ベクトル です。この大域的特徴ベクトルはセグメンテーションのために局所的ポイント特徴とともに使用されます。

 

変換不変性 (= Transformation invariance)

セグメンテーション出力は、オブジェクトが並行移動やスケールのような特定の変換を受ける場合、不変であるべきです。与えられた入力ポイントクラウドに対して、ポーズ正規化を実現するために適切な剛体 (= rigid) 変換やアフィン変換を適用します。n 個の入力ポイントの各々はベクトルとして表現され、埋め込み空間に独立にマップされるので、幾何学的変換の適用は単純に各ポイントを変換行列で行列乗算することになります。これは Spatial Transformer ネットワーク の概念により動機づけられています。

T-Net から成る演算は PointNet の高位アーキテクチャにより動機づけられています。MLP (or 完全結合層) は入力ポイントを独立的にかつ一意に高次元空間にマップするために使用されます ; 最大プーリングは大域特徴ベクトル、その次元は次に完全結合層により削減されます、をエンコードするために使用されます。最後の完全結合層における入力依存な特徴は次にグローバルに訓練可能な重みとバイアスと組み合わされて、3 x 3 変換行列という結果になります。

 

ポイントの相互作用

隣接ポイント間の相互作用は有用な情報を携行していることがしばしばあります (i.e. 単一のポイントは単独で扱われるべきではありません)。分類が大域的特徴だけを利用する必要がある一方で、セグメンテーションは大域的ポイント特徴とともに局所的ポイント特徴も活用できなければなりません。

Note : このセクションで提示される図は 原論文 から引用されています。

PointNet モデルを構成するピースを知った今、モデルを実装できます。基本的なブロック i.e., 畳み込みブロックと多層パーセプトロン・ブロックを実装することから始めます。

def conv_block(x: tf.Tensor, filters: int, name: str) -> tf.Tensor:
    x = layers.Conv1D(filters, kernel_size=1, padding="valid", name=f"{name}_conv")(x)
    x = layers.BatchNormalization(momentum=0.0, name=f"{name}_batch_norm")(x)
    return layers.Activation("relu", name=f"{name}_relu")(x)


def mlp_block(x: tf.Tensor, filters: int, name: str) -> tf.Tensor:
    x = layers.Dense(filters, name=f"{name}_dense")(x)
    x = layers.BatchNormalization(momentum=0.0, name=f"{name}_batch_norm")(x)
    return layers.Activation("relu", name=f"{name}_relu")(x)

特徴空間で直交性を強要するために ( このサンプル から引用された) regularizer を実装します。これは、変換された特徴の大きさが大きく変化し過ぎないことを確実にするために必要です。

class OrthogonalRegularizer(keras.regularizers.Regularizer):
    """Reference: https://keras.io/examples/vision/pointnet/#build-a-model"""

    def __init__(self, num_features, l2reg=0.001):
        self.num_features = num_features
        self.l2reg = l2reg
        self.identity = tf.eye(num_features)

    def __call__(self, x):
        x = tf.reshape(x, (-1, self.num_features, self.num_features))
        xxt = tf.tensordot(x, x, axes=(2, 2))
        xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
        return tf.reduce_sum(self.l2reg * tf.square(xxt - self.identity))

    def get_config(self):
        config = super(TransformerEncoder, self).get_config()
        config.update({"num_features": self.num_features, "l2reg_strength": self.l2reg})
        return config

次のピースは先に説明した変換ネットワークです。

def transformation_net(inputs: tf.Tensor, num_features: int, name: str) -> tf.Tensor:
    """
    Reference: https://keras.io/examples/vision/pointnet/#build-a-model.

    The `filters` values come from the original paper:
    https://arxiv.org/abs/1612.00593.
    """
    x = conv_block(inputs, filters=64, name=f"{name}_1")
    x = conv_block(x, filters=128, name=f"{name}_2")
    x = conv_block(x, filters=1024, name=f"{name}_3")
    x = layers.GlobalMaxPooling1D()(x)
    x = mlp_block(x, filters=512, name=f"{name}_1_1")
    x = mlp_block(x, filters=256, name=f"{name}_2_1")
    return layers.Dense(
        num_features * num_features,
        kernel_initializer="zeros",
        bias_initializer=keras.initializers.Constant(np.eye(num_features).flatten()),
        activity_regularizer=OrthogonalRegularizer(num_features),
        name=f"{name}_final",
    )(x)


def transformation_block(inputs: tf.Tensor, num_features: int, name: str) -> tf.Tensor:
    transformed_features = transformation_net(inputs, num_features, name=name)
    transformed_features = layers.Reshape((num_features, num_features))(
        transformed_features
    )
    return layers.Dot(axes=(2, 1), name=f"{name}_mm")([inputs, transformed_features])

最後に、上記のブロックをまとめてセグメンテーション・モデルを実装します。

def get_shape_segmentation_model(num_points: int, num_classes: int) -> keras.Model:
    input_points = keras.Input(shape=(None, 3))

    # PointNet Classification Network.
    transformed_inputs = transformation_block(
        input_points, num_features=3, name="input_transformation_block"
    )
    features_64 = conv_block(transformed_inputs, filters=64, name="features_64")
    features_128_1 = conv_block(features_64, filters=128, name="features_128_1")
    features_128_2 = conv_block(features_128_1, filters=128, name="features_128_2")
    transformed_features = transformation_block(
        features_128_2, num_features=128, name="transformed_features"
    )
    features_512 = conv_block(transformed_features, filters=512, name="features_512")
    features_2048 = conv_block(features_512, filters=2048, name="pre_maxpool_block")
    global_features = layers.MaxPool1D(pool_size=num_points, name="global_features")(
        features_2048
    )
    global_features = tf.tile(global_features, [1, num_points, 1])

    # Segmentation head.
    segmentation_input = layers.Concatenate(name="segmentation_input")(
        [
            features_64,
            features_128_1,
            features_128_2,
            transformed_features,
            features_512,
            global_features,
        ]
    )
    segmentation_features = conv_block(
        segmentation_input, filters=128, name="segmentation_features"
    )
    outputs = layers.Conv1D(
        num_classes, kernel_size=1, activation="softmax", name="segmentation_head"
    )(segmentation_features)
    return keras.Model(input_points, outputs)

 

モデルのインスタンス化

x, y = next(iter(train_dataset))

num_points = x.shape[1]
num_classes = y.shape[-1]

segmentation_model = get_shape_segmentation_model(num_points, num_classes)
segmentation_model.summary()
2021-10-25 01:26:33.563133: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, None, 3)]    0                                            
__________________________________________________________________________________________________
input_transformation_block_1_co (None, None, 64)     256         input_1[0][0]                    
__________________________________________________________________________________________________
input_transformation_block_1_ba (None, None, 64)     256         input_transformation_block_1_conv
__________________________________________________________________________________________________
input_transformation_block_1_re (None, None, 64)     0           input_transformation_block_1_batc
__________________________________________________________________________________________________
input_transformation_block_2_co (None, None, 128)    8320        input_transformation_block_1_relu
__________________________________________________________________________________________________
input_transformation_block_2_ba (None, None, 128)    512         input_transformation_block_2_conv
__________________________________________________________________________________________________
input_transformation_block_2_re (None, None, 128)    0           input_transformation_block_2_batc
__________________________________________________________________________________________________
input_transformation_block_3_co (None, None, 1024)   132096      input_transformation_block_2_relu
__________________________________________________________________________________________________
input_transformation_block_3_ba (None, None, 1024)   4096        input_transformation_block_3_conv
__________________________________________________________________________________________________
input_transformation_block_3_re (None, None, 1024)   0           input_transformation_block_3_batc
__________________________________________________________________________________________________
global_max_pooling1d (GlobalMax (None, 1024)         0           input_transformation_block_3_relu
__________________________________________________________________________________________________
input_transformation_block_1_1_ (None, 512)          524800      global_max_pooling1d[0][0]       
__________________________________________________________________________________________________
input_transformation_block_1_1_ (None, 512)          2048        input_transformation_block_1_1_de
__________________________________________________________________________________________________
input_transformation_block_1_1_ (None, 512)          0           input_transformation_block_1_1_ba
__________________________________________________________________________________________________
input_transformation_block_2_1_ (None, 256)          131328      input_transformation_block_1_1_re
__________________________________________________________________________________________________
input_transformation_block_2_1_ (None, 256)          1024        input_transformation_block_2_1_de
__________________________________________________________________________________________________
input_transformation_block_2_1_ (None, 256)          0           input_transformation_block_2_1_ba
__________________________________________________________________________________________________
input_transformation_block_fina (None, 9)            2313        input_transformation_block_2_1_re
__________________________________________________________________________________________________
reshape (Reshape)               (None, 3, 3)         0           input_transformation_block_final[
__________________________________________________________________________________________________
input_transformation_block_mm ( (None, None, 3)      0           input_1[0][0]                    
                                                                 reshape[0][0]                    
__________________________________________________________________________________________________
features_64_conv (Conv1D)       (None, None, 64)     256         input_transformation_block_mm[0][
__________________________________________________________________________________________________
features_64_batch_norm (BatchNo (None, None, 64)     256         features_64_conv[0][0]           
__________________________________________________________________________________________________
features_64_relu (Activation)   (None, None, 64)     0           features_64_batch_norm[0][0]     
__________________________________________________________________________________________________
features_128_1_conv (Conv1D)    (None, None, 128)    8320        features_64_relu[0][0]           
__________________________________________________________________________________________________
features_128_1_batch_norm (Batc (None, None, 128)    512         features_128_1_conv[0][0]        
__________________________________________________________________________________________________
features_128_1_relu (Activation (None, None, 128)    0           features_128_1_batch_norm[0][0]  
__________________________________________________________________________________________________
features_128_2_conv (Conv1D)    (None, None, 128)    16512       features_128_1_relu[0][0]        
__________________________________________________________________________________________________
features_128_2_batch_norm (Batc (None, None, 128)    512         features_128_2_conv[0][0]        
__________________________________________________________________________________________________
features_128_2_relu (Activation (None, None, 128)    0           features_128_2_batch_norm[0][0]  
__________________________________________________________________________________________________
transformed_features_1_conv (Co (None, None, 64)     8256        features_128_2_relu[0][0]        
__________________________________________________________________________________________________
transformed_features_1_batch_no (None, None, 64)     256         transformed_features_1_conv[0][0]
__________________________________________________________________________________________________
transformed_features_1_relu (Ac (None, None, 64)     0           transformed_features_1_batch_norm
__________________________________________________________________________________________________
transformed_features_2_conv (Co (None, None, 128)    8320        transformed_features_1_relu[0][0]
__________________________________________________________________________________________________
transformed_features_2_batch_no (None, None, 128)    512         transformed_features_2_conv[0][0]
__________________________________________________________________________________________________
transformed_features_2_relu (Ac (None, None, 128)    0           transformed_features_2_batch_norm
__________________________________________________________________________________________________
transformed_features_3_conv (Co (None, None, 1024)   132096      transformed_features_2_relu[0][0]
__________________________________________________________________________________________________
transformed_features_3_batch_no (None, None, 1024)   4096        transformed_features_3_conv[0][0]
__________________________________________________________________________________________________
transformed_features_3_relu (Ac (None, None, 1024)   0           transformed_features_3_batch_norm
__________________________________________________________________________________________________
global_max_pooling1d_1 (GlobalM (None, 1024)         0           transformed_features_3_relu[0][0]
__________________________________________________________________________________________________
transformed_features_1_1_dense  (None, 512)          524800      global_max_pooling1d_1[0][0]     
__________________________________________________________________________________________________
transformed_features_1_1_batch_ (None, 512)          2048        transformed_features_1_1_dense[0]
__________________________________________________________________________________________________
transformed_features_1_1_relu ( (None, 512)          0           transformed_features_1_1_batch_no
__________________________________________________________________________________________________
transformed_features_2_1_dense  (None, 256)          131328      transformed_features_1_1_relu[0][
__________________________________________________________________________________________________
transformed_features_2_1_batch_ (None, 256)          1024        transformed_features_2_1_dense[0]
__________________________________________________________________________________________________
transformed_features_2_1_relu ( (None, 256)          0           transformed_features_2_1_batch_no
__________________________________________________________________________________________________
transformed_features_final (Den (None, 16384)        4210688     transformed_features_2_1_relu[0][
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 128, 128)     0           transformed_features_final[0][0] 
__________________________________________________________________________________________________
transformed_features_mm (Dot)   (None, None, 128)    0           features_128_2_relu[0][0]        
                                                                 reshape_1[0][0]                  
__________________________________________________________________________________________________
features_512_conv (Conv1D)      (None, None, 512)    66048       transformed_features_mm[0][0]    
__________________________________________________________________________________________________
features_512_batch_norm (BatchN (None, None, 512)    2048        features_512_conv[0][0]          
__________________________________________________________________________________________________
features_512_relu (Activation)  (None, None, 512)    0           features_512_batch_norm[0][0]    
__________________________________________________________________________________________________
pre_maxpool_block_conv (Conv1D) (None, None, 2048)   1050624     features_512_relu[0][0]          
__________________________________________________________________________________________________
pre_maxpool_block_batch_norm (B (None, None, 2048)   8192        pre_maxpool_block_conv[0][0]     
__________________________________________________________________________________________________
pre_maxpool_block_relu (Activat (None, None, 2048)   0           pre_maxpool_block_batch_norm[0][0
__________________________________________________________________________________________________
global_features (MaxPooling1D)  (None, None, 2048)   0           pre_maxpool_block_relu[0][0]     
__________________________________________________________________________________________________
tf.tile (TFOpLambda)            (None, None, 2048)   0           global_features[0][0]            
__________________________________________________________________________________________________
segmentation_input (Concatenate (None, None, 3008)   0           features_64_relu[0][0]           
                                                                 features_128_1_relu[0][0]        
                                                                 features_128_2_relu[0][0]        
                                                                 transformed_features_mm[0][0]    
                                                                 features_512_relu[0][0]          
                                                                 tf.tile[0][0]                    
__________________________________________________________________________________________________
segmentation_features_conv (Con (None, None, 128)    385152      segmentation_input[0][0]         
__________________________________________________________________________________________________
segmentation_features_batch_nor (None, None, 128)    512         segmentation_features_conv[0][0] 
__________________________________________________________________________________________________
segmentation_features_relu (Act (None, None, 128)    0           segmentation_features_batch_norm[
__________________________________________________________________________________________________
segmentation_head (Conv1D)      (None, None, 5)      645         segmentation_features_relu[0][0] 
==================================================================================================
Total params: 7,370,062
Trainable params: 7,356,110
Non-trainable params: 13,952

 

訓練

訓練については、著者らは初期学習率を 20 エポック毎に半分にする学習率スケジュールの使用を勧めています。このサンプルでは、15 エポックにしています。

training_step_size = total_training_examples // BATCH_SIZE
total_training_steps = training_step_size * EPOCHS
print(f"Total training steps: {total_training_steps}.")

lr_schedule = keras.optimizers.schedules.PiecewiseConstantDecay(
    boundaries=[training_step_size * 15, training_step_size * 15],
    values=[INITIAL_LR, INITIAL_LR * 0.5, INITIAL_LR * 0.25],
)

steps = tf.range(total_training_steps, dtype=tf.int32)
lrs = [lr_schedule(step) for step in steps]

plt.plot(lrs)
plt.xlabel("Steps")
plt.ylabel("Learning Rate")
plt.show()
Total training steps: 5520.

最後に、実験を実行するためのユティリティを実装してモデル訓練を起動します。

def run_experiment(epochs):

    segmentation_model = get_shape_segmentation_model(num_points, num_classes)
    segmentation_model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
        loss=keras.losses.CategoricalCrossentropy(),
        metrics=["accuracy"],
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_loss",
        save_best_only=True,
        save_weights_only=True,
    )

    history = segmentation_model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=[checkpoint_callback],
    )

    segmentation_model.load_weights(checkpoint_filepath)
    return segmentation_model, history


segmentation_model, history = run_experiment(epochs=EPOCHS)
Epoch 1/60
93/93 [==============================] - 28s 127ms/step - loss: 5.3556 - accuracy: 0.7448 - val_loss: 5.8386 - val_accuracy: 0.7471
Epoch 2/60
93/93 [==============================] - 11s 117ms/step - loss: 4.7077 - accuracy: 0.8181 - val_loss: 5.2614 - val_accuracy: 0.7793
Epoch 3/60
93/93 [==============================] - 11s 118ms/step - loss: 4.6566 - accuracy: 0.8301 - val_loss: 4.7907 - val_accuracy: 0.8269
Epoch 4/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6059 - accuracy: 0.8406 - val_loss: 4.6031 - val_accuracy: 0.8482
Epoch 5/60
93/93 [==============================] - 11s 118ms/step - loss: 4.5828 - accuracy: 0.8444 - val_loss: 4.7692 - val_accuracy: 0.8220
Epoch 6/60
93/93 [==============================] - 11s 118ms/step - loss: 4.6150 - accuracy: 0.8408 - val_loss: 5.4460 - val_accuracy: 0.8192
Epoch 7/60
93/93 [==============================] - 11s 117ms/step - loss: 67.5943 - accuracy: 0.7378 - val_loss: 1617.1846 - val_accuracy: 0.5191
Epoch 8/60
93/93 [==============================] - 11s 117ms/step - loss: 15.2910 - accuracy: 0.6651 - val_loss: 8.1014 - val_accuracy: 0.7046
Epoch 9/60
93/93 [==============================] - 11s 117ms/step - loss: 6.8878 - accuracy: 0.7368 - val_loss: 14.2311 - val_accuracy: 0.6949
Epoch 10/60
93/93 [==============================] - 11s 117ms/step - loss: 5.8362 - accuracy: 0.7549 - val_loss: 14.6942 - val_accuracy: 0.6350
Epoch 11/60
93/93 [==============================] - 11s 117ms/step - loss: 5.4777 - accuracy: 0.7648 - val_loss: 44.1037 - val_accuracy: 0.6422
Epoch 12/60
93/93 [==============================] - 11s 117ms/step - loss: 5.2688 - accuracy: 0.7712 - val_loss: 4.9977 - val_accuracy: 0.7692
Epoch 13/60
93/93 [==============================] - 11s 117ms/step - loss: 5.1041 - accuracy: 0.7837 - val_loss: 6.0642 - val_accuracy: 0.7577
Epoch 14/60
93/93 [==============================] - 11s 117ms/step - loss: 5.0011 - accuracy: 0.7862 - val_loss: 4.9313 - val_accuracy: 0.7840
Epoch 15/60
93/93 [==============================] - 11s 117ms/step - loss: 4.8910 - accuracy: 0.7953 - val_loss: 5.8368 - val_accuracy: 0.7725
Epoch 16/60
93/93 [==============================] - 11s 117ms/step - loss: 4.8698 - accuracy: 0.8074 - val_loss: 73.0260 - val_accuracy: 0.7251
Epoch 17/60
93/93 [==============================] - 11s 117ms/step - loss: 4.8299 - accuracy: 0.8109 - val_loss: 17.1503 - val_accuracy: 0.7415
Epoch 18/60
93/93 [==============================] - 11s 117ms/step - loss: 4.8147 - accuracy: 0.8111 - val_loss: 62.2765 - val_accuracy: 0.7344
Epoch 19/60
93/93 [==============================] - 11s 117ms/step - loss: 4.8316 - accuracy: 0.8141 - val_loss: 5.2200 - val_accuracy: 0.7890
Epoch 20/60
93/93 [==============================] - 11s 117ms/step - loss: 4.7853 - accuracy: 0.8142 - val_loss: 5.7062 - val_accuracy: 0.7719
Epoch 21/60
93/93 [==============================] - 11s 117ms/step - loss: 4.7753 - accuracy: 0.8157 - val_loss: 6.2089 - val_accuracy: 0.7839
Epoch 22/60
93/93 [==============================] - 11s 117ms/step - loss: 4.7681 - accuracy: 0.8161 - val_loss: 5.1077 - val_accuracy: 0.8021
Epoch 23/60
93/93 [==============================] - 11s 117ms/step - loss: 4.7554 - accuracy: 0.8187 - val_loss: 4.7912 - val_accuracy: 0.7912
Epoch 24/60
93/93 [==============================] - 11s 117ms/step - loss: 4.7355 - accuracy: 0.8197 - val_loss: 4.9164 - val_accuracy: 0.7978
Epoch 25/60
93/93 [==============================] - 11s 117ms/step - loss: 4.7483 - accuracy: 0.8197 - val_loss: 13.4724 - val_accuracy: 0.7631
Epoch 26/60
93/93 [==============================] - 11s 117ms/step - loss: 4.7200 - accuracy: 0.8218 - val_loss: 8.3074 - val_accuracy: 0.7596
Epoch 27/60
93/93 [==============================] - 11s 118ms/step - loss: 4.7192 - accuracy: 0.8231 - val_loss: 12.4468 - val_accuracy: 0.7591
Epoch 28/60
93/93 [==============================] - 11s 117ms/step - loss: 4.7151 - accuracy: 0.8241 - val_loss: 23.8681 - val_accuracy: 0.7689
Epoch 29/60
93/93 [==============================] - 11s 117ms/step - loss: 4.7096 - accuracy: 0.8237 - val_loss: 4.9069 - val_accuracy: 0.8104
Epoch 30/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6991 - accuracy: 0.8257 - val_loss: 4.9858 - val_accuracy: 0.7950
Epoch 31/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6852 - accuracy: 0.8260 - val_loss: 5.0130 - val_accuracy: 0.7678
Epoch 32/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6630 - accuracy: 0.8286 - val_loss: 4.8523 - val_accuracy: 0.7676
Epoch 33/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6837 - accuracy: 0.8281 - val_loss: 5.4347 - val_accuracy: 0.8095
Epoch 34/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6571 - accuracy: 0.8296 - val_loss: 10.4595 - val_accuracy: 0.7410
Epoch 35/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6460 - accuracy: 0.8321 - val_loss: 4.9189 - val_accuracy: 0.8083
Epoch 36/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6430 - accuracy: 0.8327 - val_loss: 5.8674 - val_accuracy: 0.7911
Epoch 37/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6530 - accuracy: 0.8309 - val_loss: 4.7946 - val_accuracy: 0.8032
Epoch 38/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6391 - accuracy: 0.8318 - val_loss: 5.0111 - val_accuracy: 0.8024
Epoch 39/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6521 - accuracy: 0.8336 - val_loss: 8.1558 - val_accuracy: 0.7727
Epoch 40/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6443 - accuracy: 0.8329 - val_loss: 42.8513 - val_accuracy: 0.7688
Epoch 41/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6316 - accuracy: 0.8342 - val_loss: 5.0960 - val_accuracy: 0.8066
Epoch 42/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6322 - accuracy: 0.8335 - val_loss: 5.0634 - val_accuracy: 0.8158
Epoch 43/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6175 - accuracy: 0.8370 - val_loss: 6.0642 - val_accuracy: 0.8062
Epoch 44/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6175 - accuracy: 0.8371 - val_loss: 11.1805 - val_accuracy: 0.7790
Epoch 45/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6056 - accuracy: 0.8377 - val_loss: 4.7359 - val_accuracy: 0.8145
Epoch 46/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6108 - accuracy: 0.8383 - val_loss: 5.7125 - val_accuracy: 0.7713
Epoch 47/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6103 - accuracy: 0.8377 - val_loss: 6.3271 - val_accuracy: 0.8105
Epoch 48/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6020 - accuracy: 0.8383 - val_loss: 14.2876 - val_accuracy: 0.7529
Epoch 49/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6035 - accuracy: 0.8382 - val_loss: 4.8244 - val_accuracy: 0.8143
Epoch 50/60
93/93 [==============================] - 11s 117ms/step - loss: 4.6076 - accuracy: 0.8381 - val_loss: 8.2636 - val_accuracy: 0.7528
Epoch 51/60
93/93 [==============================] - 11s 117ms/step - loss: 4.5927 - accuracy: 0.8399 - val_loss: 4.6473 - val_accuracy: 0.8266
Epoch 52/60
93/93 [==============================] - 11s 117ms/step - loss: 4.5927 - accuracy: 0.8408 - val_loss: 4.6443 - val_accuracy: 0.8276
Epoch 53/60
93/93 [==============================] - 11s 117ms/step - loss: 4.5852 - accuracy: 0.8413 - val_loss: 5.1300 - val_accuracy: 0.7768
Epoch 54/60
93/93 [==============================] - 11s 117ms/step - loss: 4.5787 - accuracy: 0.8426 - val_loss: 8.9590 - val_accuracy: 0.7582
Epoch 55/60
93/93 [==============================] - 11s 117ms/step - loss: 4.5837 - accuracy: 0.8410 - val_loss: 5.1501 - val_accuracy: 0.8117
Epoch 56/60
93/93 [==============================] - 11s 117ms/step - loss: 4.5875 - accuracy: 0.8422 - val_loss: 31.3518 - val_accuracy: 0.7590
Epoch 57/60
93/93 [==============================] - 11s 117ms/step - loss: 4.5821 - accuracy: 0.8427 - val_loss: 4.8853 - val_accuracy: 0.8144
Epoch 58/60
93/93 [==============================] - 11s 117ms/step - loss: 4.5751 - accuracy: 0.8446 - val_loss: 4.6653 - val_accuracy: 0.8222
Epoch 59/60
93/93 [==============================] - 11s 117ms/step - loss: 4.5752 - accuracy: 0.8447 - val_loss: 6.0078 - val_accuracy: 0.8014
Epoch 60/60
93/93 [==============================] - 11s 118ms/step - loss: 4.5695 - accuracy: 0.8452 - val_loss: 4.8178 - val_accuracy: 0.8192

(訳者注: 実験結果)

Epoch 1/60
93/93 [==============================] - 22s 87ms/step - loss: 5.3702 - accuracy: 0.7357 - val_loss: 4.8234 - val_accuracy: 0.7959
Epoch 2/60
93/93 [==============================] - 7s 74ms/step - loss: 4.6737 - accuracy: 0.8238 - val_loss: 4.9133 - val_accuracy: 0.8133
Epoch 3/60
93/93 [==============================] - 7s 75ms/step - loss: 4.7244 - accuracy: 0.8216 - val_loss: 5.2728 - val_accuracy: 0.7539
Epoch 4/60
93/93 [==============================] - 7s 74ms/step - loss: 4.7115 - accuracy: 0.8144 - val_loss: 5.2501 - val_accuracy: 0.8020
Epoch 5/60
93/93 [==============================] - 7s 74ms/step - loss: 4.7318 - accuracy: 0.8130 - val_loss: 4.8355 - val_accuracy: 0.8264
Epoch 6/60
93/93 [==============================] - 7s 74ms/step - loss: 4.6112 - accuracy: 0.8364 - val_loss: 13.2324 - val_accuracy: 0.7926
Epoch 7/60
93/93 [==============================] - 7s 78ms/step - loss: 4.6172 - accuracy: 0.8353 - val_loss: 4.7182 - val_accuracy: 0.8403
Epoch 8/60
93/93 [==============================] - 7s 75ms/step - loss: 4.5751 - accuracy: 0.8424 - val_loss: 5.6088 - val_accuracy: 0.8353
Epoch 9/60
93/93 [==============================] - 7s 76ms/step - loss: 4.5851 - accuracy: 0.8425 - val_loss: 4.8051 - val_accuracy: 0.7950
Epoch 10/60
93/93 [==============================] - 7s 74ms/step - loss: 4.6435 - accuracy: 0.8338 - val_loss: 54.9336 - val_accuracy: 0.7872
Epoch 11/60
93/93 [==============================] - 7s 74ms/step - loss: 4.5741 - accuracy: 0.8447 - val_loss: 4.9625 - val_accuracy: 0.8154
Epoch 12/60
93/93 [==============================] - 7s 75ms/step - loss: 4.5540 - accuracy: 0.8484 - val_loss: 32.4925 - val_accuracy: 0.7406
Epoch 13/60
93/93 [==============================] - 7s 75ms/step - loss: 4.5377 - accuracy: 0.8514 - val_loss: 428.6281 - val_accuracy: 0.6626
Epoch 14/60
93/93 [==============================] - 7s 77ms/step - loss: 4.7567 - accuracy: 0.8115 - val_loss: 4.7127 - val_accuracy: 0.8267
Epoch 15/60
93/93 [==============================] - 7s 77ms/step - loss: 4.5393 - accuracy: 0.8501 - val_loss: 4.6364 - val_accuracy: 0.8311
Epoch 16/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4911 - accuracy: 0.8613 - val_loss: 5.4676 - val_accuracy: 0.8243
Epoch 17/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4850 - accuracy: 0.8619 - val_loss: 58570.8320 - val_accuracy: 0.7975
Epoch 18/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4766 - accuracy: 0.8637 - val_loss: 9.7880 - val_accuracy: 0.8211
Epoch 19/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4745 - accuracy: 0.8642 - val_loss: 10.1575 - val_accuracy: 0.8395
Epoch 20/60
93/93 [==============================] - 7s 78ms/step - loss: 4.4661 - accuracy: 0.8669 - val_loss: 4.5302 - val_accuracy: 0.8504
Epoch 21/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4657 - accuracy: 0.8671 - val_loss: 123.4067 - val_accuracy: 0.7972
Epoch 22/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4587 - accuracy: 0.8690 - val_loss: 6.2540 - val_accuracy: 0.8573
Epoch 23/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4521 - accuracy: 0.8704 - val_loss: 4.7138 - val_accuracy: 0.8560
Epoch 24/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4502 - accuracy: 0.8713 - val_loss: 4.7997 - val_accuracy: 0.8558
Epoch 25/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4695 - accuracy: 0.8682 - val_loss: 1727.5392 - val_accuracy: 0.7754
Epoch 26/60
93/93 [==============================] - 7s 78ms/step - loss: 4.4536 - accuracy: 0.8711 - val_loss: 4.5054 - val_accuracy: 0.8570
Epoch 27/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4442 - accuracy: 0.8736 - val_loss: 1847.9580 - val_accuracy: 0.8258
Epoch 28/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4419 - accuracy: 0.8739 - val_loss: 507323.7188 - val_accuracy: 0.6395
Epoch 29/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4410 - accuracy: 0.8749 - val_loss: 110.4861 - val_accuracy: 0.8279
Epoch 30/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4343 - accuracy: 0.8772 - val_loss: 13034.4678 - val_accuracy: 0.7615
Epoch 31/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4371 - accuracy: 0.8762 - val_loss: 4.5390 - val_accuracy: 0.8575
Epoch 32/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4299 - accuracy: 0.8783 - val_loss: 5.1719 - val_accuracy: 0.8466
Epoch 33/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4260 - accuracy: 0.8795 - val_loss: 5.2370 - val_accuracy: 0.8486
Epoch 34/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4335 - accuracy: 0.8776 - val_loss: 4.5271 - val_accuracy: 0.8526
Epoch 35/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4188 - accuracy: 0.8814 - val_loss: 8.0959 - val_accuracy: 0.8575
Epoch 36/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4114 - accuracy: 0.8842 - val_loss: 4.9942 - val_accuracy: 0.8560
Epoch 37/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4314 - accuracy: 0.8788 - val_loss: 5.0881 - val_accuracy: 0.8516
Epoch 38/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4148 - accuracy: 0.8831 - val_loss: 4.5428 - val_accuracy: 0.8618
Epoch 39/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4310 - accuracy: 0.8795 - val_loss: 8.5377 - val_accuracy: 0.8391
Epoch 40/60
93/93 [==============================] - 7s 74ms/step - loss: 4.4110 - accuracy: 0.8840 - val_loss: 4840.1558 - val_accuracy: 0.8243
Epoch 41/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4105 - accuracy: 0.8844 - val_loss: 4.5062 - val_accuracy: 0.8619
Epoch 42/60
93/93 [==============================] - 7s 78ms/step - loss: 4.4047 - accuracy: 0.8863 - val_loss: 4.4981 - val_accuracy: 0.8628
Epoch 43/60
93/93 [==============================] - 7s 74ms/step - loss: 4.3969 - accuracy: 0.8895 - val_loss: 7.4334 - val_accuracy: 0.8504
Epoch 44/60
93/93 [==============================] - 7s 74ms/step - loss: 4.3950 - accuracy: 0.8898 - val_loss: 4.6111 - val_accuracy: 0.8606
Epoch 45/60
93/93 [==============================] - 7s 75ms/step - loss: 4.3893 - accuracy: 0.8907 - val_loss: 25.9290 - val_accuracy: 0.8233
Epoch 46/60
93/93 [==============================] - 7s 75ms/step - loss: 4.3951 - accuracy: 0.8898 - val_loss: 168.4719 - val_accuracy: 0.8410
Epoch 47/60
93/93 [==============================] - 7s 74ms/step - loss: 4.3971 - accuracy: 0.8887 - val_loss: 116.8822 - val_accuracy: 0.8558
Epoch 48/60
93/93 [==============================] - 7s 75ms/step - loss: 4.3864 - accuracy: 0.8924 - val_loss: 4.5119 - val_accuracy: 0.8598
Epoch 49/60
93/93 [==============================] - 7s 74ms/step - loss: 4.3787 - accuracy: 0.8948 - val_loss: 8.3084 - val_accuracy: 0.8552
Epoch 50/60
93/93 [==============================] - 7s 74ms/step - loss: 4.3846 - accuracy: 0.8937 - val_loss: 2636.2273 - val_accuracy: 0.7942
Epoch 51/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4077 - accuracy: 0.8863 - val_loss: 1912.4525 - val_accuracy: 0.7862
Epoch 52/60
93/93 [==============================] - 7s 74ms/step - loss: 4.3784 - accuracy: 0.8954 - val_loss: 4.5436 - val_accuracy: 0.8639
Epoch 53/60
93/93 [==============================] - 7s 74ms/step - loss: 4.3998 - accuracy: 0.8897 - val_loss: 4.6248 - val_accuracy: 0.8334
Epoch 54/60
93/93 [==============================] - 7s 75ms/step - loss: 4.4187 - accuracy: 0.8857 - val_loss: 4.5236 - val_accuracy: 0.8604
Epoch 55/60
93/93 [==============================] - 7s 74ms/step - loss: 4.3752 - accuracy: 0.8970 - val_loss: 4.5569 - val_accuracy: 0.8630
Epoch 56/60
93/93 [==============================] - 7s 74ms/step - loss: 4.3772 - accuracy: 0.8965 - val_loss: 64.3617 - val_accuracy: 0.8405
Epoch 57/60
93/93 [==============================] - 7s 74ms/step - loss: 4.3951 - accuracy: 0.8906 - val_loss: 4.5302 - val_accuracy: 0.8592
Epoch 58/60
93/93 [==============================] - 7s 74ms/step - loss: 4.3737 - accuracy: 0.8967 - val_loss: 180.0808 - val_accuracy: 0.8335
Epoch 59/60
93/93 [==============================] - 7s 75ms/step - loss: 4.3709 - accuracy: 0.8978 - val_loss: 7.8048 - val_accuracy: 0.8477
Epoch 60/60
93/93 [==============================] - 7s 74ms/step - loss: 4.3596 - accuracy: 0.9009 - val_loss: 486.6651 - val_accuracy: 0.8325
CPU times: user 7min 25s, sys: 26.2 s, total: 7min 51s
Wall time: 7min 17s

 

訓練状況の可視化

def plot_result(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_result("loss")
plot_result("accuracy")

 

推論

validation_batch = next(iter(val_dataset))
val_predictions = segmentation_model.predict(validation_batch[0])
print(f"Validation prediction shape: {val_predictions.shape}")


def visualize_single_point_cloud(point_clouds, label_clouds, idx):
    label_map = LABELS + ["none"]
    point_cloud = point_clouds[idx]
    label_cloud = label_clouds[idx]
    visualize_data(point_cloud, [label_map[np.argmax(label)] for label in label_cloud])


idx = np.random.choice(len(validation_batch[0]))
print(f"Index selected: {idx}")

# Plotting with ground-truth.
visualize_single_point_cloud(validation_batch[0], validation_batch[1], idx)

# Plotting with predicted labels.
visualize_single_point_cloud(validation_batch[0], val_predictions, idx)
Validation prediction shape: (32, 1024, 5)
Index selected: 24

 

Final notes

If you are interested in learning more about this topic, you may find this repository useful.

 

以上



クラスキャット

最近の投稿

  • LangGraph 0.5 : エージェント開発 : エージェント・アーキテクチャ
  • LangGraph 0.5 : エージェント開発 : ワークフローとエージェント
  • LangGraph 0.5 : エージェント開発 : エージェントの実行
  • LangGraph 0.5 : エージェント開発 : prebuilt コンポーネントを使用したエージェント開発
  • LangGraph 0.5 : Get started : ローカルサーバの実行

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (24) LangGraph 0.5 (9) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2021年12月
月 火 水 木 金 土 日
 12345
6789101112
13141516171819
20212223242526
2728293031  
« 11月   3月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme