Make組ブログ

Python、Webアプリや製品・サービス開発についてhirokikyが書きます。

PyTorchでn次元のTensorからn-1次元のTensorを使ってインデックスする

PyTorch でn次元のTensorがあったときに、n-1次元のTensorを使ってインデックスする方法です。 以下のように a というTensorと、 idx というインデックスがあったときに、 answer を取る方法です。

a = torch.tensor([
    [100, 200, 300, 400],
    [500, 600, 700, 800],
])
idx = torch.tensor([2, 3])

# 欲しい値は以下
answer = torch.tensor([300, 800])

方法のコード

>>> a.gather(1, idx.unsqueeze(1)).reshape(idx.shape)
tensor([300, 800])

一旦 idx.unsqueeze することで次元の数をあわせてから、 .gather を使ってインデックスしています。 正直これがベストかわからないので、うまい方法があれば教えてください(僕はPyTorch初心者です)。

欲しいケース

transformersのMaskedLMモデルを実行した結果から、元のセンテンスの各文字がどんな値だったかを取得するのに必要でした。 上記の aresult = model(...) として result.logitsidxinputs = tokenizer.batch_encode_plus(...) として inputs.input_ids です。

この記事はShodo (https://shodo.ink) で執筆されました。