機械学習/AIモデルを構築する際のサンプルサイズ設計

AI

「AI作りたいんだけどどれぐらいデータ集めればいい?」という質問

AI作りたいんだけどどれぐらいデータ集めればいい?」という質問を受けたことはないでしょうか。ディープラーニングや機械学習に関わるAIプロジェクトに携わった人は、こういった質問に遭遇することがあるでしょう。みなさんどう答えていますか?私も毎回悩みます。

シンプルなデータに対する仮説検定や効果検証では手法が確立されているため、いくつかの仮定が置ければサンプルサイズの目安をある程度の確度で見積もることが可能です(ただしそこそこ強い仮定は必要です)。しかし、画像のような高次元データを扱うモデリングタスクではそうはいきません。ある程度分析が進んだ段階では何かしら目安が作れることはありますが、少なくともプロジェクトの初期段階では難しいでしょう。

この質問に対する回答として無難な回答としては、「類似のプロジェクトXXではこれぐらい集めてたみたいですよ、なのでそれぐらいがプロトタイプとしては良いんじゃないでしょうか。」といったところでしょうか。

もしくは、「一旦1000件データがあればチャレンジしても良いかもです。」 といったところでしょうか。

どのような回答をしたとしても事故るときは事故ります。なんなら上記はフワッとしてて何の根拠もない回答です。がしかし、少なくともプロジェクトは前に進むでしょう。こういった判断が結果的にプロジェクトにとって良いケースもあります。

また巷では、次のような話も聞きます。

  • 「モデルのパラメータ数の10倍以上が良い(バーニーのおじさん)」
  • 特徴量の次元×10~30 程度のサンプルサイズが目安である」
  • 1000件あれば安心である」

これらを鵜呑みにするのは危険です。あくまでもこれらは特定のプロジェクトで良い結果が得られた例に過ぎず、自分のプロジェクトの参考にできるかどうかなんて誰も分かりません。

データ収集方針の策定や設計ってすごい難しいですよね。でも我々は何とかしてベターな結論を出し、前に進んでいかなければなりません。

さて、本記事では私の過去の経験などを踏まえ、データをどれぐらい集めるべきかという問題提起を起点としてデータ収集時に考慮すべきことをまとめます。
そして、何とかデータを1000件集めた場合でもどういう事故に遭遇する可能性があるか、ということを例を用いてまとめていきます。

データ収集方針の策定の際の何かの参考になれば幸いです。

データ収集時に何を考慮してデータをどれぐらい集めるべきか

まず最初に結論を言うと正解はありませんし、それを算出するようなカッコいい手法もありません。ケースバイケースで適切なものは異なります。プロジェクトの規模や性質、フェーズによって考慮すべきことは無数にあり、それらを明確なルールに落とすことが困難だからです。

上記を踏まえ、プロジェクト開始時にどういったことを考慮すべきかをできるかぎり洗い出して抽象化してみると以下のようになるでしょうか。

  • 問題設定は定まっているのか
    • 何から何を予測・推定する問題なのか
    • 目標とする評価指標の定義・定量化は可能か
    • 目標精度の目安は定まっているか
    • 絶対に予測を誤ってほしくないデータはどういうものか
  • 収集するデータの品質はどの程度になりそうか
    • クラスバランスに偏りがありそうか
    • データがどれぐらいばらつきそうか
    • ノイズはどれぐらい乗る可能性がありそうか
    • 時間経過による分布の変化は発生しそうか
  • データ収集の方法とコストはどのようになりそうか
  • 収集データの法的リスクは無いか
  • システムリソースはどれぐらい使えそうか
    • モデルのサイズがどの程度になりそうか
    • 学習済みの大規模モデルのゼロショット・フューショット学習や転移学習の検討の余地はあるか
  • LLMやVLMなどを含めた外部APIが使えそうか
  • ミニマムで始めて後々モデルの更新は可能なのか

これらに加えて、システム的な問題やドメイン独自の課題を考慮する必要があり、状況はかなり複雑化していきます。
例えば、以下のような状況です。

  • データ基盤を整えるのが困難なため、手元のデータのみで長持ちするモデルを作りたい
  • 現場の物理的要因によってセンサーが不安定となるため、一部の特徴量は欠損が多いことを考慮したい
  • 季節性バイアスが乗る性質が予想されるため、数カ月分のデータでは不十分
  • など…

このようなドメイン特有の問題は無数に挙げられます。このような課題を考慮していったら要求定義・要件定義だけで相当な時間がかかってしまいます。全てを洗い出しているときりがなく、全てを考慮してサンプルサイズの設計をすることは不可能です。泥沼にはまっていってプロジェクト推進が困難、という状況に陥ることもあるでしょう。

この複雑な状況でどうすべきかはとても難しいですが、様々なステークホルダーと上記のような考慮事項を議論し、合意を取り、できる限り多くバリエーションに富んだデータを集める、がベターな解の一つでしょう。EDAを経てプロトタイプのモデリングを行い、ステークホルダーと議論し、データをさらに集めてEDA、モデリング、…(以下繰り返し)、という具合です。

一つ一つの課題が重たく大変なので、重要度に応じて妥協や保留としたり、絶対妥協してはいけない課題を再定義したり、様々な議論が行われるでしょう。

さて、ここからはデータの品質に着目し、懸念事項を事前に整理せずにとりあえず1000枚!と決めてモデリングを進めた場合にどういう事故が起きそうかを、いくつかの例を用いて説明していきます。

1000件データを集めたら安心と信じてモデルを作ってみる

1000件のデータを集めてモデリングを行った場合に、データの品質によってどのような問題が発生するかを確認していきます。データセットにはMNISTを使用します。

MNISTでは現場の問題と全然違うよ!と感じますが、考え方としては遠くないのでMNISTを使います。

もう少し言うと、MNISTのような単純なデータですら、データ収集方針によってたどり着く結論が大きく変わります

現場のデータはもっと複雑なのは間違いないですが、MNISTですらきちんとした設計をしないとうまくいかないこともある、ということを我々は知らなくてはなりません。

問題設定

今回の問題設定として、工業製品の製造ラインにおける自動分類するシステムを考えます。製品をカメラで撮影して、画像を元に自動でどの製品かを分類する仕組みを作りたい、という要望ですね。

製造ラインでは、現場の方が手作業で0、1、2の形状をした製品を作っているとします。これを画像から自動で分類したいため、現場の方の力を借りて画像を入手し、分類モデルを構築することを考えます。このとき、分類システムの品質基準として分類モデルの精度は各形状のデータについて全て95%の精度で正解していなくてはならないとします。つまり、これを下回るような量の誤りが発生する場合は品質基準を満たさないため、クレームが発生してしまうということです。

分析に用いるデータとして、現場の方の協力を得て作られた製品のサンプル画像を合計1,000枚得られたとします。そして、1,000枚も集められたので分析のデータの量としては十分だろうと判断し、モデリングの作業に入ったとします。この作業は教科書通りとし、Train/Validation/Test とデータセットを分けてTestデータに対する分類精度が品質基準を満たしていたら、そのモデルをリリースすると判断します。リリース後は実運用のデータが入ってくる訳なので、実運用のデータの分類精度が各形状について95%を下回った場合はクレームが発生してしまいますね。

この実運用のデータは分析時は手に入らないデータです。本記事においてはサンプリング前のMNISTデータセットから事前にデータを一定割合避けておき、それを未知の実運用のデータとして、この精度評価を行います。この精度が悪ければ、実運用時に問題が起きたと考えます。

Pythonコードでシミュレーション

ここからはPythonコードでシミュレーションをしていきます。細かい処理は重要ではないので要所要所拾い読みしてください。(何なら飛ばしても大丈夫です、読みたい人は読んでみてください。)

まず、今回用いるモデルとしては非常にシンプルなCNNを用います。

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3)
        self.max_pool = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(8 * 13 * 13, 8)
        self.fc2 = nn.Linear(8, 3)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.max_pool(x)
        x = x.view(-1, 8 * 13 * 13)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

学習用の関数とテスト用の関数をそれぞれ用意します。

def train(model, device, train_loader, optimizer, val_loader=None, target_classes=[0, 1, 2]):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()

    if val_loader:
        model.eval()
        val_loss = 0
        correct = 0
        correct_per_class = {cls: 0 for cls in range(len(target_classes))}
        total_per_class = {cls: 0 for cls in range(len(target_classes))}
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                val_loss += nn.CrossEntropyLoss()(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                for cls in range(len(target_classes)):
                    cls_mask = (target == cls)
                    correct_per_class[cls] += pred[cls_mask].eq(target[cls_mask].view_as(pred[cls_mask])).sum().item()
                    total_per_class[cls] += cls_mask.sum().item()
        val_loss /= len(val_loader.dataset)
        accuracy = 100. * correct / len(val_loader.dataset)
        accuracy_per_class = {cls: 100. * correct_per_class[cls] / total_per_class[cls] for cls in range(len(target_classes))}
        print(f'Validation set: Average loss: {val_loss:.4f}, Overall Accuracy: {correct}/{len(val_loader.dataset)} ({accuracy:.2f}%)')
        for cls in range(len(target_classes)):
            print(f'Class {cls} Accuracy: {accuracy_per_class[cls]:.2f}%')

def test(model, device, test_loader, target_classes=[0, 1, 2]):
    model.eval()
    test_loss = 0
    correct = 0
    correct_per_class = {cls: 0 for cls in range(len(target_classes))}
    total_per_class = {cls: 0 for cls in range(len(target_classes))}
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.CrossEntropyLoss()(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            for cls in range(len(target_classes)):
                cls_mask = (target == cls)
                correct_per_class[cls] += pred[cls_mask].eq(target[cls_mask].view_as(pred[cls_mask])).sum().item()
                total_per_class[cls] += cls_mask.sum().item()
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    accuracy_per_class = {cls: 100. * correct_per_class[cls] / total_per_class[cls] for cls in range(len(target_classes))}
    print(f'Test set: Average loss: {test_loss:.4f}, Overall Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')
    for cls in range(len(target_classes)):
        print(f'Class {cls} Accuracy: {accuracy_per_class[cls]:.2f}%')
    return test_loss, accuracy, accuracy_per_class

分析用に用意したデータを Train/Validation/Test に分割する関数を用意します。

def split_dataset_by_class(dataset, class_counts, train_ratio=0.7, val_ratio=0.1):
    targets_np = dataset.dataset.targets.numpy()
    rest_indices = dataset.indices

    train_indices = []
    val_indices = []
    test_indices = []

    for cls, total_count in class_counts.items():
        cls_indices = [idx for idx in rest_indices if targets_np[idx] == cls]
        random.shuffle(cls_indices)

        train_count = int(total_count * train_ratio)
        val_count = int(total_count * val_ratio)
        test_count = total_count - train_count - val_count

        train_indices.extend(cls_indices[:train_count])
        val_indices.extend(cls_indices[train_count:train_count + val_count])
        test_indices.extend(cls_indices[train_count + val_count:train_count + val_count + test_count])

    train_subset = Subset(dataset.dataset, train_indices)
    val_subset = Subset(dataset.dataset, val_indices)
    test_subset = Subset(dataset.dataset, test_indices)

    return train_subset, val_subset, test_subset

データサンプリングのシミュレーション用の関数を定義します。この関数だけ少し重要で、まず実運用相当のデータとして2割(3700枚ぐらい)だけ確保しておきます。そして、実運用相当のデータだけにちょっとした処理をかけられるようにしたり、残ったデータから何らかの方法でサンプリングを行えるようにしておき、実運用で起きうることを再現できるようにしておきます。

def prepare_data(class_counts, sampling_func, real_func=None, target_classes=[0, 1, 2]):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

    indices = [i for i, target in enumerate(mnist.targets) if target in target_classes]
    mnist.targets = mnist.targets[indices]
    mnist.data = mnist.data[indices]

    class_to_idx = {cls: i for i, cls in enumerate(target_classes)}
    mnist.targets = torch.tensor([class_to_idx[target.item()] for target in mnist.targets])

    real_test_size = int(len(mnist) * 0.2)
    rest_size = len(mnist) - real_test_size

    all_indices = list(range(len(mnist)))
    random.shuffle(all_indices)
    real_test_indices = all_indices[:real_test_size]
    rest_indices = all_indices[real_test_size:]

    mnist_real_test = Subset(mnist, real_test_indices)
    if real_func:
        mnist_real_test = real_func(mnist_real_test)

    mnist_rest = Subset(mnist, rest_indices)
    mnist_sampled_rest = sampling_func(mnist_rest, class_counts, class_to_idx)

    train_subset, val_subset, test_subset = split_dataset_by_class(mnist_sampled_rest, class_counts)

    train_loader = DataLoader(train_subset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=64, shuffle=False)
    test_loader = DataLoader(test_subset, batch_size=64, shuffle=False)
    real_test_loader = DataLoader(mnist_real_test, batch_size=64, shuffle=False)

    return train_loader, val_loader, test_loader, real_test_loader

最後に、データ確認用のプロットの関数を用意します。

def plot_samples_from_dataloader(dataloader, target_classes, num_samples_per_class=3):
    class_samples = {cls: [] for cls in target_classes}

    for data, targets in dataloader:
        for img, label in zip(data, targets):
            if len(class_samples[label.item()]) < num_samples_per_class:
                class_samples[label.item()].append(img)
        if all(len(class_samples[cls]) >= num_samples_per_class for cls in target_classes):
            break

    fig, axes = plt.subplots(len(target_classes), num_samples_per_class, figsize=(num_samples_per_class * 2, len(target_classes) * 2))
    for i, cls in enumerate(target_classes):
        for j in range(num_samples_per_class):
            img = class_samples[cls][j].numpy().squeeze()
            axes[i, j].imshow(img, cmap='gray')
            axes[i, j].axis('off')
            if j == 0:
                axes[i, j].set_title(f'Class {cls}')

    plt.tight_layout()
    plt.show()

ここまででシミュレーションの準備は完了です。

ここからは実際に起きうる状況というのを疑似的に再現し、Testデータでの精度検証および、実運用相当のデータでの精度検証をしていきましょう

理想的なデータ収集が出来ている場合

まずは理想的にデータ収集ができているケースを考えていきます。つまり、分析用に集めたデータの分布は実運用相当のデータの分布とほぼ同じであり考慮すべきバイアスはほとんど存在しないであろう、と想定できるケースです。

今回は1000枚データが集まったとするため、各クラスのデータが、

  • 0: 334枚
  • 1: 333枚
  • 2: 333枚

というように、偏りなく均一な枚数で揃えられたとします。

この状況を再現するために各クラスからランダムサンプリングすることをシミュレーションしていきます。

def random_sampling(dataset, class_counts, class_to_idx):
    targets_np = dataset.dataset.targets.numpy()
    rest_indices = dataset.indices

    sampled_indices = []
    for original_cls, count in class_counts.items():
        cls = class_to_idx[original_cls]
        cls_indices = [idx for idx in rest_indices if targets_np[idx] == cls]
        sampled_indices.extend(random.sample(cls_indices, count))

    return Subset(dataset.dataset, sampled_indices)

上記は各クラス毎に指定した枚数だけランダムサンプリングする単純な関数です。これを用いて、理想的なサンプリングが行えたとします。

class_counts = {0: 334, 1: 333, 2: 333}
train_loader, val_loader, test_loader, real_test_loader = prepare_data(class_counts, random_sampling)

plot_samples_from_dataloader(test_loader, [0, 1, 2], num_samples_per_class=3)

例えば、以下のようなデータがサンプリングされており、これを分析用のデータとして用いることができます。

それなりにばらつきがあり、分析やモデリングに用いることは妥当そうです。これをもとに、学習とテストデータでの評価を行ってみます。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1, 11):
    train(model, device, train_loader, optimizer, val_loader)
test_loss, test_accuracy, test_accuracy_per_class = test(model, device, test_loader)
print(f'Test accuracy: {test_accuracy:.2f}%')

これを実行すると以下のような結果が得られました。

Test set: Average loss: 0.0011, Overall Accuracy: 200/202 (99.01%)
Class 0 Accuracy: 100.00%
Class 1 Accuracy: 97.01%
Class 2 Accuracy: 100.00%

Testセットにおいて、各形状のデータの分類精度が95%を超えているので問題がなさそうです。

では、これをリリースしたと考えて、実運用相当のデータでも精度評価を行ってみましょう。

real_test_loss, real_test_accuracy, real_test_accuracy_per_class = test(model, device, real_test_loader)
print(f'Real test accuracy: {real_test_accuracy:.2f}%')

これを実行すると以下の結果が得られました。

Real Test set: Average loss: 0.0012, Overall Accuracy: 3643/3724 (97.82%)
Class 0 Accuracy: 98.84%
Class 1 Accuracy: 97.86%
Class 2 Accuracy: 96.73%
Real test accuracy: 97.82%

実運用相当データでも同様に95%を超える精度が出たため、問題なさそうです。お客様に満足していただき、プロジェクトは大成功でめでたし、という感じですね。

収集したデータのクラスバランスに偏りがある場合

次は、収集したデータを見たところ、各クラスのデータ数が異なっていて以下のような状況だったとしましょう。

  • 0: 450枚
  • 1: 450枚
  • 2: 100枚

恐らく何らかの都合で「2」のデータがあまり集まらなかったのでしょう。そもそも生産数が少なくてデータが集まらなかったり、現場環境の都合でデータの確保が困難だったり、様々な理由が考えられます。なお、クラスバランス以外のデータの品質には問題は無いとし、各クラス内のばらつきも十分保たれていると仮定します。

この時点で問題は明確ですが、教科書通りデータを分割してテストデータに対して精度の良いモデルを頑張って作ってリリースしたとします。これを実運用に載せた場合、どのような問題が起きるでしょうか。具体的な精度を確認していきましょう。

まずは学習、テストを実施します。

class_counts = {0: 450, 1: 450, 2: 100}
train_loader, val_loader, test_loader, real_test_loader = prepare_data(class_counts, random_sampling)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1, 11):
    train(model, device, train_loader, optimizer, val_loader)

test_loss, test_accuracy, test_accuracy_per_class = test(model, device, test_loader)
print(f'Test accuracy: {test_accuracy:.2f}%')

これを実行すると以下の結果が得られました。

Test set: Average loss: 0.0054, Overall Accuracy: 198/200 (99.00%)
Class 0 Accuracy: 100.00%
Class 1 Accuracy: 98.89%
Class 2 Accuracy: 95.00%
Test accuracy: 99.00%

Testセットにおいて、各形状のデータの分類精度が95%を超えているので問題がなさそうです。

では、これをリリースしたと考えて、実運用相当のデータでも精度評価を行ってみましょう。

Real Test set: Average loss: 0.0052, Overall Accuracy: 3472/3724 (93.23%)
Class 0 Accuracy: 99.91%
Class 1 Accuracy: 99.04%
Class 2 Accuracy: 80.30%
Real test accuracy: 93.23%

分析時には精度に問題ありませんでしたが、実運用時には「2」の分類精度が著しく低く、品質基準を大幅に下回っていることがわかります。

このように、クラスバランスについて事前に十分な議論がなかったため、実運用時に特定クラスのみ精度が出ないという状況が発生しました。
もしデータ収集の段階でクラスバランスについて十分な議論がされていれば、次のような対策ができたでしょう。

  • 特定クラスのデータを分析用に収集しにくい原因を探る
  • 特定クラスのデータに関する分類の品質低下を許容できるか検討する

このように、データ収集時にクラスバランスを考慮することは非常に重要です。

収集したデータのばらつきが不十分な場合

前節とは異なり、クラスバランスは均一で、それぞれ333枚程度データを集められたとします。しかし、データの収集方針は現場の方に一任しており、十分な議論がされてなかったとします。

この時、現場の人の裁量で特定ロット特定時間で一気にサンプリングがされたとしましょう。また、運が悪く(良く?)その時間のそのロットには優秀な職人がいて、たまたま製造された製品の品質がかなり良く、サンプリングされたデータのばらつきが少なかったとします。

本記事においてはこれをシミュレーションするため、MNISTから偏った抽出を行います。

def biased_sampling(dataset, class_counts, class_to_idx):
    targets_np = dataset.dataset.targets.numpy()
    rest_indices = dataset.indices
    data_np = dataset.dataset.data.numpy()

    sampled_indices = []
    for original_cls, count in class_counts.items():
        cls = class_to_idx[original_cls]
        cls_indices = np.array([idx for idx in rest_indices if targets_np[idx] == cls])

        def calculate_metrics(images):
            y_indices, x_indices = np.indices(images[0].shape)
            total_intensity = images.sum(axis=(1, 2))
            non_zero_pixels = (images > 0).sum(axis=(1, 2))
            return non_zero_pixels, total_intensity

        cls_images = data_np[cls_indices]
        sizes, thicknesses = calculate_metrics(cls_images)

        avg_size = np.mean(sizes)
        avg_thickness = np.mean(thicknesses)

        size_deviations = np.abs(sizes - avg_size)
        thickness_deviations = np.abs(thicknesses - avg_thickness)
        scores = size_deviations + thickness_deviations

        sorted_indices = np.argsort(scores)
        sampled_cls_indices = cls_indices[sorted_indices[:count]]
        sampled_indices.extend(sampled_cls_indices)

    return Subset(dataset.dataset, sampled_indices)

このコードでは、データの太さや大きさを大きい順に並べて偏った抽出を行っています。

この状況でモデリング及び精度評価を行います。まずはサンプリングして実際に分析用に収集されたと想定されるデータを確認してみましょう。

class_counts = {0: 334, 1: 333, 2: 333}
train_loader, val_loader, test_loader, real_test_loader = prepare_data(class_counts, biased_sampling)

plot_samples_from_dataloader(test_loader, [0, 1, 2], num_samples_per_class=3)

少し分かりにくいですが、比較的細くて小さめの安定したデータが収集されていることが分かります。

ではこれを用いてモデルの学習とTestデータでの精度検証を行ってみましょう。

for epoch in range(1, 11):
    print(f'Epoch {epoch}')
    train(model, device, train_loader, optimizer, val_loader)

test_loss, test_accuracy, test_accuracy_per_class = test(model, device, test_loader)
print(f'Test accuracy: {test_accuracy:.2f}%')

これを実行すると以下の結果が得られました。

Test set: Average loss: 0.0003, Overall Accuracy: 200/202 (99.01%)
Class 0 Accuracy: 97.06%
Class 1 Accuracy: 100.00%
Class 2 Accuracy: 100.00%
Test accuracy: 99.01%

Testセットにおいて、各形状のデータの分類精度が95%を超えているので問題がなさそうです。

では、これをリリースしたと考えて、実運用相当のデータでも精度評価を行ってみましょう。

real_test_loss, real_test_accuracy, real_test_accuracy_per_class = test(model, device, real_test_loader)
print(f'Real test accuracy: {real_test_accuracy:.2f}%')

これを実行すると以下の結果が得られました。

Real Test set: Average loss: 0.0014, Overall Accuracy: 3620/3724 (97.21%)
Class 0 Accuracy: 98.97%
Class 1 Accuracy: 98.18%
Class 2 Accuracy: 94.36%
Real test accuracy: 97.21%

分析時には精度に問題が無かったにもかかわらず、実運用相当のデータでは精度に問題が発生することが分かります。

これは、データのばらつき方について十分な議論が無いまま進めたために起きてしまったと考えられます

もし事前にデータのばらつきについて議論できていれば、日時やロット番号、作業者などの割り当てをランダムに行うなどの対策が取れ、十分にばらついたデータを確保できていたかもしれません。(いわゆる層化抽出ですね。)

仮にばらつきが十分でなかったとしても、「特定のバイアスが乗りそうなのでこの種類のデータには注意が必要」といった議論ができ、顧客の期待値コントロールなどにも貢献できたでしょう。

このように、データのばらつき方についても事前に考慮することは重要であると言えます。

収集したデータが抽出された環境と実運用時の環境が一致していない場合

このケースでは、クラスバランスは均一で、製品のばらつきもある程度保てるようにランダムサンプリングができたと仮定します。

一方で、データを収集した環境と実運用の撮影環境が異なり、実運用時には何らかのノイズが乗ってしまうとします。例えば、データ収集時は検証ルームのようなところで撮像がされて、実運用時はライン上で撮像が行われる、という状態です。今回はこれをシミュレーションするため、実運用時のデータにガウシアンノイズを乗せます。

本記事においてはこれをシミュレーションするため、下記のコードでガウシアンノイズを付与します。

def add_gaussian_noise(images, mean, std):
    noise = np.random.normal(mean, std, images.shape).astype(np.float32)
    noisy_images = images + noise
    noisy_images = np.clip(noisy_images, 0, 1)
    return noisy_images

def sampling_with_noise(dataset, mean=0, std=0.2):
    targets_np = dataset.dataset.targets.numpy()
    rest_indices = dataset.indices

    sampled_images = dataset.dataset.data[rest_indices].numpy().astype(np.float32) / 255.0
    noisy_images = add_gaussian_noise(sampled_images, mean, std)
    noisy_images = torch.tensor(noisy_images * 255).type(torch.uint8)

    dataset.dataset.data[rest_indices] = noisy_images

    return Subset(dataset.dataset, rest_indices)

この状況でモデリング及び精度評価を行います。まずはサンプリングして実際に分析用に収集されたと想定されるデータを確認してみましょう。

class_counts = {0: 334, 1: 333, 2: 333}
train_loader, val_loader, test_loader, real_test_loader = prepare_data(class_counts, random_sampling, sampling_with_noise)

plot_samples_from_dataloader(test_loader, [0, 1, 2], num_samples_per_class=3)

分析用に収集されたデータはノイズが乗っておらずきれいなデータであることが分かります。

次に実運用相当のデータを確認します。

plot_samples_from_dataloader(real_test_loader, [0, 1, 2], num_samples_per_class=3)

ノイズが乗っており、それなりに汚いデータであることが分かります。

では分析用のデータを用いてモデルの学習とTestデータでの精度検証を行ってみましょう。

for epoch in range(1, 11):
    print(f'Epoch {epoch}')
    train(model, device, train_loader, optimizer, val_loader)

test_loss, test_accuracy, test_accuracy_per_class = test(model, device, test_loader)
print(f'Test accuracy: {test_accuracy:.2f}%')

これを実行すると以下の結果が得られました。

Test set: Average loss: 0.0008, Overall Accuracy: 199/202 (98.51%)
Class 0 Accuracy: 100.00%
Class 1 Accuracy: 97.01%
Class 2 Accuracy: 98.51%
Test accuracy: 98.51%

Testセットにおいて、各形状のデータの分類精度が95%を超えているので問題がなさそうです。

では、これをリリースしたと考えて、実運用相当のデータでも精度評価を行ってみましょう。

real_test_loss, real_test_accuracy, real_test_accuracy_per_class = test(model, device, real_test_loader)
print(f'Real test accuracy: {real_test_accuracy:.2f}%')

これを実行すると以下の結果が得られました。

Real Test set: Average loss: 0.0019, Overall Accuracy: 3605/3724 (96.80%)
Class 0 Accuracy: 98.89%
Class 1 Accuracy: 93.02%
Class 2 Accuracy: 99.00%
Real test accuracy: 96.80%

分析時には精度に問題が無かったにもかかわらず、実運用相当のデータでは精度に問題が発生していることが分かります。

これは、分析データを収集したときの環境と実運用時の環境について十分な議論がなかったために起きたと考えられます。

事前に実運用時の環境がどのようなものになるか、どのようなノイズが予想されるかを議論できていれば、次のような対策が考えられます。

  • データ拡張でノイズに対してロバストなモデルを設計する
  • データの収集環境を調整して、実運用時と近い環境を用意する

なお、これが発生する要因として環境による違いももちろんありますが、時間経過などに基づく環境の変化も要因としてはあるでしょう。いわゆるデータドリフトという概念です。
ここでの詳しい議論はしませんが、時間経過による変化を事前に議論して対策を講じることも重要です。

このように、分析時と実運用時の環境の違いについても事前に考慮することも重要です。

「どれぐらいデータ集めればいい?」の質問に対して何を考えるべきかの結論

ここまで議論してきた通り、データ数だけで議論が済むような質問では到底ありませんデータの品質ばらつき偏り運用時にどうなるか、など様々な要因によって左右されるからです。本記事で取り上げたのはかなり一般的な例で、これに加えてドメイン特有の問題も絡んでくるでしょう。

ここまでで議論した内容とシミュレーションを踏まえ、「どれぐらいデータ集めればいい?」という問いが来た時にデータサイエンティストとして考えた方が良い観点は以下でしょう。

  • 問題設定は定まっているのか
    • 何から何を予測・推定する問題なのか
    • 目標とする評価指標の定義・定量化は可能か
    • 目標精度の目安は定まっているか
    • 絶対に予測を誤ってほしくないデータはどういうものか
  • 収集するデータの品質はどの程度になりそうか
    • クラスバランスに偏りがありそうか
    • データがどれぐらいばらつきそうか
    • ノイズはどれぐらい乗る可能性がありそうか
    • 時間経過による分布の変化は発生しそうか
  • データ収集の方法とコストはどのようになりそうか
  • 収集データの法的リスクは無いか
  • システムリソースはどれぐらい使えそうか
    • モデルのサイズがどの程度になりそうか
    • 学習済みの大規模モデルのゼロショット・フューショット学習や転移学習の検討の余地はあるか
  • LLMやVLMなどを含めた外部APIが使えそうか
  • ミニマムで始めて後々モデルの更新は可能なのか

これらの項目について、PM、営業、顧客、顧客側のエンジニアなど、様々な立場の人間と議論していくことが必要です。全て考慮し、適切な回答をしましょう。

…と言いたいところですが、全てを考慮することは不可能です。

それっぽいまとめをしましたが、どれだけ考えて議論しても考慮漏れや想定外は発生しますし、すべての不安要素を排除することはできません。考えうる限りの不安要素をつぶしたとしても、予定通りの精度が実運用時で出るなんてことはほぼありえません。データが集まって初めて判明する事実も多くあります。(ごくまれにうまく行くプロジェクトもありますが、上手くいったら逆に不安になるぐらいがちょうど良いです。)

このように非常に難しくどうしようもない状況の中でデータサイエンティストとしてどうすべきでしょうか。冒頭で書いたように、私もいつも頭を悩ませています。

ベターであることとして言えるのは、様々な観点でデータを疑い、様々な観点で仮説を立てつつ、様々なステークホルダーを巻き込みつつ、できる限り多くのデータを集めて考えうる分析・検証を重ね、考えうる案の中で最もビジネスに貢献する意思決定を示す、ということではないでしょうか。

コメント

タイトルとURLをコピーしました