PyTorch でn次元のTensorがあったときに、n-1次元のTensorを使ってインデックスする方法です。
以下のように a
というTensorと、 idx
というインデックスがあったときに、 answer
を取る方法です。
a = torch.tensor([ [100, 200, 300, 400], [500, 600, 700, 800], ]) idx = torch.tensor([2, 3]) # 欲しい値は以下 answer = torch.tensor([300, 800])
方法のコード
>>> a.gather(1, idx.unsqueeze(1)).reshape(idx.shape) tensor([300, 800])
一旦 idx
を .unsqueeze
することで次元の数をあわせてから、 .gather
を使ってインデックスしています。
正直これがベストかわからないので、うまい方法があれば教えてください(僕はPyTorch初心者です)。
欲しいケース
transformersのMaskedLMモデルを実行した結果から、元のセンテンスの各文字がどんな値だったかを取得するのに必要でした。
上記の a
が result = model(...)
として result.logits
、 idx
が inputs = tokenizer.batch_encode_plus(...)
として inputs.input_ids
です。
この記事はShodo (https://shodo.ink) で執筆されました。