PyTorch で MNIST データセットの学習を実装する時のあれこれ
MNIST は独自のデータ形式 (.ubyte) で PIL 画像として配布されている。
これをニューラルネットワークに突っ込むために、transform の処理を行う必要がある。
それが下記のコードである。
import torch
from torchvision import datasets, transforms
#データ前処理 transform を設定
transform = transforms.Compose(
[transforms.ToTensor(), # Tensor変換とshape変換 [H, W, C] -> [C, H, W]
transforms.Normalize((0.5, ), (0.5, ))]) # 標準化 平均:0.5 標準偏差:0.5
#訓練用(train + validation)のデータセット サイズ:(channel, height, width) = (1,28,28) 60000枚
trainval_dataset = datasets.MNIST(root='../data',
train=True, # True:訓練用60,000枚, False:テスト用10,000枚
download=True,
transform=transform)Compose に関しては、引数でリストを受け取り、リスト内の変換を順番に適用している。
今回の例においてはテンソル化した後、平均0.5、標準偏差0.5の正規化を行なっている。
ToTensor() はテンソル化に加え、軸の順番を変更しチャンネルを先頭に持ってくる役割も持つ。