TFrecordで使われているtf.train.Exampleの構造について

スポンサーリンク
TFrecordで使われているtf.train.Exampleの構造についてAIを作ってみる
この記事を読んで分かること
  • TFRecordで使われているフォーマット

 

TFRecordの構造

TFRecordフォーマットは、直接データが入っているわけではなく、いくつかのオブジェクトが存在します。

オブジェクトは階層化されていて、以下のような構造となっています。

tf.train.Exampleの階層
  • tf.train.Example
    • tf.train.Features
      • tf.train.Feature
        • tf.train.BytesList
        • tf.train.FloatList
        • tf.train.Int64List

 

各オブジェクトの役割を表にまとめました。

TFRecord オブジェクトの概要
オブジェクト
概要
tf.train.Example(公式サイト)
データセット全体を管理する
「Features」を内包する
tf.train.Features(公式サイト)複数のオブジェクトを束ねる
「Feature」を複数内包する
tf.train.Feature(公式サイト)1つの特徴量を格納する
「BytesList」「FloatList」「Int64List」のいずれか1つを内包する
tf.train.BytesList(公式サイト)
バイト列のデータを格納する
画像もバイナリ化して格納する
tf.train.FloatList(公式サイト)
float型のデータを格納する
tf.train.Int64List(公式サイト)int型のデータを格納する

 

TFRecordの各オブジェクトの説明

TFRecordのオブジェクトはボトムアップで見ていった方が分かりやすいので、「tf.train.Int64List」「tf.train.FloatList」「 tf.train.BytesList」から見ていきます。

 

tf.train.Int64List

tf.train.Int64Listにはint型のリストが格納でき、次の型を強制します。

tf.train.Int64Listの型強制
  • bool
  • enum
  • int32
  • uint32
  • int64
  • uint64

 

int型のリストを渡すと、リストを内包したオブジェクト(tf.train.Int64List)となります。

import tensorflow as tf

int_list = [1, 2, 3]
tf_int_list = tf.train.Int64List(value=int_list)
tf_int_list
実行結果(クリックして表示)
value: 1
value: 2
value: 3

 

tf.train.FloatList

tf.train.FloatListにはfloat型のリストが格納でき、次の型を強制します。

tf.train.FloatListの型強制
  • float(float32)
  • double(float64)

 

float型のリストを渡すと、リストを内包したオブジェクト(tf.train.FloatList)となります。

float_list = [1.0, 2.0, 3.0]
tf_float_list = tf.train.FloatList(value=float_list)
tf_float_list
実行結果(クリックして表示)
value: 1.0
value: 2.0
value: 3.0

 

tf.train.BytesList

tf.train.BytesListにはバイト列のリストが格納でき、次の型を強制します。

tf.train.BytesListの型強制
  • string
  • byte

 

byte型のリストを渡すと、リストを内包したオブジェクト(tf.train.BytesList)となります。

str_list = ["a", "b", "c"]
bin_list = []
for i_str in str_list:
bin_list.append(i_str.encode())
tf_str_list = tf.train.BytesList(value=bin_list)
tf_str_list
実行結果(クリックして表示)
value: "a"
value: "b"
value: "c"

 

tf.train.Feature

tf.train.Featureでは、「tf.train.BytesList」「tf.train.FloatList」「tf.train.Int64List」の3つの値を互換性のある形にします。

int_feature = tf.train.Feature(int64_list=tf_int_list)
int_feature
実行結果(クリックして表示)
int64_list {
  value: 1
  value: 2
  value: 3
}

 

float_feature = tf.train.Feature(float_list=tf_float_list)
float_feature
実行結果(クリックして表示)
float_list {
  value: 1.0
  value: 2.0
  value: 3.0
}

 

str_feature = tf.train.Feature(bytes_list=tf_str_list)
str_feature
実行結果(クリックして表示)
bytes_list {
  value: "a"
  value: "b"
  value: "c"
}

 

tf.train.Features

tf.train.Featuresでは、複数の「tf.train.Feature」をまとめます。

そのためには、tf.train.Featureの要素を辞書型でマッピングして、その情報をtf.train.Featuresの形にします。

feature_map = {
    'feature_int': int_feature,
    'feature_float': float_feature,
    'feature_str': str_feature,
}

tf_features = tf.train.Features(feature=feature_map)
tf_features
実行結果(クリックして表示)
feature {
  key: "feature_float"
  value {
    float_list {
      value: 1.0
      value: 2.0
      value: 3.0
    }
  }
}
feature {
  key: "feature_int"
  value {
    int64_list {
      value: 1
      value: 2
      value: 3
    }
  }
}
feature {
  key: "feature_str"
  value {
    bytes_list {
      value: "a"
      value: "b"
      value: "c"
    }
  }
}

 

tf.train.Example

tf.train.Exampleは、データセットの表銃的な形式で「tf.train.Features」の単なるラッパーです。

辞書型の要素をバイト文字列にシリアル化する方法の1つに過ぎません。

example_proto = tf.train.Example(features=tf_features)
example_proto
実行結果(クリックして表示)
features {
  feature {
    key: "feature_float"
    value {
      float_list {
        value: 1.0
        value: 2.0
        value: 3.0
      }
    }
  }
  feature {
    key: "feature_int"
    value {
      int64_list {
        value: 1
        value: 2
        value: 3
      }
    }
  }
  feature {
    key: "feature_str"
    value {
      bytes_list {
        value: "a"
        value: "b"
        value: "c"
      }
    }
  }
}

tf.train.Exampleのメッセージは、SerializeToStringメソッドを使うとバイナリ文字列にシリアル化することができます。

serialized_example = example_proto.SerializeToString()
serialized_example
実行結果(クリックして表示)
b'\nW\n\x1a\n\x0bfeature_str\x12\x0b\n\t\n\x01a\n\x01b\n\x01c\n!\n\rfeature_float\x12\x10\x12\x0e\n\x0c\x00\x00\x80?\x00\x00\x00@\x00\x00@@\n\x16\n\x0bfeature_int\x12\x07\x1a\x05\n\x03\x01\x02\x03'

tf.train.Example.FromStringメソッドを使用すると、メッセージのデコードをすることができます。

example_proto = tf.train.Example.FromString(serialized_example)
example_proto
実行結果(クリックして表示)
features {
  feature {
    key: "feature_float"
    value {
      float_list {
        value: 1.0
        value: 2.0
        value: 3.0
      }
    }
  }
  feature {
    key: "feature_int"
    value {
      int64_list {
        value: 1
        value: 2
        value: 3
      }
    }
  }
  feature {
    key: "feature_str"
    value {
      bytes_list {
        value: "a"
        value: "b"
        value: "c"
      }
    }
  }
}

 

まとめ

tfRecordのフォーマットにはtf.train.Exampleオブジェクトが使われています。

tf.train.Exampleオブジェクトには、tf.train.Featuresが内包されていて、各特徴量はtf.train.Featureの形式になっています。

tf.train.Featureの中身は3種類あり、バイト列を格納する「tf.train.BytesList」、float型のリストを格納する「tf.train.FloatList」、int型のリストを格納する「tf.train.Int64List」があります。

 

参考文献

公式チュートリアル

CUBE SUGAR CONTAINER様

コメント

タイトルとURLをコピーしました