Gemma : Tutorials : サンプリング (翻訳/解説)
翻訳 : クラスキャット セールスインフォメーション
作成日時 : 03/03/2024
* 本ページは、google-deepmind/gemma レポジトリの以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Website: www.classcat.com ; ClassCatJP
Gemma : サンプリング・チュートリアル
この colab で Gemma チェックポイントをロードしてそれからサンプリングを行う方法を説明した詳細なチュートリアルを見つけるでしょう。
セットアップ
README (日本語) のインストール手順に従ってください。
チェックポイントのダウンロード
https://www.kaggle.com/models/google/gemma から Flax のチェックポイントをダウンロードしてローカルパスに配置します。
# @title Download the checkpoints
# Download the Flax checkpoints from https://www.kaggle.com/models/google/gemma
# and put the local paths below.
ckpt_path = ''
vocab_path = ''
Python のインポート
# @title Python imports
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
モデルで生成を始める
Flax で使用するため LLM のチェックポイントをロードして準備します。
# Load parameters
params = params_lib.load_and_format_params(ckpt_path)
トークナイザーをロードします、これは SentencePiece ライブラリを使用して構築します。
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
チェックポイントから正しい configuration を自動的にロードするため、transformer_lib.TransformerConfig.from_params 関数を使用します。このリリースでは未使用トークンにより語彙サイズは入力埋め込みの数よりも小さいことに注意してください。
transformer_config=transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024 # Number of time steps in the transformer's cache
)
transformer = transformer_lib.Transformer(transformer_config)
最後に、モデルとトークナイザーの上にサンプラーを構築します。
# Create a sampler with the right param shapes.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer'],
)
サンプリングを始める準備ができました! このサンプラーはジャストインタイム・コンパイルを使用しますので、入力 shape の変更は再コンパイルのトリガーとなり、遅くなる可能性があります。最速そして最も効率的な結果のためには、バッチサイズを一定に保持してください。
input_batch = [
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
"What are the planets of the solar system?",
]
out_data = sampler(
input_strings=input_batch,
total_generation_steps=300, # number of steps performed when generating
)
for input_string, out_string in zip(input_batch, out_data.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
print()
print(10*'#')
バブルソートの実装と太陽系の説明を得るはずです。
以上