[IR] Google WSDM'09講演で述べられている符号化方式を実装してみた
MG勉強会の後にid:sleepy_yoshiさんに教えてもらったWSDM 2009における講演"Challenges in Building Large-Scale Information Retrieval Systems"で述べられている符号化方式のGroup Varint Encodingを実装してみた。
整数の符号化方式
転置インデックスなどで文章番号のリストを前の値との差分で表すなどの方法を用いると出現する、ほとんどの値は小さな値となるためこれを4バイト使って表現するのは記憶容量の無駄である。
このためVarint Encoding、ガンマ符号、デルタ符号、Rice Coding、Simple 9、pForDeltaなど様々な符号化方式が提案されている。このうちVarint Encodingは実装が手軽なことからよく用いられている。Varint Encodingでは次のように整数xを表現する。
- xが7bit以下であればxのまま出力
- xが8bit以上であればxの下位7bitとxがまだ続くことを示すためにMSBに1をセットして出力
- xを7bit右シフトして、上記処理を繰り返す
例えば
値 | 固定長符号 | Varint Encoding |
---|---|---|
5 | 00000000 00000000 00000000 00000101 | 00000101 |
130 | 00000000 00000000 00000000 10000010 | 10000010 00000001 |
24706 | 00000000 00000000 01100000 10000010 | 10000010 11000001 00000001 |
となります。
仮定
上の表では5を00000000 00000000 00000000 00000101と表現されると書きましたがintel系のCPUではリトルエンディアンという記録方式を採用しており、00000101 00000000 00000000 00000000と逆順に格納されます。以下の議論ではこのことを仮定します。
また、符号化する対象の数字は4バイト(2^32 -1)以下とします。
Group Varint Encoding
Varint Encodingの欠点としてDecoding時に大量の分岐処理・シフト演算・マスク処理が必要ということがあります。例えばどこまでが一つの数字かというのを決めるためにMSBが1かどうかをみなければならないということとMSBを消さなければならないということが挙げられます。
今符号は4バイトで表せるのでバイト長は2ビットで表現できます。このため、4つの整数を一組にまとめるとその組の数のバイト長は1バイトで表現できます。
すなわち、以下の4つの数を
00000001 00001111 11111111 00000011 11111111 11111111 00000111 ======== ======== ================== ============================ 1 15 511 131071
の代わりに
00000110 00000001 00001111 11111111 00000011 11111111 11111111 00000001 ======== ======== ======== ================== ============================ tag 1 15 511 131071
と表現します。*1始めのtagは00|00|01|10と分けると1バイト,1バイト,2バイト,3バイトの数が連続してるということを意味します。
ここでtagの種類は256通りしかないことを利用して、各タグにおいて数が始まるoffsetとバイト長に応じたMaskに関する表を作って置きます。
00|00|01|10 => Offsets: +1,+2,+3,+5; Masks: 0xff,0xff,0xffff,0xffffff
例えば511をDecodeする時はOffset 5から11111111 00000011 11111111 11111111を取ってきて0xffffとandを取ってやればよい。
とりあえず、Decodeする要素数を4の倍数とするとDecodeのコードは以下のようになります。
uint32_t offsets[256][4]; uint32_t masks[256][4]; int length[256]; void decode_vb_group(const uint8_t * data , size_t decode_length , uint32_t * result){ for(int i = 0 ; i < decode_length ; i += 4){ uint8_t tag = *data; for(int j = 0 ; j < 4 ; ++j) *(result++) = masks[tag][j] & (*((uint32_t *)(data + offsets[tag][j]))); data += length[tag]; } }
コードを見るとシフト演算や条件分岐が消えていることが分かります。たとえば通常のVarint Encodingでは次のような復号処理となると思います。
void decode_vb(const uint8_t * data , size_t decode_length , uint32_t * result){ int r ,b; for(int i = 0 ; i < decode_length ; ++i){ r = 0x7f & (*data); b = 7; while(0x80 & (*data++)){ r = r | ((0x7f & (*data)) << b); b += 7; } *(result++) = r; } }
実験
以下のような小さな値が多くなるようにランダムな配列を作成して実験を行いました。
// 長さlenのランダムな値の入った配列を作成する void generate(int len , vector<uint32_t> & data){ data.clear(); data.resize(len); mt19937 gen(777); uniform_int<> dst( 1 , 0x7fffffff); variate_generator< mt19937& , uniform_int<> > rand(gen , dst); // 小さい値が多くなるようにするため上位ビットをランダムにマスクする int mask[] = {0xf , 0xf , 0xf , 0xf , 0xff , 0xfff , 0xfffff , 0xffffffff}; for(int i = 0 ; i < len ; ++i){ int m = mask[rand() % 8]; int r = rand(); data[i] = 1 + (r & m); } }
結果としては長さ100万のデータの復号を100回行うのにかかった時間は
Varint Encoding | Group Varing Encoding |
---|---|
0.85sec(117M numbers/sec) | 0.16sec(625M numbers/sec) |
となり、大幅に計算時間が短縮できてることが分かります。
感想
Varing Encodingを4つまとめて行うことによって大幅な計算時間の短縮が行えるというのは驚きでした。ところで疑問があるのですがOffsetから4バイト読むために
(*((uint32_t *)(data + offsets[tag][j])))
実装コード
最後に実験に用いた符号化も含めた全てのコードを載せておきます。
#include <ctime> #include <vector> #include <iostream> #include <boost/random.hpp> #include <boost/lexical_cast.hpp> #include <stdint.h> using namespace std; using namespace boost; #include <time.h> #include <sys/time.h> #include <stdio.h> double gettimeofday_sec() { struct timeval tv; gettimeofday(&tv, NULL); return tv.tv_sec + (double)tv.tv_usec*1e-6; } // 長さlenのランダムな値の入った配列を作成する void generate(int len , vector<uint32_t> & data){ data.clear(); data.resize(len); mt19937 gen(777); uniform_int<> dst( 1 , 0x7fffffff); variate_generator< mt19937& , uniform_int<> > rand(gen , dst); // 小さい値が多くなるようにするため上位ビットをランダムにマスクする int mask[] = {0xf , 0xf , 0xf , 0xf , 0xff , 0xfff , 0xfffff , 0xffffffff}; for(int i = 0 ; i < len ; ++i){ int m = mask[rand() % 8]; int r = rand(); data[i] = 1 + (r & m); } } void encode_vb(const vector<uint32_t> & data , vector<uint8_t> & result){ result.clear(); vector<uint32_t>::const_iterator it; for(it = data.begin() ; it != data.end() ; ++it){ uint32_t v = *it; //encode v while(v){ if(v <= 0x7f){ result.push_back((uint8_t)v); break; }else{ uint8_t e = 0x80 | (v & 0x7f); result.push_back(e); v = v >> 7; } } } } void encode_vb_group(const vector<uint32_t> & data , vector<uint8_t> & result){ result.clear(); for(int i = 0 ; i < data.size() ; i += 4){ uint8_t tag = 0; for(int j = 0 ; j < 4 ; ++j){ int v = data[i + j]; uint8_t t = 3; if(v < 0x100){ t = 0; }else if(v < 0x10000){ t = 1; }else if(v < 0x1000000){ t = 2; } tag |= t << (6 - 2 * j); } result.push_back(tag); for(int j = 0 ; j < 4 ; ++j){ int v = data[i + j]; while(v >= 0x100){ result.push_back((uint8_t)(v & 0xff)); v = v >> 8; } result.push_back((uint8_t)v); } } } void decode_vb(const uint8_t * data , size_t decode_length , uint32_t * result){ int r ,b; for(int i = 0 ; i < decode_length ; ++i){ r = 0x7f & (*data); b = 7; while(0x80 & (*data++)){ r = r | ((0x7f & (*data)) << b); b += 7; } *(result++) = r; } } uint32_t offsets[256][4]; uint32_t masks[256][4]; int length[256]; void init_group(){ for(int i = 0 ; i < 256 ; ++i){ offsets[i][0] = 1; for(int j = 0 ; j < 4 ; ++j){ int t = 0x3 & (i >> ( 6 - 2 * j)); if(j < 3){ offsets[i][j + 1] = offsets[i][j] + t + 1; }else if(j == 3){ length[i] = offsets[i][3] + t + 1; } if(t == 0){ masks[i][j] = 0xff; }else if(t == 1){ masks[i][j] = 0xffff; }else if(t == 2){ masks[i][j] = 0xffffff; }else{ masks[i][j] = 0xffffffff; } } } } void decode_vb_group(const uint8_t * data , size_t decode_length , uint32_t * result){ for(int i = 0 ; i < decode_length ; i += 4){ uint8_t tag = *data; for(int j = 0 ; j < 4 ; ++j) *(result++) = masks[tag][j] & (*((uint32_t *)(data + offsets[tag][j]))); data += length[tag]; } } int main(int argc , char ** argv){ if(argc < 2){ cerr << argv[0] << " array-length" << endl; return 1; } int len = lexical_cast<int>(argv[1]); if(len % 4 != 0){ // 簡単のため長さは4の倍数にする len = (len >> 2) << 2; } vector<uint32_t> data; generate(len , data); vector<uint8_t> vb_enc; encode_vb(data , vb_enc); //cout << data.size() * 4 << " " << vb_enc.size() << " " << (vb_enc.size() * 1.0 / (data.size() * 4.0)) << endl; uint32_t * dec = new uint32_t[ data.size() ]; double t1 = gettimeofday_sec(); for(int rep = 0 ; rep < 100 ; ++rep){ decode_vb(&(vb_enc[0]) , data.size() , dec); } double t2 = gettimeofday_sec(); cout << (t2 - t1) << "sec" << endl; init_group(); vector<uint8_t> vb_enc_g; encode_vb_group(data , vb_enc_g); //cout << data.size() * 4 << " " << vb_enc_g.size() << " " << (vb_enc_g.size() * 1.0 / (data.size() * 4.0)) << endl; t1 = gettimeofday_sec(); for(int rep = 0 ; rep < 100 ; ++rep){ decode_vb_group(&(vb_enc_g[0]) , data.size() , dec); } t2 = gettimeofday_sec(); cout << (t2 - t1) << "sec" << endl; delete dec; return 0; }