[機械学習] AROWの落ち穂拾い2

とりあえず以下のコードをollのoll.cppに突っ込むことによってAROWを使うようにできる。(あとoll.hppやoll_train.cppの学習手法が並んでいるところにAROW用の値を付け加える)

バイアスの部分とかはちゃんとなってるかあまり自信ないです。

CW(Confidence-weighted)のコードと非常によく似たコードになっている。

  //
  // Adaptive Regularization Of Weight Vector
  //
  void oll::updateAROW(const fv_t& fv, const int y, const float alpha) {
    for (size_t i = 0; i < fv.size(); i++){
      if (cov.size() <= fv[i].first) {
        w.resize(fv[i].first+1);
        const size_t prevSize = cov.size();
        cov.resize(fv[i].first+1);
        for (size_t j = prevSize; j < cov.size(); j++){
          cov[j] = 1.f;
        }
      }
      w[fv[i].first] += fv[i].second * alpha * y * cov[fv[i].first];
      cov[fv[i].first] = 1.f / (1.f/cov[fv[i].first] + fv[i].second * fv[i].second / C);
    }
    b += alpha * y * covb * bias;
    covb = 1.f / (1.f/ covb + bias * bias / C);
  }

  template <>
  void oll::trainExample(const AROW_s& a, const fv_t& fv, const int y) {
    const float score = getMargin(w, b, fv) * y;
    if(score >= 1){
      return ;
    }
    const float var   = getVariance(fv);
    const float beta  = 1.0 / (var + C);
    const float alpha = (1.0 - score) * beta;
    updateAROW(fv, y, alpha);
  }