[機械学習] AROWのコードを書いてみた

昨日のPFIセミナーで紹介されていたAROW (Adaptive Regularization Of Weight Vector)を実装してみた。

AROWはCrammerらによりNIPS 2009で提案された手法で、彼らが以前提案したConfidence weightedよりもノイズに強く、またCWとほぼ同等の性能を持っている。

今回実装したのは共分散行列を対角に近似した場合の式に基づいている。これは共分散行列をフルに持とうとすると素性の数の2乗程度のメモリが必要で、コストが大きすぎるためである。

追記 Featureクラスの定義を書くのを忘れてたので追加

public class Feature {
	int index;
	double weight;
	Feature(int i , double w){
		index = i;
		weight = w;
	}
}
import java.util.Arrays;

public class AROW {
  double mean[];
  double cov[];
  int F;
  double r;
  public AROW(int featureSize) {
    F = featureSize;
    mean = new double[F];
    cov  = new double[F];
    Arrays.fill(cov, 1.0);
    r = 0.1;
  }
  
  double getMargin(Feature[] x){
    double res = 0.0;
    for(Feature f : x){
      res += mean[f.index] * f.weight;
    }
    return res;
  }
  
  double getConfidence(Feature[] x){
    double res = 0.0;
    for(Feature f : x){
      res += cov[f.index] * f.weight * f.weight;
    }
    return res;
  }
  
  int update(Feature[] x , int label){
    double m = getMargin(x);
    int loss = m * label < 0 ? 1 : 0;
    if(m * label >= 1){
      return 0;
    }
    double v = getConfidence(x);
    double beta = 1.0 / (v + r);
    double alpha = (1.0 - label * m) * beta;
    // update mean
    for(Feature f : x){
      mean[f.index] += alpha * label * cov[f.index] * f.weight;
    }
    // update covariance
    for(Feature f : x){
      cov[f.index] = 1.0 / ( (1.0/cov[f.index]) + f.weight * f.weight / r);
    }
    return loss;
  }
  
  int predict(Feature[] x){
    double m = getMargin(x);
    return m > 0 ? 1 : -1;
  }
}

テストコード

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Scanner;


class Instance{
  int label;
  Feature x[];
}

public class New20{  
  public static void main(String[] args) throws Exception{
    Scanner sc = new Scanner(new File("news20.binary"));    
    List<Instance> data = new ArrayList<Instance>();    
    while(sc.hasNext()){
      String line = sc.nextLine();
      String arr[] = line.split("[\\s:]");
      int label = arr[0].charAt(0) == '+' ? +1 : -1;
      Feature x[] = new Feature[(arr.length - 1) / 2];
      for(int i = 1 ; i < arr.length ; i += 2){
        int index = Integer.parseInt(arr[i]) - 1;
        double  w = Double.parseDouble(arr[i + 1]);
        Feature f = new Feature(index, w);
        x[(i - 1) >> 1] = f;
      }
      Instance i = new Instance();
      i.label = label;
      i.x     = x;
      data.add(i);
    }
    Random rand = new Random(1111);
    Collections.shuffle(data , rand);
    List<Instance> train = new ArrayList<Instance>();
    List<Instance> test  = new ArrayList<Instance>();
    for(int i = 0 ; i < 15000 ; ++i){
      train.add(data.get(i));
    }
    for(int i = train.size(); i < data.size() ; ++i){
      test.add(data.get(i));
    }
    int F = 1355191;
    {
      // normal case
      AROW arow = new AROW(F);
      for(int rep = 0 ; rep < 10 ; ++rep){
        for(Instance i : train){
          arow.update(i.x, i.label);
        }
        int mistake = 0;
        for(Instance i : test){
          int l = arow.predict(i.x);
          if(l != i.label){
            mistake++;
          }
        }
        System.out.println(rep + " th iteration :");
        System.out.println("number of mistake = " + mistake + " error rate = " + (mistake * 1.0 / test.size()));
      }      
    }
  }
}


実行結果

0 th iteration :
number of mistake = 158 error rate = 0.03162530024019215
1 th iteration :
number of mistake = 156 error rate = 0.03122497998398719
2 th iteration :
number of mistake = 152 error rate = 0.03042433947157726
3 th iteration :
number of mistake = 153 error rate = 0.030624499599679743
4 th iteration :
number of mistake = 153 error rate = 0.030624499599679743
5 th iteration :
number of mistake = 152 error rate = 0.03042433947157726
6 th iteration :
number of mistake = 152 error rate = 0.03042433947157726
7 th iteration :
number of mistake = 152 error rate = 0.03042433947157726
8 th iteration :
number of mistake = 152 error rate = 0.03042433947157726
9 th iteration :
number of mistake = 152 error rate = 0.03042433947157726

結果を見ると一回目の試行で誤り率が3%となっており、収束が非常に早いことがわかる。