ふぁむたろうのブログ

機械学習系のお話やポエムを投稿します

PyTorch-Ignite で学習用コードをスマートにする

  • 12月25日
    • Early Stopping、モデルの保存ができるように train と main を変更しました

皆さんこんにちは。みなさんメリークリスマス☆茜ちゃんだよ☆

今回は PyTorch 用のライブラリである PyTorch-Ignite(https://github.com/pytorch/ignite) を使って Training 用のコードを簡潔に書けるようにします。

本記事のキーワード
本記事の対象
  • 何度か PyTorch のコードを書いてきてもう "for input in XXX_loader" は飽きてきた方
  • 大体似たような Training 用のコードを書いている方
  • Training 部分についてある程度のテンプレートを身につけたい方
  • Ignite という響きに惹かれた方
本記事の対象でない方
  • まだ PyTorch を始めたばかりの方
    • Ignite は各バッチ毎に行っている操作 "output = model(input)" や "model.train()" を隠蔽してしまいます
      • そのため学習のイメージが掴みきれていないうちは Ignite は推奨しません。
    • 初めからそこら辺に興味がない場合は Keras の方が良い気がします
  • 機械学習を始めたばかりの方
  • Training 周りのラッパーを全てご自分で用意されている方
  • 各バッチやデータ毎に細かい処理を行いたい方

Ignite とは

PyTorch の Training 周りをスッキリ書くためのライブラリです。 Training 周りに絞ってあるので、試しやすそうです。

例えば kaggle や業務等で、モデルや学習率をざっくばらんに試したいときに使いそうですね。

Ignite Documentation — ignite master documentation github.com

  • 公式 Document の訳
    • Ignite は PyTorch の Training 周りに絞ったライブラリ
    • 各 metrics や earlystopping、model の保存などを含めた Training loop を扱えるよ

f:id:fam_taro:20181224180408p:plain
左:Ignite を使った場合の学習用コード、右:Ignite を使わない場合の学習用コード

上記は公式画像の Document 内にありましたが、学習用コードがスッキリしてます。

使ってみる

https://github.com/pytorch/ignite/blob/master/examples/mnist/mnist.pyhttps://github.com/pytorch/ignite/blob/master/examples/mnist/mnist_with_tensorboardx.py を元に MNIST 用の Training コードを書いてみます。

Python のバージョンと必要なライブラリ

  • Python 3.6.5
    • 3.6 以降じゃないと f-Strings 使えません

今回は以下のライブラリを使いました。

特に pytorch-ignite はちゃんと "pip install pytorch-ignite" でいれましょう。 ("pip install ignite" だと違うものが入ります)

pytorch-ignite      0.1.2
tensorboardX        1.4
torch               0.4.1
torchvision         0.2.1

データセット

今回は MNIST を使います。 自分で行う場合は別途 Dataset を定義しましょう。 ここでついでに DataLoader も定義します。

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize


def get_data_loaders(train_batch_size, val_batch_size):
    data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

    train_loader = DataLoader(MNIST(download=True,
                                    root=".",
                                    transform=data_transform,
                                    train=True),
                              batch_size=train_batch_size, shuffle=True)

    val_loader = DataLoader(MNIST(download=False,
                                  root=".",
                                  transform=data_transform,
                                  train=False),
                            batch_size=val_batch_size, shuffle=False)
    return train_loader, val_loader

モデル

今回は参考サイトのままシンプルなネットワークを用います。 使ってみたいモデルがある人は是非ここで定義しましょう。

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self, num_class=10):
        super(Net, self).__init__()
        self.num_class = num_class
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, self.num_class)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)

train に必要な関数の定義

早速 train 本体を記述したいのですが、その前に必要になる関数をここで定義します。

  • write_metrics() は評価値をログに保存するついでに print もするための関数
  • score_function() は validation_loss を ignite.handlers.EarlyStopping オブジェクトに渡すために変換する関数
    • ignite.handlers.EarlyStopping は、与えたスコアが上がった場合に「スコアが改善した」と判断するため
def write_metrics(metrics, writer, mode: str, epoch: int):
    """print metrics & write metrics to log"""
    avg_accuracy = metrics['accuracy']
    avg_nll = metrics['nll']
    print(f"{mode} Results - Epoch: {epoch}  "
          f"Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}")
    writer.add_scalar(f"{mode}/avg_loss", avg_nll, epoch)
    writer.add_scalar(f"{mode}/avg_accuracy", avg_accuracy, epoch)


def score_function(engine):
    """
    ignite.handlers.EarlyStopping では指定スコアが上がると改善したと判定する。
    そのため今回のロスに -1 をかけたものを ignite.handlers.EarlyStopping オブジェクトに渡す
    """
    val_loss = engine.state.metrics['nll']
    return -val_loss

train 本体の定義

ここで train する部分を記述します。 これまで頑張って以下のような for 文を書いていた部分です。

train(**args):
    for i, input in enumerate(train_loader):
        ...

これを pytorch-ignite を使うと以下のように書けます。

def train(epochs, model, train_loader, valid_loader,
          criterion, optimizer, writer, device, log_interval):
    # device: str であることに注意
    # この時点では Dataloader を与えていないことに注意
    trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={'accuracy': Accuracy(),
                                                     'nll': Loss(criterion)},
                                            device=device)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        i = (engine.state.iteration - 1) % len(train_loader) + 1
        if i % log_interval == 0:
            print(f"Epoch[{engine.state.epoch}] Iteration[{i}/{len(train_loader)}] "
                  f"Loss: {engine.state.output:.2f}")
            # engine.state.output は criterion(model(input)) を表す?
            writer.add_scalar("training/loss", engine.state.output,
                              engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        write_metrics(metrics, writer, 'training', engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(valid_loader)
        metrics = evaluator.state.metrics
        write_metrics(metrics, writer, 'validation', engine.state.epoch)

    # # Checkpoint setting
    # ./checkpoints/sample_mymodel_{step_number}
    # n_saved 個までパラメータを保持する
    handler = ModelCheckpoint(dirname='./checkpoints', filename_prefix='sample',
                              save_interval=2, n_saved=3, create_dir=True)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {'mymodel': model})

    # # Early stopping
    handler = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)
    # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset)
    evaluator.add_event_handler(Events.COMPLETED, handler)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

ポイントとしては以下の通りです。

  • create_supervised_trainer() で trainer を定義
  • create_supervised_evaluator() で evaluator を定義
    • これは train_loader にも valid_loader に対しても使えます
  • @trainer.on(<実行したいタイミング>) で各処理を定義した関数をデコレートする
    • デフォルトでは以下のタイミングに処理を挟み込めるそうです(https://pytorch.org/ignite/engine.html#ignite.engine.Events)
      • COMPLETED
      • EPOCH_COMPLETED
      • EPOCH_STARTED
      • EXCEPTION_RAISED
      • ITERATION_COMPLETED
      • ITERATION_STARTED
      • STARTED
    • 上記の通り大体実行したいタイミングは網羅してそうです
  • 実行したいイベント(モデルの保存や EarlyStopping)は trainer や evaluator に add_event_handler() で追加する

また ModelCheckpoint() によって学習途中のパラメータを保存できるようにしています。 これが 2行で書けるのは嬉しいです。 加えて n_saved によって保存する直近のパラメータの個数を制限できるのも地味に嬉しいです!

さらに今回はコンペを意識して EarlyStopping を導入しました。 注意事項としては evaluator に対して追加していることと、実行するタイミングが Event.COMPLETED になっていることです。

main 関数

main 関数では train() の実行に必要なものを定義します。 コマンドライン引数や config ファイルの読み込みもここで定義すると良いと思います。

ポイントでは、学習結果を見るために tensorboardX を使っていることぐらいです。 もし他に logger を使っている方はそちらでも可能です(ただし train を修正する必要あり)。

あとは Early Stopping を確認するために epochs=100 と大きめに設定していることに注意してください。

import torch
import torch.nn.functional as F
import torch.optim as optim
from tensorboardX import SummaryWriter

from model import Net
from dataset import get_data_loaders
from train import train


def main():
    # 定数
    epochs = 100
    train_batchsize = 128
    valid_batchsize = 4
    log_interval = 50
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 各学習に必要なもの
    model = Net(num_class=10)
    train_loader, valid_loader = get_data_loaders(train_batchsize, valid_batchsize)
    criterion = F.nll_loss
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    log_writer = SummaryWriter('./log')

    # 学習開始
    train(epochs=epochs, model=model,
          train_loader=train_loader, valid_loader=valid_loader,
          criterion=criterion, optimizer=optimizer,
          writer=log_writer, device=device, log_interval=log_interval)

    # モデル保存
    torch.save(model.state_dict(), './checkpoints/final_weights.pt')

    log_writer.close()


if __name__ == '__main__':
    main()

学習実行

以下のように Epoch が回れば OK です。 僕が回したときは epochs = 100 に設定したにも関わらず epoch[23] で学習が終わったので ちゃんと early stopping されていることが確認できました。

$ python main.py
<省略>
Epoch[2] Iteration[1450/1875] Loss: 0.17
<省略>
Epoch[23] Iteration[200/469] Loss: 0.07
Epoch[23] Iteration[250/469] Loss: 0.05
Epoch[23] Iteration[300/469] Loss: 0.24
Epoch[23] Iteration[350/469] Loss: 0.06
Epoch[23] Iteration[400/469] Loss: 0.12
Epoch[23] Iteration[450/469] Loss: 0.08
training Results - Epoch: 23  Avg accuracy: 0.99 Avg loss: 0.03
validation Results - Epoch: 23  Avg accuracy: 0.99 Avg loss: 0.03

学習が終わると以下のディレクトリが新しくできているかと思われます。

  • checkpoints
    • 学習中と学習後のモデルのパラメータが格納されています
    • 私が確認したときは以下のファイルがあり、final_weights.pt を除くとちゃんと n_saved=3 が反映されていることが確認できました
      • final_weights.pt
      • sample_mymodel_18.pth
      • sample_mymodel_20.pth
      • sample_mymodel_22.pth
  • log
    • 学習結果が格納されています
  • raw、processed
    • これは MNIST のデータです(ですので他の Dataset を使った場合は発生しません)

学習結果の確認

$ tensorboard --logdir log

その後該当ページ(デフォルトですと localhost:6006) にアクセスして下記のようなグラフが見れれば OK です!

f:id:fam_taro:20181225195054p:plain
tensorboard による学習結果の確認

まとめ

  • 思ったより自分で定義することが多くてびっくりしました😇。
    • ただし Training 周りをスッキリさせることができ、導入する価値はあるなと思いました(特に for 文周り)。
  • Early Stopping やモデルの保存はちゃんとカバーされてて地味に嬉しかったです
  • 特に PyTorch をある程度書いてきて Training 周りのコードが整理しきれてない人は、テンプレートを知るという意味でも触る価値があるかと思います
  • 今回は触れませんでしたが、example を確認すると pytorch-ignite は GAN や 強化学習にも対応しており 非常に自由度の高いライブラリであることが確認できています

これを使って fine-tuning のテンプレートを作ってみたいなと思いました。

ぜひ皆さんも触ってみてください! ウチもやったんだからさ

今後確認しておきたいこと

  • PyTorch 1.0 対応
  • Multi GPU 対応
  • より細かい操作
    • 他のロスや出力に softmax 等がかかっていないときの対応