私は 31 クラス (Office データセット) の画像分類器に取り組んでいます。クラスごとに 1 つのフォルダーがあります。データセットをロードしてdatasets.ImageFolder
各画像にラベルを割り当て、トレーニングする PyTorch を使用して記述した Python スクリプトがあります。以下は、データをロードするためのコード スニペットです。
from torchvision import datasets, transforms
import torch
def load_training(root_path, dir, batch_size, kwargs):
transform = transforms.Compose(
[transforms.Resize([256, 256]),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
data = datasets.ImageFolder(root=root_path + dir, transform=transform)
train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
return train_loader
コードは各フォルダーを取得し、そのフォルダー内のすべての画像に同じラベルを割り当てます。どのラベルがどの画像/画像フォルダーに割り当てられているかを確認する方法はありますか?
ベストアンサー1
class_to_idx
ImageFolder クラスには、クラス名をインデックス (ラベル) にマッピングする辞書である属性があります。したがって、 を使用してクラスにアクセスしdata.classes
、 を使用して各クラスのラベルを取得できますdata.class_to_idx
。
参考のために:pytorchvision のマスターファイル