[機械学習] AROWの落ち穂拾い2
とりあえず以下のコードをollのoll.cppに突っ込むことによってAROWを使うようにできる。(あとoll.hppやoll_train.cppの学習手法が並んでいるところにAROW用の値を付け加える)
バイアスの部分とかはちゃんとなってるかあまり自信ないです。
CW(Confidence-weighted)のコードと非常によく似たコードになっている。
// // Adaptive Regularization Of Weight Vector // void oll::updateAROW(const fv_t& fv, const int y, const float alpha) { for (size_t i = 0; i < fv.size(); i++){ if (cov.size() <= fv[i].first) { w.resize(fv[i].first+1); const size_t prevSize = cov.size(); cov.resize(fv[i].first+1); for (size_t j = prevSize; j < cov.size(); j++){ cov[j] = 1.f; } } w[fv[i].first] += fv[i].second * alpha * y * cov[fv[i].first]; cov[fv[i].first] = 1.f / (1.f/cov[fv[i].first] + fv[i].second * fv[i].second / C); } b += alpha * y * covb * bias; covb = 1.f / (1.f/ covb + bias * bias / C); } template <> void oll::trainExample(const AROW_s& a, const fv_t& fv, const int y) { const float score = getMargin(w, b, fv) * y; if(score >= 1){ return ; } const float var = getVariance(fv); const float beta = 1.0 / (var + C); const float alpha = (1.0 - score) * beta; updateAROW(fv, y, alpha); }