Keras 3 : Google Colab 上のインストールと Stable Diffusion デモ (翻訳/解説)
翻訳 : クラスキャット セールスインフォメーション
作成日時 : 12/01/2023
* 本ページは、以下の fchollet によるノートブックを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
クラスキャット 人工知能 研究開発支援サービス
◆ クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Website: www.classcat.com ; ClassCatJP
Keras 3 : Google Colab 上のインストールと Stable Diffusion デモ
Keras 3 + KerasNLP + KerasCV インストール・メモ
以下で、スターマーク (*) は一時的な状況を示しています – すぐに改善されます!
- このデモを実行するには Cuda 12, TF 2.15 と Keras 3 が必要です。
- これら 3 つのパッケージまだ Colab のデフォルトではありません (*)。
- Cuda 12 は ‘pip install torch’ 経由でインストールされます。それは torch の依存関係です。
- pip のターゲット tensorflow[and-cuda] と jax[cuda12_pip] もありますが、Colab では (そして Colab でのみ) Cuda 12 のインストールは失敗します。
- TF 2.15 (現時点での最新版) は ‘pip install tensorflow’ によりインストールされます。
- keras-nlp と keras-cv はデータ処理のために TensorFlow への依存関係を持ちます (keras-nlp はトークン化のために tf-text を使用し、keras-cv は一部のデータ増強層で TF ops を使用しています)。従ってそれらは TensorFlow をインストールします、これは次に Keras 2 をインストールします。
- 実際に、TF 2.16 がリリースされるまでは、TensorFlow は Keras 3 ではなく、Keras 2 をインストールします (*)。
- 従って、Keras 3 は最上位でインストールされる必要があり、行 ‘pip install keras’ は tensorflow, keras-nlp や keras-cv のインストール後に来る必要があります。
- pip は Keras がインストールされている TensorFlow と互換ではないと警告を出しますが、来るべき TF 2.16 リリースまでは警告を無視してください (*)。
!pip install --upgrade torch
!pip install --upgrade tensorflow
!pip install --upgrade jax
!pip install --upgrade keras-nlp
!pip install --upgrade keras-cv
!pip install --upgrade keras
# Some care is required to install Keras 3. This is a temporary situation.
# See installation notes at the end of this notebook for details.
バックエンドの選択と表示ユティリティ
ここでは “jax” を選択しています :
#@title Backend selection and display utilities [run me]
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
def big_print(a,b):
html = '{}{}'.format(a, b)
display(HTML(html))
def plot_images(images):
plt.figure(figsize=(20, 20))
for i in range(len(images)):
ax = plt.subplot(1, len(images), i + 1)
plt.imshow(images[i])
plt.axis("off")
backend = 'jax' # @param ["jax", "tensorflow", "torch"]
インポート
import math, os, random
os.environ['KERAS_BACKEND'] = backend
import keras
import keras_cv
import keras_nlp
backend = keras.config.backend()
big_print('\u2B50 ', 'Keras version '+keras.version())
big_print('\u2B50 ', 'Running on '+backend.upper())
⭐ Keras version 3.0.0 ⭐ Running on JAX
Keras 3: Keras-NLP モデルのロードと実行
OPT は causal 言語モデルで、入力プロンプトを続けます。
# model
nlp_model = keras_nlp.models.OPTCausalLM.from_preset("opt_125m_en")
nlp_model.compile(sampler=keras_nlp.samplers.ContrastiveSampler())
prompt = "Hi, I'm a {} machine learning developer. \
What are you working on?".format(backend.upper())
response = nlp_model.generate(prompt, max_length=57)
response = response.replace(prompt, '')
big_print("\U0001F64B ",prompt)
big_print("\U0001F916 ",response)
🙋 Hi, I'm a JAX machine learning developer. What are you working on? 🤖 A lot of data science stuff right now, I'm trying to figure out how to get a handle on what's going on in the world
Keras 3: Keras-CV モデルのロードと実行
Stable diffusion モデルはテキストプロンプトから画像を生成します。
stable_diffusion = keras_cv.models.StableDiffusion()
if backend=="torch":
stable_diffusion.jit_compile = False # work in progress on PyTorch...
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
prompt = "A refined pencil sketch of a {} machine learning developer.".format(backend.upper())
images = stable_diffusion.text_to_image(prompt, batch_size=3)
big_print("\U0001F4DD ",prompt)
plot_images(images)
以上