TLT(Transfer Learning Toolkit)の使い方について、チュートリアル形式で全6回にまとめました。
- 第1回:入力データの準備
- 第2回:事前学習モデルの入手
- 第3回:学習コンフィグの作成
- 第4回:モデルの学習/評価
- 第5回:モデルの枝刈り ←★今ここ
- 第6回:推論
- TLTでモデルを枝刈り(軽量化)する方法
- モデルの再学習をする方法
- モデルの再評価をする方法
以下環境で動作確認を行いました。
TLT(Transfer Learning Toolkit)のチュートリアルに沿った内容を実践してみます。
TLTインストール後にサンプルファイルとして入っている
/notebooks/examples/detectnet_v2/detectnet_v2.ipynb
の内容となります。
今回編集するファイルは以下となります。
/workspace/ ├ examples/ │ ├ detectnet_v2/ │ │ ├ detectnet_v2.ipynb # チュートリアルのソースコード │ │ ├ specs/ # $SPECS_DIR │ │ │ ├ detectnet_v2_tfrecords_kitti_trainval.txt # kitti → tfrecords に変換する設定ファイル │ │ │ ├ detectnet_v2_train_resnet18_kitti.txt # モデルの学習に関する設定ファイル │ │ │ ├ detectnet_v2_retrain_resnet18_kitti.txt # モデルの再学習に関する設定ファイル ├ tlt-experiments/ │ ├ data/ # $DATA_DOWNLOAD_DIR │ │ ├ training/ │ │ │ ├ image_2/ │ │ │ │ ├ 000000.png │ │ │ │ ├ 000001.png │ │ │ │ ├ : │ │ │ ├ label_2/ │ │ │ │ ├ 000000.txt │ │ │ │ ├ 000001.txt │ │ │ │ ├ : │ │ ├ testing/ │ │ │ ├ image_2/ │ │ │ │ ├ 000000.png │ │ │ │ ├ 000001.png │ │ │ │ ├ : │ ├ detectnet_v2/ # $USER_EXPERIMENT_DIR
モデルの枝刈り
モデルを軽量化するためにモデルの枝刈りを行っていきます。
枝刈りでは影響の少ないネットワークを切断することで軽量化をすることができますが、精度がトレードオフの関係となってしまうのが一般的です。
前回までに作成したDetectNet_v2モデルの枝刈りを行っていきましょう。
TLTの枝刈りはtlt-pruneコマンドを使用します。
tlt-pruneコマンドの引数
tlt-pruneコマンドの詳細については、公式サイトのドキュメントに詳しく記載されています。
必須の引数
- -m : モデルのパス
- -o : 出力先のディレクトリのパス
- -k : NGCAPIキー
オプションの引数
- -n : 正規化の方法(「max」または「L2」) デフォルトは「max」
- -eq : 異なるブランチから重みをマージする方法(「arithmetic_mean」「geometric_mean」「union」「intersection」) デフォルトは「union」
- -pg : 一度に削除するフィルターの数。デフォルトは「8」
- -pth : 枝刈りの閾値。デフォルト0.1
- -nf : レイヤーごとに保持するフィルターの最小値
- -el : 除外するレイヤーのリスト
枝刈りを実行
枝刈り後のモデルを格納するディレクトリを作成します。
!mkdir -p $USER_EXPERIMENT_DIR/experiment_dir_prunedprint(deta)
tlt-pruneコマンドを使って枝刈りを行います。
!tlt-prune -m $USER_EXPERIMENT_DIR/experiment_dir_unpruned/weights/resnet18_detector.tlt \
-o $USER_EXPERIMENT_DIR/experiment_dir_pruned/resnet18_nopool_bn_detectnet_v2_pruned.tlt \
-eq union \
-pth 0.7 \
-k $KEY
枝刈りの結果、何%のネットワークが切断されたか表示されました。
[INFO] iva.common.magnet_prune: Pruning ratio (pruned model / original model): 0.8543357449145645
枝刈り結果の確認
枝刈りして生成されたモデルを確認してみましょう。
先程の学習コマンドで指定した出力先フォルダを確認しています。
!ls -rlt $USER_EXPERIMENT_DIR/experiment_dir_pruned/
枝刈り後のモデル(.tlt)が存在していることが確認できました。
total 37592 -rw-r--r-- 1 root root 38494112 Feb 21 10:21 resnet18_nopool_bn_detectnet_v2_pruned.tlt
もともと45MB程度あったモデルファイルが、37MB程度に軽量化されていることが確認できます。
枝刈り後のモデルの再学習
モデルが枝刈りされると有用な重みまで削除され、精度が低下する可能性があります。
精度を取り戻すために、枝刈り後のモデルで再学習することが推奨されています。
再学習用のコンフィグを作成
再学習用のコンフィグファイルの構成は、1度目の学習に使ったコンフィグファイルと変わりません。
学習用のコンフィグの構成はこちらの記事にまとまっています。

再学習時のコンフィグを作成するに当たり、以下の2点に注意が必要です。
今回、再学習用のコンフィグファイルは「detectnet_v2_retrain_resnet18_kitti.txt」という名前で作成しました。
一応中身を確認しておきます。
!cat $SPECS_DIR/detectnet_v2_retrain_resnet18_kitti.txt
最初の学習時のコンフィグとの変更点はモデルのパスとload_graphの設定のみなので、コンフィグファイルの内容は一部のみの掲載としました。
: model_config { pretrained_model_file: "/workspace/tlt-experiments/detectnet_v2/experiment_dir_pruned/resnet18_nopool_bn_detectnet_v2_pruned.tlt" num_layers: 18 use_batch_norm: true load_graph: true :
再学習の実施
再学習も初回の学習と同じようにtlt-trainコマンドを使います。
オプションは第4回の記事にまとめられています。

実際に設定ファイルや出力ディレクトリ等を指定して、モデルの再学習をさせます。
# Retraining using the pruned model as pretrained weights
!tlt-train detectnet_v2 -e $SPECS_DIR/detectnet_v2_retrain_resnet18_kitti.txt \
-r $USER_EXPERIMENT_DIR/experiment_dir_retrain \
-k $KEY \
-n resnet18_detector_pruned \
--gpus $NUM_GPUS
初回の学習時と同様に学習結果が表示されました。
Validation cost: 0.000060 Mean average_precision (in %): 56.8756 class name average precision (in %) ------------ -------------------------- car 67.897 cyclist 53.1432 pedestrian 49.5866
再学習の結果を確認
再学習して生成されたモデルを確認してみましょう。
先程の学習コマンドで指定した出力先フォルダを確認しています。
!ls -rlt $USER_EXPERIMENT_DIR/experiment_dir_retrain/weights
モデル(.tlt)が存在していることが確認できました。
モデルの評価
!tlt-evaluate detectnet_v2 -e $SPECS_DIR/detectnet_v2_retrain_resnet18_kitti.txt \
-m $USER_EXPERIMENT_DIR/experiment_dir_retrain/weights/resnet18_detector_pruned.tlt \
-k $KEY
実行すると評価の結果が表示されました。
Validation cost: 0.000311 Mean average_precision (in %): 56.9122 class name average precision (in %) ------------ -------------------------- car 67.9653 cyclist 53.1848 pedestrian 49.5863
まとめ
TLTでモデルを枝刈り(軽量化)する方法について説明しました。
また、枝刈りしたモデルは精度が下がってしまうことがあるので、再学習して再評価を行いました。
tlt-pruneコマンドにはモデル、出力先、NGCAPIキーが必須の引数でした。
モデルの作成後には精度を確認し、必要であれば再学習を行いましょう。
次回の記事はこちら

前回の記事はこちら

コメント