概要
Pytorch でニューラルネットワークを実装しているとき、以下のようなエラーが出ることがあります。
ValueError: expected 4D input (got 2D input)
ValueError: expected 2D or 3D input (got 4D input)
上記のエラーはニューラルネットワークの順伝播( forward() ) の過程において、データの次元が合わない、というエラーです。
畳み込み(Convolution)のネットワークから全結合(Linear)のネットワークに繋げるときなどに発生しがちであると考えられます。
具体例
例えば、以下のようなネットワークの定義だと掲題のエラーが発生します。
# ネットワークの定義
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.bn1d = nn.BatchNorm1d(32)
# 順伝播
def forward(self, x):
x = self.conv1(x)
x = self.bn1d(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
# データの投入
net = Net()
data = torch.ones([1, 1, 28, 28])
net(data)
# 以下のエラーが発生する
# ValueError: expected 2D or 3D input (got 4D input)
上記の例だと、self.conv1(x) の出力が4次元(バッチサイズ, チャンネル数, 高さ, 幅) であるにも関わらず、BatchNorm1d に渡そうとしています。
BatchNorm1d は 2次元か3次元のデータを想定しているため、上記のようなエラーが発生します。
対処法
順伝播の過程でデータの次元を適切に扱うことでエラーが解消できます。
具体的には、以下の様に forward() 内でデータの shape を出力することでエラーの原因を特定することができます。
def forward(self, x):
x = self.conv1(x)
print(x.shape)
# 以下略
# torch.Size([1, 32, 26, 26]) が出力される
この様に、エラーが発生している直前のデータの shape を確認することで、原因となったデータの次元を確認することができます。
今回の場合、出力からデータが 4次元となっていることが分かるため、BatchNorm1d ではなく、BatchNorm2dに変更すれば良いことが分かります。
# ネットワークの定義
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.bn1d = nn.BatchNorm1d(32)
# ← ここを nn.BatchNorm2d(32)
にすれば良い
ネットワークが複雑になるにつれて、データの次元がどうなっているかがわかりにくくなってきます。
把握できなくなったときは、データの shape を確認することでエラーの対応がしやすくしましょう。
コメント