Pytorch で ValueError: expected 4D input (got 2D input) などのエラーが起きたときの対処法

AI

概要

Pytorch でニューラルネットワークを実装しているとき、以下のようなエラーが出ることがあります。

ValueError: expected 4D input (got 2D input)
ValueError: expected 2D or 3D input (got 4D input)

上記のエラーはニューラルネットワークの順伝播( forward() ) の過程において、データの次元が合わない、というエラーです。
畳み込み(Convolution)のネットワークから全結合(Linear)のネットワークに繋げるときなどに発生しがちであると考えられます。

具体例

例えば、以下のようなネットワークの定義だと掲題のエラーが発生します。

# ネットワークの定義
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        self.bn1d = nn.BatchNorm1d(32)

    # 順伝播
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1d(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

# データの投入
net = Net()
data = torch.ones([1, 1, 28, 28])
net(data)

# 以下のエラーが発生する
# ValueError: expected 2D or 3D input (got 4D input)

上記の例だと、self.conv1(x) の出力が4次元(バッチサイズ, チャンネル数, 高さ, 幅) であるにも関わらず、BatchNorm1d に渡そうとしています。

BatchNorm1d は 2次元か3次元のデータを想定しているため、上記のようなエラーが発生します。

対処法

順伝播の過程でデータの次元を適切に扱うことでエラーが解消できます。

具体的には、以下の様に forward() 内でデータの shape を出力することでエラーの原因を特定することができます。

    def forward(self, x):
        x = self.conv1(x)
        print(x.shape)
        # 以下略

# torch.Size([1, 32, 26, 26]) が出力される

この様に、エラーが発生している直前のデータの shape を確認することで、原因となったデータの次元を確認することができます。

今回の場合、出力からデータが 4次元となっていることが分かるため、BatchNorm1d ではなく、BatchNorm2dに変更すれば良いことが分かります。

# ネットワークの定義
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        self.bn1d = nn.BatchNorm1d(32)
  # ← ここを nn.BatchNorm2d(32)
 にすれば良い

ネットワークが複雑になるにつれて、データの次元がどうなっているかがわかりにくくなってきます。
把握できなくなったときは、データの shape を確認することでエラーの対応がしやすくしましょう。

コメント

タイトルとURLをコピーしました