PyTorch Lightning 入門 – シンプルでスマートなコードで分類器を作成する

Pytorch

概要

PyTorch には開発スピードを高めるためのフレームワークがいくつかあります。

それは Catalyst, fastai, PyTorch Lightning など様々です。

どれも一長一短あり、使いやすいか否かは人によって分かれるかと思います。本記事ではその中でも、PyTorch Lightning に焦点を当て、実装・解説を行っていきます。

PyTorch Lightning を使いこなして爆速でAIの開発をしよう!

※ 本記事は PyTorch について軽く知っている人を対象としています。

PyTorch Ligitning 紹介

早速ですが、まずはこちらの動画を見ていただくのがわかりやすいと思います。

引用: https://pytorch-lightning.readthedocs.io/en/latest/starter/new-project.html

この動画で分かるように、PyTorch Lightning では 素の PyTorch で一般的に書くであろう処理を全て良しなに Lightning Module と Trainer が吸収をしてくれます

Lightning Module では主にモデルの定義や損失関数、オプティマイザー、モデルの学習をどのようなステップで行うのか、を定義していきます。

TrainerはCPU, GPUなどのデバイス情報、ワーカー数、コールバックの処理など、学習においてよく必要とされる細かい処理を全て担ってくれます。

細かい機能を上げるとキリが無いので、ここからは Lightning Module の解説に重点的に、最低限これを書けば PyTorch Lightning でモデルが作れる、ということを解説していきます。

PyTorch Lightning の最低限のコード

今回扱う問題は、お馴染みのMNISTの分類器を作成していきます。PyTorch Lightning に焦点を当てたものなので、モデルの構造は最も簡素なものを作成していきます。

早速ですが、モデル作成~精度検証まで全て含めたコードが以下です。

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM

class Classifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(28 * 28, 10),
            nn.ReLU())
    
    def forward(self, x):
        return self.classifier(x.view(x.size(0), -1))
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.classifier(x.view(x.size(0), -1))
        return nn.functional.cross_entropy(y_hat, y)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.classifier(x.view(x.size(0), -1))
        self.log_dict({'accuracy': FM.accuracy(torch.argmax(y_hat, dim=1), y)}) 

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.02)


train_loader = DataLoader(MNIST('./', train=True, download=True, transform=transforms.ToTensor()), batch_size=1024, num_workers=12)
test_loader = DataLoader(MNIST('./', train=False, download=True, transform=transforms.ToTensor()), batch_size=1024, num_workers=12)

trainer = pl.Trainer(max_epochs=10, gpus=1)
model = Classifier()

trainer.fit(model, train_loader)
trainer.test(model, test_dataloaders=test_loader)

40行程度でMNISTの分類器の作成、学習、精度検証ができました。

なお、今回はモデルは0から定義しましたが、実装済みのものをimportすればもっと短くなります。

まずコードをざっくり見ると、for文が一切登場しなかったりzero_grad, backward, step など素のPyTorchで必須の処理が無かったりと、シンプルにまとまったコードとなっていることが分かります。

ここからは細かくコードを確認していきます。

class Classifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(28 * 28, 10),
            nn.ReLU())

まずはLightning Moduleの定義です。

ここでは LightningModule を継承し、コンストラクタではモデルの定義を行います。今回は1層の全結合だけで構成された簡易的な分類器を定義しています。

モデルの定義については通常のPyTorchと同様の方法で行います。

ここからは各ステップでどのような処理を行うかなどをオーバーライドしていきます。

まずは forward です。forward には推論時の処理を記述します。

    def forward(self, x):
        return self.classifier(x.view(x.size(0), -1))

LightningModuleの特徴として、学習時やテスト時に行う順伝播と推論時に行う順伝播を完全に区別されていることが挙げられます。

このコードでは、推論時に入力をベクトル化し、分類器の出力を返り値とすることが分かります。

次に training_step です。ここには学習時の処理を記述します。

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.classifier(x.view(x.size(0), -1))
        return nn.functional.cross_entropy(y_hat, y)

ここで損失関数の計算なども一緒に行います。

training_stepにおいては、デフォルトで return された値を損失として扱ってくれるため、これに対して逆伝播を行ってくれることになります。

このコードでは分類器の出力と正解との値のクロスエントロピーを計算し、それを損失としていることが分かります。

次に test_step です。ここには精度を測るためのテスト時に行う処理を記述します。

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.classifier(x.view(x.size(0), -1))
        self.log_dict({'accuracy': FM.accuracy(torch.argmax(y_hat, dim=1), y)}) 

ここで、Lightning Module の log_dict というものを用います。

ここにテストした結果の accuracy などを辞書の形式で格納していくと、最終的な平均精度の算出であったり、そのロギングなどを自動で行ってくれます

最後にconfigure_optimizersです。ここにはオプティマイザの記述をしていきます。

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.02)

ここではオプティマイザの設定に加え、それに関連するスケジューラなどの設定も合わせて行うことができます。

今回はシンプルにAdamを用いています。

ここまでで Lightning Module の定義が完了しましたので、ここからはそれ以外の Dataloader や Trainer について見ていきます。

train_loader = DataLoader(MNIST('./', train=True, download=True, transform=transforms.ToTensor()), batch_size=1024, num_workers=12)
test_loader = DataLoader(MNIST('./', train=False, download=True, transform=transforms.ToTensor()), batch_size=1024, num_workers=12)

Dataloader については今回は通常のPyTorchと同様のものを用いました。

PyTorch Lightning には LightningDataModule というデータを扱うために便利なモジュールが存在していますが、こちらの解説は本記事で扱いません。

trainer = pl.Trainer(max_epochs=10, gpus=1)
model = Classifier()

Trainer についてはほぼデフォルトの設定で定義をします。

Trainer では エポック数、GPUの設定を始めとして、ロギングなどのコールバック、アーリーストッピングなどの設定なども行うことができます。

この辺の処理は素のPyTorchで書くとコードが煩雑になりがちなので、これらを全てTrainerに任せることが出来るのは非常に大きいメリットと言えます。


trainer.fit(model, train_loader)
trainer.test(model, test_dataloaders=test_loader)

最後に学習とテストの実施をします。

Lightning Module に定義した training_step, test_step と Trainerの定義に従って学習、テストを実施し、記録をしてくれます。

ここまで確認できたので最後にこのコードを train.py として保存し、実行してみます。

$ python train.py 
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params
------------------------------------------
0 | classifier | Sequential | 7.9 K 
------------------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)
Epoch 9: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 59/59 [00:01<00:00, 40.84it/s, loss=0.668, v_num=24]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 14.53it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'accuracy': 0.7459999918937683}
--------------------------------------------------------------------------------

たったこれだけのコードしか書いていませんが、モデルの作成~精度検証まで行うことができました。

まとめ

本記事では PyTorch Lightning について紹介し、最低限のコードを書いて分類器を作成しました。

これだけシンプルかつ短いコードでモデルの作成ができるのは驚きです。

PyTorch Lightning を用いると煩雑なコードから開放され、モデルの開発にのみ集中することが可能となるでしょう。

今回紹介した内容は PyTorch Lightning の機能のほんの一握りなので、今後も紹介していきたいと思いいます。

PyTorch Lightning を使いこなして、イケイケなAIエンジニアを目指そう!

コメント

タイトルとURLをコピーしました