Keson's blog

Caffe解读2 -- Protocol Buffers

Reference:

Protobuf Git

Protobuf 安装

Google Developer Guide

proto文件格式

优点

  • 提供灵活高效的结构化数据序列化
  • 类似与XML,但是更小,更快,更简单
  • 在proto文件中定义数据结构后,可以用Protobuf编译器编译成各种目标语言

简单了解

定义message 类型

1
2
3
4
5
message SearchRequest{
required string query=1;
optional int32 page_number=2;
optional int32 result_per_page=3;
}

字段的类型

  • 标量类型:
    • double
    • float
    • int32
    • int64
    • uint32
    • uint64
    • string
    • bool
    • bytes
  • 枚举类型:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
message SearchRequest {
required string query = 1;
optional int32 page_number = 2;
optional int32 result_per_page = 3 [default = 10];
enum Corpus {
UNIVERSAL = 0;
WEB = 1;
IMAGES = 2;
LOCAL = 3;
NEWS = 4;
PRODUCTS = 5;
VIDEO = 6;
}
optional Corpus corpus = 4 [default = UNIVERSAL];
}
  • 其他类型

字段的tag

数字tag,用于标识该字段。1-15可以用1个byte编码,从16到2047使用两个byte编码,所以常用的字段应该用1-15打上tag。

字段的规则

  • required: 确定有一个
  • optional: 可以有0个或1个,不能多于1个
  • repeated: 字段可以重复任意多次(包括0次),定义该字段时在后面加上[packed=true]用以更有效的编码,比如:
1
repeated int32 samples = 4 [packed=true];
  • reserved 保留字段,保证这个字段不会再次被使用

使用其它的message类型

一个message类型中使用其他的message

1
2
3
4
5
6
7
8
9
message SearchResponse {
repeated Result result = 1;
}
message Result {
required string url = 1;
optional string title = 2;
repeated string snippets = 3;
}

Import

如果需要使用的message定义在其它的proto文件中,可以使用import

1
import "myproject/other_protos.proto";

默认情况下你只能使用直接导入的文件中的定义。然而有的时候你需要将一个文件从一个路径移动到另一个路径的时候,与其将所有的引用这个文件的地方都更新到新的路径,不如在原来的路径上留下一个假的文件,使用import public来指向新的路径。import public语句可以将它导入的文件简介传递给导入本文减的文件。比如 :

1
2
// new.proto
// All definitions are moved here
1
2
3
4
// old.proto
// This is the proto that all clients are importing.
import public "new.proto";
import "other.proto";
1
2
3
// client.proto
import "old.proto";
// You use definitions from old.proto and new.proto, but not other.proto

在命令行中试用-I/–proto_path来指定一系列的编译器搜索路径,如果这个参数没有被设置,那么默认在命令执行的路径查找。通常情况下使用-I/–proto_path来指定到你项目的根目录,然后使用完整的路径来导入所需的文件。

内嵌形式

你可以在一个消息中定义并使用其他消息类型,比如下面的例子 —— Result消息是在SearchResponse中定义的

1
2
3
4
5
6
7
8
message SearchResponse {
message Result {
required string url = 1;
optional string title = 2;
repeated string snippets = 3;
}
repeated Result result = 1;
}

如果你打算在这个消息的父消息之外重用这个消息的话,你可以这样引用它

1
2
3
message SomeOtherMessage {
optional SearchResponse.Result result = 1;
}

你想嵌套多深就嵌套多深,没有限制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
message Outer { // Level 0
message MiddleAA { // Level 1
message Inner { // Level 2
required int64 ival = 1;
optional bool booly = 2;
}
}
message MiddleBB { // Level 1
message Inner { // Level 2
required int32 ival = 1;
optional bool booly = 2;
}
}
}

从Proto 文件生成

用protocol buffer编译器编译proto文件,生成对应语言的代码。包括getting和setting对应的字段,以及序列化流和消息的反序列化。

  • 对于C++: 每个proto文件生成一个.h和.cc文件,每个消息类型生成对应的类
  • 对于Java:编译器为每个消息生成一个.java文件,外加一个特殊的Builder类来生成消息实例
  • 对于Python: 一点点不同 —– Python编译器生成有一个静态的对每个消息的描述器的模块。然后,用一个元类在运行时创建必要的Python数据访问类。
  • 对于Go:编译器对文件中的每个消息生成一个.pb.go文件。

一个简单的使用例子

创建自己的Protocol格式

定义个proto文件,添加message以及各个字段。如以下addressbook.proto

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
syntax= "proto2";
package tutorial;
message Person {
required string name = 1;
required int32 id = 2;
optional string email = 3;
enum PhoneType {
MOBILE = 0;
HOME = 1;
WORK = 2;
}
message PhoneNumber {
required string number = 1;
optional PhoneType type = 2 [default = HOME];
}
repeated PhoneNumber phone = 4;
}
message AddressBook {
repeated Person person = 1;
}

编译 .proto 文件

写好 proto 文件之后就可以用 Protobuf 编译器将该文件编译成目标语言了。本例中我们将使用 C++。

假设您的 proto 文件存放在 $SRC_DIR 下面,您也想把生成的文件放在同一个目录下,则可以使用如下命令就可以生存对应的addressbook.pb.h和addressbook.pb.cc

1
protoc -I=$SRC_DIR --cpp_out=$DST_DIR $SRC_DIR/addressbook.proto

其中package 后面的字段会生成对应的域名空间 上面的例子对应namespace 是tutorial。
则对应的h文件中,会有下列的接口,实际的会更多。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// name
inline bool has_name() const;
inline void clear_name();
inline const ::std::string& name() const;
inline void set_name(const ::std::string& value);
inline void set_name(const char* value);
inline ::std::string* mutable_name();
// id
inline bool has_id() const;
inline void clear_id();
inline int32_t id() const;
inline void set_id(int32_t value);
// email
inline bool has_email() const;
inline void clear_email();
inline const ::std::string& email() const;
inline void set_email(const ::std::string& value);
inline void set_email(const char* value);
inline ::std::string* mutable_email();
// phone
inline int phone_size() const;
inline void clear_phone();
inline const ::google::protobuf::RepeatedPtrField< ::tutorial::Person_PhoneNumber >& phone() const;
inline ::google::protobuf::RepeatedPtrField< ::tutorial::Person_PhoneNumber >* mutable_phone();
inline const ::tutorial::Person_PhoneNumber& phone(int index) const;
inline ::tutorial::Person_PhoneNumber* mutable_phone(int index);
inline ::tutorial::Person_PhoneNumber* add_phone();

写以下 Application (writeMain.cpp)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <iostream>
#include <fstream>
#include <string>
#include "addressbook.pb.h"
using namespace std;
// This function fills in a Person message based on user input.
void PromptForAddress(tutorial::Person* person) {
cout << "Enter person ID number: ";
int id;
cin >> id;
person->set_id(id);
cin.ignore(256, '\n');
cout << "Enter name: ";
getline(cin, *person->mutable_name());
cout << "Enter email address (blank for none): ";
string email;
getline(cin, email);
if (!email.empty()) {
person->set_email(email);
}
while (true) {
cout << "Enter a phone number (or leave blank to finish): ";
string number;
getline(cin, number);
if (number.empty()) {
break;
}
tutorial::Person::PhoneNumber* phone_number = person->add_phone();
phone_number->set_number(number);
cout << "Is this a mobile, home, or work phone? ";
string type;
getline(cin, type);
if (type == "mobile") {
phone_number->set_type(tutorial::Person::MOBILE);
} else if (type == "home") {
phone_number->set_type(tutorial::Person::HOME);
} else if (type == "work") {
phone_number->set_type(tutorial::Person::WORK);
} else {
cout << "Unknown phone type. Using default." << endl;
}
}
}
// Main function: Reads the entire address book from a file,
// adds one person based on user input, then writes it back out to the same
// file.
int main(int argc, char* argv[]) {
// Verify that the version of the library that we linked against is
// compatible with the version of the headers we compiled against.
GOOGLE_PROTOBUF_VERIFY_VERSION;
if (argc != 2) {
cerr << "Usage: " << argv[0] << " ADDRESS_BOOK_FILE" << endl;
return -1;
}
tutorial::AddressBook address_book;
{
// Read the existing address book.
fstream input(argv[1], ios::in | ios::binary);
if (!input) {
cout << argv[1] << ": File not found. Creating a new file." << endl;
} else if (!address_book.ParseFromIstream(&input)) {
cerr << "Failed to parse address book." << endl;
return -1;
}
}
// Add an address.
PromptForAddress(address_book.add_person());
{
// Write the new address book back to disk.
fstream output(argv[1], ios::out | ios::trunc | ios::binary);
if (!address_book.SerializeToOstream(&output)) {
cerr << "Failed to write address book." << endl;
return -1;
}
}
// Optional: Delete all global objects allocated by libprotobuf.
google::protobuf::ShutdownProtobufLibrary();
return 0;
}

g++ -o write addressbook.pb.cc writeMain.cpp -lprotobuf 进行编译

编译时需要-lprotobuf

写Read Message 的Application(readmain.cpp)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <iostream>
#include <fstream>
#include <string>
#include "addressbook.pb.h"
using namespace std;
// Iterates though all people in the AddressBook and prints info about them.
void ListPeople(const tutorial::AddressBook& address_book) {
for (int i = 0; i < address_book.person_size(); i++) {
const tutorial::Person& person = address_book.person(i);
cout << "Person ID: " << person.id() << endl;
cout << " Name: " << person.name() << endl;
if (person.has_email()) {
cout << " E-mail address: " << person.email() << endl;
}
for (int j = 0; j < person.phone_size(); j++) {
const tutorial::Person::PhoneNumber& phone_number = person.phone(j);
switch (phone_number.type()) {
case tutorial::Person::MOBILE:
cout << " Mobile phone #: ";
break;
case tutorial::Person::HOME:
cout << " Home phone #: ";
break;
case tutorial::Person::WORK:
cout << " Work phone #: ";
break;
}
cout << phone_number.number() << endl;
}
}
}
// Main function: Reads the entire address book from a file and prints all
// the information inside.
int main(int argc, char* argv[]) {
// Verify that the version of the library that we linked against is
// compatible with the version of the headers we compiled against.
GOOGLE_PROTOBUF_VERIFY_VERSION;
if (argc != 2) {
cerr << "Usage: " << argv[0] << " ADDRESS_BOOK_FILE" << endl;
return -1;
}
tutorial::AddressBook address_book;
{
// Read the existing address book.
fstream input(argv[1], ios::in | ios::binary);
if (!address_book.ParseFromIstream(&input)) {
cerr << "Failed to parse address book." << endl;
return -1;
}
}
ListPeople(address_book);
// Optional: Delete all global objects allocated by libprotobuf.
google::protobuf::ShutdownProtobufLibrary();
return 0;
}

编译g++ -o read addressbook.pb.cc readMain.cpp -lprotobuf

运行 ./write Address写入信息

运行./read Address读出信息

Caffe中的Protocol Buffer

在解读caffe源码之前,应该先看一下caffe.proto,位于../src/caffe/proto 以及由此编译产生的caffe.pb.cccaffe.pb.h

Caffe的proto中message对应的类

例如对于caffe.proto中最简单的BlobShape的message的定义

1
2
3
4
// Specifies the shape (dimensions) of a Blob.
message BlobShape {
repeated int64 dim = 1 [packed = true];
}

产生的相对应的类也包含了许多函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
class BlobShape : public ::google::protobuf::Message /* @@protoc_insertion_point(class_definition:caffe.BlobShape) */ {
public:
BlobShape();
virtual ~BlobShape();
BlobShape(const BlobShape& from);
inline BlobShape& operator=(const BlobShape& from) {
CopyFrom(from);
return *this;
}
inline const ::google::protobuf::UnknownFieldSet& unknown_fields() const {
return _internal_metadata_.unknown_fields();
}
inline ::google::protobuf::UnknownFieldSet* mutable_unknown_fields() {
return _internal_metadata_.mutable_unknown_fields();
}
static const ::google::protobuf::Descriptor* descriptor();
static const BlobShape& default_instance();
static inline const BlobShape* internal_default_instance() {
return reinterpret_cast<const BlobShape*>(
&_BlobShape_default_instance_);
}
void Swap(BlobShape* other);
// implements Message ----------------------------------------------
inline BlobShape* New() const PROTOBUF_FINAL { return New(NULL); }
BlobShape* New(::google::protobuf::Arena* arena) const PROTOBUF_FINAL;
void CopyFrom(const ::google::protobuf::Message& from) PROTOBUF_FINAL;
void MergeFrom(const ::google::protobuf::Message& from) PROTOBUF_FINAL;
void CopyFrom(const BlobShape& from);
void MergeFrom(const BlobShape& from);
void Clear() PROTOBUF_FINAL;
bool IsInitialized() const PROTOBUF_FINAL;
size_t ByteSizeLong() const PROTOBUF_FINAL;
bool MergePartialFromCodedStream(
::google::protobuf::io::CodedInputStream* input) PROTOBUF_FINAL;
void SerializeWithCachedSizes(
::google::protobuf::io::CodedOutputStream* output) const PROTOBUF_FINAL;
::google::protobuf::uint8* InternalSerializeWithCachedSizesToArray(
bool deterministic, ::google::protobuf::uint8* target) const PROTOBUF_FINAL;
::google::protobuf::uint8* SerializeWithCachedSizesToArray(::google::protobuf::uint8* output)
const PROTOBUF_FINAL {
return InternalSerializeWithCachedSizesToArray(false, output);
}
int GetCachedSize() const PROTOBUF_FINAL { return _cached_size_; }
private:
void SharedCtor();
void SharedDtor();
void SetCachedSize(int size) const PROTOBUF_FINAL;
void InternalSwap(BlobShape* other);
private:
inline ::google::protobuf::Arena* GetArenaNoVirtual() const {
return NULL;
}
inline void* MaybeArenaPtr() const {
return NULL;
}
public:
::google::protobuf::Metadata GetMetadata() const PROTOBUF_FINAL;
// nested types ----------------------------------------------------
// accessors -------------------------------------------------------
// repeated int64 dim = 1 [packed = true];
int dim_size() const;
void clear_dim();
static const int kDimFieldNumber = 1;
::google::protobuf::int64 dim(int index) const;
void set_dim(int index, ::google::protobuf::int64 value);
void add_dim(::google::protobuf::int64 value);
const ::google::protobuf::RepeatedField< ::google::protobuf::int64 >&
dim() const;
::google::protobuf::RepeatedField< ::google::protobuf::int64 >*
mutable_dim();
// @@protoc_insertion_point(class_scope:caffe.BlobShape)
private:
::google::protobuf::internal::InternalMetadataWithArena _internal_metadata_;
::google::protobuf::internal::HasBits<1> _has_bits_;
mutable int _cached_size_;
::google::protobuf::RepeatedField< ::google::protobuf::int64 > dim_;
mutable int _dim_cached_byte_size_;
friend void protobuf_InitDefaults_caffe_2eproto_impl();
friend void protobuf_AddDesc_caffe_2eproto_impl();
friend const ::google::protobuf::uint32* protobuf_Offsets_caffe_2eproto();
friend void protobuf_ShutdownFile_caffe_2eproto();
};

比较重要的几个message

下面看几个比较重要的message定义

  • BlobProto
  • NetParameter
  • SolverParameter
  • LayerParameter
  • TransformationParameter
  • LossParameter
  • ConvolutionParameter

具体我们在后面分析具体内容时仔细阅读对应的code即可。如果我们需要增加自己的层,也需要在这里添加内容,具体可看后面的文章。

先来看一下message LayerParameter

message LayerParameter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available layer-specific ID: 147 (last added: recurrent_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
repeated string bottom = 3; // the name of each bottom blob
repeated string top = 4; // the name of each top blob
// The train / test phase for computation.
optional Phase phase = 10;
// The amount of weight to assign each top blob in the objective.
// Each layer assigns a default value, usually of either 0 or 1,
// to each top blob.
repeated float loss_weight = 5;
// Specifies training parameters (multipliers on global learning constants,
// and the name and other settings used for weight sharing).
repeated ParamSpec param = 6;
// The blobs containing the numeric parameters of the layer.
repeated BlobProto blobs = 7;
// Specifies whether to backpropagate to each bottom. If unspecified,
// Caffe will automatically infer whether each input needs backpropagation
// to compute parameter gradients. If set to true for some inputs,
// backpropagation to those inputs is forced; if set false for some inputs,
// backpropagation to those inputs is skipped.
//
// The size must be either 0 or equal to the number of bottoms.
repeated bool propagate_down = 11;
// Rules controlling whether and when a layer is included in the network,
// based on the current NetState. You may specify a non-zero number of rules
// to include OR exclude, but not both. If no include or exclude rules are
// specified, the layer is always included. If the current NetState meets
// ANY (i.e., one or more) of the specified rules, the layer is
// included/excluded.
repeated NetStateRule include = 8;
repeated NetStateRule exclude = 9;
// Parameters for data pre-processing.
optional TransformationParameter transform_param = 100;
// Parameters shared by loss layers.
optional LossParameter loss_param = 101;
// Layer type-specific parameters.
//
// Note: certain layers may have more than one computational engine
// for their implementation. These layers include an Engine type and
// engine parameter for selecting the implementation.
// The default for the engine is set by the ENGINE switch at compile-time.
optional AccuracyParameter accuracy_param = 102;
optional ArgMaxParameter argmax_param = 103;
optional BatchNormParameter batch_norm_param = 139;
optional BiasParameter bias_param = 141;
optional ConcatParameter concat_param = 104;
optional ContrastiveLossParameter contrastive_loss_param = 105;
optional ConvolutionParameter convolution_param = 106;
optional CropParameter crop_param = 144;
optional DataParameter data_param = 107;
optional DropoutParameter dropout_param = 108;
optional DummyDataParameter dummy_data_param = 109;
optional EltwiseParameter eltwise_param = 110;
optional ELUParameter elu_param = 140;
optional EmbedParameter embed_param = 137;
optional ExpParameter exp_param = 111;
optional FlattenParameter flatten_param = 135;
optional HDF5DataParameter hdf5_data_param = 112;
optional HDF5OutputParameter hdf5_output_param = 113;
optional HingeLossParameter hinge_loss_param = 114;
optional ImageDataParameter image_data_param = 115;
optional InfogainLossParameter infogain_loss_param = 116;
optional InnerProductParameter inner_product_param = 117;
optional InputParameter input_param = 143;
optional LogParameter log_param = 134;
optional LRNParameter lrn_param = 118;
optional MemoryDataParameter memory_data_param = 119;
optional MVNParameter mvn_param = 120;
optional ParameterParameter parameter_param = 145;
optional PoolingParameter pooling_param = 121;
optional PowerParameter power_param = 122;
optional PReLUParameter prelu_param = 131;
optional PythonParameter python_param = 130;
optional RecurrentParameter recurrent_param = 146;
optional ReductionParameter reduction_param = 136;
optional ReLUParameter relu_param = 123;
optional ReshapeParameter reshape_param = 133;
optional ScaleParameter scale_param = 142;
optional SigmoidParameter sigmoid_param = 124;
optional SoftmaxParameter softmax_param = 125;
optional SPPParameter spp_param = 132;
optional SliceParameter slice_param = 126;
optional TanHParameter tanh_param = 127;
optional ThresholdParameter threshold_param = 128;
optional TileParameter tile_param = 138;
optional WindowDataParameter window_data_param = 129;
}