min_heapを用いた上位r個の要素の抽出
MG勉強会の発表があるため4.6ランキング検索の部分を読むついでに、最後のサブセクションの上位r個の要素を取り出す部分について実装してみた。
情報検索において、N個の候補集合から上位r個の要素を取り出すことが多い。
値が配列に格納されているとするとこれを実現するためのコードはもっとも単純に行うと以下のようになる
//長さlenの配列arrayの中でトップr個の値をresultに挿入する void sort_method(int * array , int len, int r , vector<int> & result){ sort(array , array + len); copy(array + len - r , array + len , back_inserter(result)); }
しかし、Nが大きいとき、MGの例だとN=100万のときにsortの処理にはおおよそ1000000 * lg(1000000) \approx 2000万回の比較が必要となる。しかし、求めたい上位件数はr=100件で十分であることが多いため、この配列全部をソートするのには無駄が多い。
Managing Gigabytesの4.6のSortingのサブセクションではこれを防ぐための方法として、min-heapを用いる方法が紹介されている。
この方法では配列の要素を順に見ていくと同時にこれまでの上位r件の値をheapに保持しておく。
新規の要素を処理する際にはheapのr番めの値(最小値、これはO(1)で取得可能)との比較を行い
- heapの最小値 < 新規の要素
- heapの最小値を捨て、新規の要素を追加する
- otherwise
- 新規の要素を捨てる
を繰り返す。最後まで処理が終わればヒープには上位r件の値が保持される。そして、ヒープの中にある程度要素があるときには現在のトップrの値がヒープから捨てられる確率は小さくなるため、比較回数の期待値は小さくなる。
以上をコードにすると以下のようになる
//長さlenの配列arrayの中でトップr個の値をresultに挿入する void heap_method(int * array , int len , int r , vector<int> & result){ priority_queue<int , vector<int> , greater<int> > q; for(int i = 0 ; i < r ; i++){ q.push(array[i]); } for(int i = r + 1 ; i < len ; i++){ if(q.top() < array[i]){ q.pop(); q.push(array[i]); } } for(int i = 0 ; i < r ; i++){ result.push_back(q.top()); q.pop(); } reverse(result.begin() , result.end()); }
以下に配列の大きさに対して選択にかかった時間を示す。なお、実験環境はIntel Core i7 920(2.67GHz), メモリ3G, Ubuntu 9.04で測定した。また、配列の要素は一様乱数を用いて生成した。なおグラフは両対数グラフとなっている。Nが大きいときにはヒープを用いた方法の方が100倍程度高速となっている。
追記:
コメントで指摘されていますがstlのpartial_sortがほとんどおなじことを行っています。ソース
参考までにstl_algo.hの中を貼り付けておきます
template <class _InputIter, class _RandomAccessIter, class _Distance, class _Tp> _RandomAccessIter __partial_sort_copy(_InputIter __first, _InputIter __last, _RandomAccessIter __result_first, _RandomAccessIter __result_last, _Distance*, _Tp*) { if (__result_first == __result_last) return __result_last; _RandomAccessIter __result_real_last = __result_first; while(__first != __last && __result_real_last != __result_last) { *__result_real_last = *__first; ++__result_real_last; ++__first; } make_heap(__result_first, __result_real_last); while (__first != __last) { if (*__first < *__result_first) __adjust_heap(__result_first, _Distance(0), _Distance(__result_real_last - __result_first), _Tp(*__first)); ++__first; } sort_heap(__result_first, __result_real_last); return __result_real_last; }