Pytorch: 画像ラベル 質問する

Pytorch: 画像ラベル 質問する

私は 31 クラス (Office データセット) の画像分類器に取り組んでいます。クラスごとに 1 つのフォルダーがあります。データセットをロードしてdatasets.ImageFolder各画像にラベルを割り当て、トレーニングする PyTorch を使用して記述した Python スクリプトがあります。以下は、データをロードするためのコード スニペットです。

from torchvision import datasets, transforms
import torch

def load_training(root_path, dir, batch_size, kwargs):
    transform = transforms.Compose(
        [transforms.Resize([256, 256]),
         transforms.RandomCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])
    data = datasets.ImageFolder(root=root_path + dir, transform=transform)
    train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
    return train_loader

コードは各フォルダーを取得し、そのフォルダー内のすべての画像に同じラベルを割り当てます。どのラベルがどの画像/画像フォルダーに割り当てられているかを確認する方法はありますか?

ベストアンサー1

class_to_idxImageFolder クラスには、クラス名をインデックス (ラベル) にマッピングする辞書である属性があります。したがって、 を使用してクラスにアクセスしdata.classes、 を使用して各クラスのラベルを取得できますdata.class_to_idx

参考のために:pytorchvision のマスターファイル

おすすめ記事