Make組ブログ

Python、Webアプリや製品・サービス開発についてhirokikyが書きます。

SageMakerでのトレーニング時は事前にデータをエンコードしておく

SageMakerでトレーニングする際、強めのマシンを使う場合は事前にデータをエンコードしておくと良いです。データ量が多いと dataset.map(encoder) をするのも案外時間がかかります。そうするとGPUやTrnを有効活用していない時間も課金されてしまうので、事前にエンコードしたものを使いましょう。とくにSpotInstanceを使って繰り返しSageMakerのサーバーを起動する場合はやっておくことをおすすめします。

事前にデータをエンコードする

事前にエンコードするには、以下のように普段通りのエンコードし、 dataset.save_to_disk() を呼び出します。

dataset = load_dataset(...)
encoder = Encoder()
dataset = dataset.map(encoder)
dataset = dataset.remove_columns([...])
dataset.save_to_disk("data/mydataset_enc/")

読み込む際は load_from_disk() を使います。

from datasets import load_from_disk

load_from_disk("data/mydataset_enc/")

このフォルダーをS3にアップロードしておきましょう。

SageMakerでエンコードしたデータセットを使う

SageMakerでの学習を開始する際に、先ほどアップロードしたS3のパスを指定します。

from sagemaker.huggingface import HuggingFace

estimator = HuggingFace(
    entry_point="train_model.py",
)
estimator.fit({
    "train": "s3://my-dataset/mydataset_enc/",
})

このように指定するとSageMakerが実行前に自動でこのフォルダーをダウンロードしてくれます。
SageMakerの内部で実行する train_model.py では以下のようにファイルを読み込みましょう。

dataset = load_from_disk(os.environ["SM_CHANNEL_TRAIN"])

以上です!

これでSageMakerを起動するたびに大規模なデータセットエンコードを待つ必要がありません。

おわりに

Shodoでは独自のAIモデルを使ってAI校正のサービスを提供しています。

現在、Shodoではアドベントカレンダー応援クーポンを配布しております。80%オフでShodoを最長3ヶ月間使えるクーポンです。以下のクーポンコードをご購入時に入力して、このアドベントカレンダーの季節にShodoのAI校正をブログの執筆にお役立てください。

XMAS2024

shodo.ink

執筆:@hirokiky
Shodoで執筆されました