A dual coordinate descent method for large-scale linear SVM
via しかしSVMも最近は速いらしい - 射撃しつつ前転 改
上記の論文の3.1まで読んでL1-Linear SVMを実装してみた.Shrinkingの部分はまだ読んでいない.
やっていることは双対問題
を各$\alpha_i$ごとに最小化していて,勾配方向が$w$を保存していると各成分の要素数程度でできて,1反復あたりの計算量が成分の非零の数程度でできるというもの.
ただ,自分の理解では双対問題には$\sum_i y_i \alpha_i = 0$なる制約があるはずなんだけど,その制約部分が消えている理由がわからなかった.(追記:上記の制約部分はバイアス項から出ているので論文で考えているのはバイアスなしの場合だから制約はなくていい.ただ論文と$x<-[x,1],w<-[w,b]$のように余分に一個次元を付け加えればいいとあるけどそれだとバイアス項も目的関数に入ってきて解いている問題が微妙に違うんじゃないかと思った)
とりあえずa9a,real-sim,news20,rcv1の4つのデータについて実験してみた.いずれもデータの80%の部分を学習用に20%の部分を検証用に用いている.
終了条件とかはLIBLINEARのコードを参考にした.
データセット | データ数 | データの次元 | 非零要素数 | 実行時間(sec) | 誤り確率 |
---|---|---|---|---|---|
a9a | 32561 | 123 | 451592 | 23.5 | 0.15 |
real-sim | 72309 | 20958 | 3709083 | 4.47 | 0.02 |
news20 | 19996 | 1355199 | 9097916 | 6.08 | 0.03 |
rcv1 | 677399 | 47236 | 49856258 | 196.8 | 0.02 |
SVNのメインコード
import java.util.*; class Feature{ int index; double value; Feature(int i , double v){ index = i; value = v; } } class Instance{ ArrayList<Feature> features; int label; Instance(){ features = new ArrayList<Feature>(); } void add(int i , double v){ features.add(new Feature(i , v)); } } public class L1SVM { double alpha[]; double w[]; Instance x[]; double C; double QD[]; int y[]; int seed; public L1SVM(Instance is[] , double U , int seed){ x = is; y = new int[is.length]; for(int i = 0 ; i < y.length ; i++) y[i] = is[i].label; C = U; int n = is.length; alpha = new double[n]; QD = new double[n]; for(int i = 0 ; i < n ; i++){ double s = 0.0; for(Feature f : is[i].features){ s += f.value * f.value; } QD[i] = s; } int fnum = 0; for(Instance i : is){ for(Feature f : i.features) fnum = Math.max(f.index, fnum); } fnum++; w = new double[fnum]; this.seed = seed; } void solve(){ int index[] = new int[x.length]; for(int i = 0 ; i < index.length ; i++){ index[i] = i; } int max_iter = 2000; double eps = 1.0E-2; int iter = 0; Random rand = new Random(seed); while(iter < max_iter){ shuffle(index, rand); double maxPG = Double.NEGATIVE_INFINITY; double minPG = Double.MAX_VALUE; for(int i : index){ double pg = update(i); maxPG = Math.max(maxPG, pg); minPG = Math.min(minPG, pg); } if(maxPG - minPG <= eps){ break; } iter++; } System.out.printf("Finish optimization: iter = %d\n" , iter); int nSV = 0; for(double a : alpha) if(a > 0)nSV++; System.out.printf("#of support vector = %d\n" , nSV); } double update(int i){ Instance xi = x[i]; int yi = y[i]; double G = 0.0; for(Feature f : xi.features){ G += w[f.index] * f.value; } G = yi * G - 1; double PG = 0.0; if(alpha[i] == 0.0){ PG = Math.min(G, 0); }else if(alpha[i] == C){ PG = Math.max(G, 0); }else{ PG = G; } if(Math.abs(PG) > 1.0e-12){ double alpha_old = alpha[i]; alpha[i] = Math.min(Math.max(alpha[i] - G / QD[i], 0.0), C); double d = (alpha[i] - alpha_old) * yi; for(Feature f : xi.features){ w[f.index] += d * f.value; } } return PG; } void shuffle(int array[] , Random rand){ int n = array.length; for(int i = 0 ; i < n ; i++){ int r = rand.nextInt(n - i) + i; int tmp = array[i]; array[i] = array[r]; array[r] = tmp; } } }
テスト用コード
import java.io.File; import java.util.*; public class Test { static List<Instance> loadData(String fname) { List<Instance> ilst = new ArrayList<Instance>(); try{ int feature_num = 0; int nonzeros = 0; Scanner sc = new Scanner(new File(fname)); while(sc.hasNext()){ String line = sc.nextLine(); String arr[] = line.split("\\s+"); Instance in = new Instance(); for(int i = 1 ; i < arr.length ; i++){ String ai = arr[i]; int sep = ai.indexOf(':'); int f = Integer.parseInt(ai.substring(0, sep)) - 1; feature_num = Math.max(f, feature_num); double v = Double.valueOf(ai.substring(sep + 1)); in.add(f, v); } if(arr[0].charAt(0)== '-'){ in.label = -1; }else{ in.label = 1; } nonzeros += in.features.size(); in.features.trimToSize(); ilst.add(in); } System.out.println("# of instance = " + (ilst.size())); System.out.println("# of features = " + (feature_num + 1)); System.out.println("# of nonzeros = " + (nonzeros)); }catch (Exception e) { e.printStackTrace(); } return ilst; } public static void main(String[] args) { List<Instance> ilst = loadData(args[0]); Collections.shuffle(ilst , new Random(11111)); double C = 1.0; double err = 0.0; double time = 0; int cv_num = 5; for(int cv = 0 ; cv < cv_num ; cv++){ List<Instance> learn = new ArrayList<Instance>(); List<Instance> test = new ArrayList<Instance>(); for(int i = 0 ; i < ilst.size() ; i++){ if(i % cv_num == cv){ test.add(ilst.get(i)); }else{ learn.add(ilst.get(i)); } } long t = System.currentTimeMillis(); L1SVM lsvm = new L1SVM(learn.toArray(new Instance[0]) , C , 11111); lsvm.solve(); time += System.currentTimeMillis() - t; System.out.printf("optimization end %d ms\n" , (System.currentTimeMillis() - t)); double w[] = lsvm.w; int en = 0; int tn = 0; for(Instance is : test){ double f = 0.0; for(Feature fi : is.features){ if(fi.index >= w.length)continue; f += fi.value * w[fi.index]; } if(is.label * f < 0)en++; tn++; } System.out.printf("error rate = %f\n" , (1.0 * en) / tn); err += 1.0 * en / tn; } System.out.printf("average error rate = %f\n" , err / cv_num); System.out.printf("average time = %f(s)\n" , time * 0.001 / cv_num); } }