min_heapを用いた上位r個の要素の抽出

MG勉強会の発表があるため4.6ランキング検索の部分を読むついでに、最後のサブセクションの上位r個の要素を取り出す部分について実装してみた。

情報検索において、N個の候補集合から上位r個の要素を取り出すことが多い。
値が配列に格納されているとするとこれを実現するためのコードはもっとも単純に行うと以下のようになる

//長さlenの配列arrayの中でトップr個の値をresultに挿入する
void sort_method(int * array , int len, int r , vector<int> & result){
  sort(array , array + len);
  copy(array + len - r  , array + len , back_inserter(result));
}

しかし、Nが大きいとき、MGの例だとN=100万のときにsortの処理にはおおよそ1000000 * lg(1000000) \approx 2000万回の比較が必要となる。しかし、求めたい上位件数はr=100件で十分であることが多いため、この配列全部をソートするのには無駄が多い。
Managing Gigabytesの4.6のSortingのサブセクションではこれを防ぐための方法として、min-heapを用いる方法が紹介されている。
この方法では配列の要素を順に見ていくと同時にこれまでの上位r件の値をheapに保持しておく。
新規の要素を処理する際にはheapのr番めの値(最小値、これはO(1)で取得可能)との比較を行い

  • heapの最小値 < 新規の要素
    • heapの最小値を捨て、新規の要素を追加する
  • otherwise
    • 新規の要素を捨てる

を繰り返す。最後まで処理が終わればヒープには上位r件の値が保持される。そして、ヒープの中にある程度要素があるときには現在のトップrの値がヒープから捨てられる確率は小さくなるため、比較回数の期待値は小さくなる。

以上をコードにすると以下のようになる

//長さlenの配列arrayの中でトップr個の値をresultに挿入する
void heap_method(int * array , int len , int r , vector<int> & result){
  priority_queue<int , vector<int> , greater<int> > q;
  for(int i = 0 ; i < r ; i++){
    q.push(array[i]);
  }

  for(int i = r + 1 ; i < len ; i++){
    if(q.top() < array[i]){
      q.pop();
      q.push(array[i]);
    }
  }

  for(int i = 0 ; i < r ; i++){
    result.push_back(q.top());
    q.pop();
  }
  reverse(result.begin() , result.end());
}

以下に配列の大きさに対して選択にかかった時間を示す。なお、実験環境はIntel Core i7 920(2.67GHz), メモリ3G, Ubuntu 9.04で測定した。また、配列の要素は一様乱数を用いて生成した。なおグラフは両対数グラフとなっている。Nが大きいときにはヒープを用いた方法の方が100倍程度高速となっている。

追記:

コメントで指摘されていますがstlのpartial_sortがほとんどおなじことを行っています。ソース

参考までにstl_algo.hの中を貼り付けておきます

template <class _InputIter, class _RandomAccessIter, class _Distance,
          class _Tp>
_RandomAccessIter __partial_sort_copy(_InputIter __first,
                                      _InputIter __last,
                                      _RandomAccessIter __result_first,
                                      _RandomAccessIter __result_last, 
                                      _Distance*, _Tp*) {
  if (__result_first == __result_last) return __result_last;
  _RandomAccessIter __result_real_last = __result_first;
  while(__first != __last && __result_real_last != __result_last) {
    *__result_real_last = *__first;
    ++__result_real_last;
    ++__first;
  }
  make_heap(__result_first, __result_real_last);
  while (__first != __last) {
    if (*__first < *__result_first) 
      __adjust_heap(__result_first, _Distance(0),
                    _Distance(__result_real_last - __result_first),
                    _Tp(*__first));
    ++__first;
  }
  sort_heap(__result_first, __result_real_last);
  return __result_real_last;
}