【第5回:モデルの枝刈り】TLT(Transfer Learning Toolkit)のチュートリアル

スポンサーリンク
【第5回_モデルの枝刈り】TLT(Transfer Learning Toolkit)のチュートリアルAIを作ってみる

TLT(Transfer Learning Toolkit)の使い方について、チュートリアル形式で全6回にまとめました。

 

 

この記事を読んで分かること
  • TLTでモデルを枝刈り(軽量化)する方法
  • モデルの再学習をする方法
  • モデルの再評価をする方法

 

以下環境で動作確認を行いました。

動作環境

  • Ubuntu18.04
  • GTX1080

 

TLT(Transfer Learning Toolkit)のチュートリアルに沿った内容を実践してみます。

TLTインストール後にサンプルファイルとして入っている

/notebooks/examples/detectnet_v2/detectnet_v2.ipynb

の内容となります。

 

今回編集するファイルは以下となります。

/workspace/
  ├ examples/
  │  ├ detectnet_v2/
  │  │  ├ detectnet_v2.ipynb # チュートリアルのソースコード 
  │  │  ├ specs/  # $SPECS_DIR
  │  │  │  ├ detectnet_v2_tfrecords_kitti_trainval.txt # kitti → tfrecords に変換する設定ファイル
  │  │  │  ├ detectnet_v2_train_resnet18_kitti.txt # モデルの学習に関する設定ファイル
  │  │  │  ├ detectnet_v2_retrain_resnet18_kitti.txt # モデルの再学習に関する設定ファイル
  ├ tlt-experiments/
  │  ├ data/  # $DATA_DOWNLOAD_DIR
  │  │  ├ training/
  │  │  │  ├ image_2/
  │  │  │  │  ├ 000000.png
  │  │  │  │  ├ 000001.png 
  │  │  │  │  ├   :
  │  │  │  ├ label_2/
  │  │  │  │  ├ 000000.txt
  │  │  │  │  ├ 000001.txt 
  │  │  │  │  ├   :
  │  │  ├ testing/
  │  │  │  ├ image_2/
  │  │  │  │  ├ 000000.png
  │  │  │  │  ├ 000001.png
  │  │  │  │  ├   :
  │  ├ detectnet_v2/ # $USER_EXPERIMENT_DIR

モデルの枝刈り

モデルを軽量化するためにモデルの枝刈りを行っていきます

枝刈りでは影響の少ないネットワークを切断することで軽量化をすることができますが、精度がトレードオフの関係となってしまうのが一般的です。

前回までに作成したDetectNet_v2モデルの枝刈りを行っていきましょう。

TLTの枝刈りはtlt-pruneコマンドを使用します。

 

tlt-pruneコマンドの引数

tlt-pruneコマンドの詳細については、公式サイトのドキュメントに詳しく記載されています。

 

必須の引数

  • -m : モデルのパス
  • -o : 出力先のディレクトリのパス
  • -k : NGCAPIキー

 

オプションの引数

  • -n : 正規化の方法(「max」または「L2」) デフォルトは「max」
  • -eq : 異なるブランチから重みをマージする方法(「arithmetic_mean」「geometric_mean」「union」「intersection」) デフォルトは「union」
  • -pg : 一度に削除するフィルターの数。デフォルトは「8」
  • -pth : 枝刈りの閾値。デフォルト0.1
  • -nf : レイヤーごとに保持するフィルターの最小値
  • -el : 除外するレイヤーのリスト

 

枝刈りを実行

枝刈り後のモデルを格納するディレクトリを作成します。

!mkdir -p $USER_EXPERIMENT_DIR/experiment_dir_prunedprint(deta)

 

tlt-pruneコマンドを使って枝刈りを行います。

!tlt-prune -m $USER_EXPERIMENT_DIR/experiment_dir_unpruned/weights/resnet18_detector.tlt \
            -o $USER_EXPERIMENT_DIR/experiment_dir_pruned/resnet18_nopool_bn_detectnet_v2_pruned.tlt \
            -eq union \
            -pth 0.7 \
            -k $KEY

 

枝刈りの結果、何%のネットワークが切断されたか表示されました。

[INFO] iva.common.magnet_prune: Pruning ratio (pruned model / original model): 0.8543357449145645

 

枝刈り結果の確認

枝刈りして生成されたモデルを確認してみましょう。

先程の学習コマンドで指定した出力先フォルダを確認しています。

!ls -rlt $USER_EXPERIMENT_DIR/experiment_dir_pruned/

 

枝刈り後のモデル(.tlt)が存在していることが確認できました。

total 37592
-rw-r--r-- 1 root root 38494112 Feb 21 10:21 resnet18_nopool_bn_detectnet_v2_pruned.tlt

もともと45MB程度あったモデルファイルが、37MB程度に軽量化されていることが確認できます。

 

枝刈り後のモデルの再学習

モデルが枝刈りされると有用な重みまで削除され、精度が低下する可能性があります。

精度を取り戻すために、枝刈り後のモデルで再学習することが推奨されています。

 

再学習用のコンフィグを作成

再学習用のコンフィグファイルの構成は、1度目の学習に使ったコンフィグファイルと変わりません

学習用のコンフィグの構成はこちらの記事にまとまっています。

【第3回:学習コンフィグの作成】TLT(Transfer Learning Toolkit)のチュートリアル
TLT(Transfer Learning Toolkit)の使い方について、チュートリアル形式で全6回にまとめました。 第1回:入力データの準備 第2回:事前学習モデルの入手 第3回:学習コンフィグの作成  ←★今こ...

 

再学習時のコンフィグを作成するに当たり、以下の2点に注意が必要です。

  • 枝刈り前の精度を取り戻すために、再学習時には正規化をオフ(regularizerのtypeをNO_REG)にすることがおすすめされています。
  • また、detectnet_v2の場合は枝刈りされたモデルを利用するには、load_graphを trueに設定することが必要となります。

 

今回、再学習用のコンフィグファイルは「detectnet_v2_retrain_resnet18_kitti.txt」という名前で作成しました。

一応中身を確認しておきます。

!cat $SPECS_DIR/detectnet_v2_retrain_resnet18_kitti.txt

 

最初の学習時のコンフィグとの変更点はモデルのパスload_graphの設定のみなので、コンフィグファイルの内容は一部のみの掲載としました。

      :
model_config {
  pretrained_model_file: "/workspace/tlt-experiments/detectnet_v2/experiment_dir_pruned/resnet18_nopool_bn_detectnet_v2_pruned.tlt"
  num_layers: 18
  use_batch_norm: true
  load_graph: true
      :

 

再学習の実施

再学習も初回の学習と同じようにtlt-trainコマンドを使います。

オプションは第4回の記事にまとめられています。

【第4回:モデルの学習/評価】TLT(Transfer Learning Toolkit)のチュートリアル
TLT(Transfer Learning Toolkit)の使い方について、チュートリアル形式で全6回にまとめました。 第1回:入力データの準備 第2回:事前学習モデルの入手 第3回:学習コンフィグの作成   ...

 

実際に設定ファイルや出力ディレクトリ等を指定して、モデルの再学習をさせます

# Retraining using the pruned model as pretrained weights 
!tlt-train detectnet_v2 -e $SPECS_DIR/detectnet_v2_retrain_resnet18_kitti.txt \
                        -r $USER_EXPERIMENT_DIR/experiment_dir_retrain \
                        -k $KEY \
                        -n resnet18_detector_pruned \
                        --gpus $NUM_GPUS

 

初回の学習時と同様に学習結果が表示されました

Validation cost: 0.000060
Mean average_precision (in %): 56.8756

class name      average precision (in %)
------------  --------------------------
car                              67.897
cyclist                          53.1432
pedestrian                       49.5866

 

再学習の結果を確認

再学習して生成されたモデルを確認してみましょう。

先程の学習コマンドで指定した出力先フォルダを確認しています。

!ls -rlt $USER_EXPERIMENT_DIR/experiment_dir_retrain/weights

 

モデル(.tlt)が存在していることが確認できました。

-rw-r--r-- 1 root root 45022888 Feb 13 11:44 resnet18_detector_pruned.tlt

 

モデルの評価

初回学習時と同様に、モデルの評価も行っていきます。
!tlt-evaluate detectnet_v2 -e $SPECS_DIR/detectnet_v2_retrain_resnet18_kitti.txt \
                           -m $USER_EXPERIMENT_DIR/experiment_dir_retrain/weights/resnet18_detector_pruned.tlt \
                           -k $KEY

 

実行すると評価の結果が表示されました。

Validation cost: 0.000311
Mean average_precision (in %): 56.9122

class name      average precision (in %)
------------  --------------------------
car                              67.9653
cyclist                          53.1848
pedestrian                       49.5863

まとめ

TLTでモデルを枝刈り(軽量化)する方法について説明しました。

また、枝刈りしたモデルは精度が下がってしまうことがあるので、再学習して再評価を行いました。

tlt-pruneコマンドにはモデル出力先NGCAPIキーが必須の引数でした。

モデルの作成後には精度を確認し、必要であれば再学習を行いましょう。

 

次回の記事はこちら

【第6回:推論】TLT(Transfer Learning Toolkit)のチュートリアル
TLT(Transfer Learning Toolkit)の使い方について、チュートリアル形式で全6回にまとめました。 第1回:入力データの準備 第2回:事前学習モデルの入手 第3回:学習コンフィグの作成   ...

前回の記事はこちら

【第4回:モデルの学習/評価】TLT(Transfer Learning Toolkit)のチュートリアル
TLT(Transfer Learning Toolkit)の使い方について、チュートリアル形式で全6回にまとめました。 第1回:入力データの準備 第2回:事前学習モデルの入手 第3回:学習コンフィグの作成   ...

コメント

タイトルとURLをコピーしました