[機械学習] CVB0を実装してみた

On Smoothing and Inference for Topic Models (UAI 2009) pdfで述べられているLDAの推論方法であるCVB0を実装してみた。
これはTehらのCVBで述べられている期待値の2次の項までの近似の部分をさらに近似して0次の項だけで近似したものとなっている。二次近似に比べexpの計算がいらない分、高速に実行することができる。

手元で実験したところ前回の記事で実装したCGS(Collapsed Gibbs samplerより2倍程度高速であった。

Token.java 前回のものに文章重みの項が付け加わっている

public class Token {
  public int docId;
  public int wordId;
  public double weight;
  public Token(int d , int w ){
    docId = d;
    wordId = w;
    weight = 1.0;
  }
  public Token(int d , int w , double ww){
    docId = d;
    wordId = w;
    weight = ww;
  }
}

LDACVB0.java

import java.util.*;

public class LDACVB0 {
  int D; // number of document
  int K; // number of topic
  int W; // number of unique word
  float wordExpectation[][];
  float docExpectation[][];
  float topicExpectation[];
  // hyper parameter
  double alpha, beta;
  Token tokens[];
  // token probability
  float gamma[][];
  Random rand;

  public LDACVB0(int documentNum, int topicNum, int wordNum, List<Token> tlist,
      int seed) {
    this(documentNum, topicNum, wordNum, tlist, 50.0 / topicNum, 0.1, seed);
  }

  public LDACVB0(int documentNum, int topicNum, int wordNum, List<Token> tlist,
      double alpha, double beta, int seed) {
    wordExpectation = new float[wordNum][topicNum];
    topicExpectation = new float[topicNum];
    docExpectation = new float[documentNum][topicNum];
    D = documentNum;
    K = topicNum;
    W = wordNum;
    tokens = tlist.toArray(new Token[0]);
    gamma = new float[tokens.length][topicNum];
    this.alpha = alpha;
    this.beta = beta;
    rand = new Random(seed);
    init();
  }

  private void init() {
    for (int i = 0; i < gamma.length ; ++i) {
      double sum = 0.0;
      for(int j = 0 ; j < K ; ++j){
        gamma[i][j] = rand.nextFloat();
        sum += gamma[i][j];
      }
      double sinv = 1.0 / sum;
      for(int j = 0 ; j < K ; ++j){
        gamma[i][j] *= sinv;
      }
    }
  }

  private void updateExpectation(){
    Arrays.fill(topicExpectation, 0.0f);    
    for(int i = 0 ; i < W ; ++i){
      Arrays.fill(wordExpectation[i], 0.0f);
    }    
    for(int i = 0 ; i < D ; ++i){
      Arrays.fill(docExpectation[i], 0.0f);
    }
    for (int i = 0; i < tokens.length; ++i) {
      Token t = tokens[i];
      for (int k = 0; k < K; ++k) {
        topicExpectation[k] += t.weight * gamma[i][k];
        wordExpectation[t.wordId][k] += t.weight * gamma[i][k];
        docExpectation[t.docId][k] += t.weight * gamma[i][k];
      }
    }
  }
  
  private void updateGamma(){
    for(int i = 0 ; i < tokens.length ; ++i){
      Token t = tokens[i];
      double sum = 0.0;
      for(int k = 0 ; k < K ; ++k){
        double g = gamma[i][k];
        double wk = (wordExpectation[t.wordId][k] + beta) - g;
        double dk = (docExpectation [t.docId][k] + alpha) - g;
        double kk = (topicExpectation[k]      + W * beta) - g;
        gamma[i][k] =(float)(wk * dk / kk);
        sum += gamma[i][k];
      }
      double sinv = 1.0 / sum;
      for(int k = 0 ; k < K ; ++k){
        gamma[i][k] *= sinv;
      }
    }
  }
  
  public void update() {
    updateExpectation();
    updateGamma();
  }

  public double[][] getTheta() {
    double theta[][] = new double[D][K];
    for(int d = 0 ; d < D ; ++d){
      double sum = 0.0;
      for(int k = 0 ; k < K ; ++k){
        theta[d][k] = alpha + docExpectation[d][k];
        sum += theta[d][k];
      }
      double sinv = 1.0 / sum;
      for(int k = 0 ; k < K ; ++k){
        theta[d][k] *= sinv;
      }
    }
    return theta;
  }

  public double[][] getPhi() {
    double phi[][] = new double[K][W];
    for(int k = 0 ; k < K ; ++k){
      double sum = 0.0;
      for(int w = 0 ; w < W ; ++w){
        phi[k][w] = wordExpectation[w][k] + beta;
        sum += phi[k][w];
      }
      double sinv = 1.0 / sum;
      for(int w = 0 ; w < W ; ++w){
        phi[k][w] *= sinv;
      }
    }
    return phi;
  }
}

前回のテストコードで実験したところ以下のようなトピックが抽出できた

topic : 0
graph 0.01639231445222314
weight 0.014217440489921464
pruning 0.013566353552508154
level 0.013337075394709439
network 0.012245316879923073
variables 0.01154173846067167
component 0.009558425902995188
contribution 0.009189845206519427
structure 0.009004521752442873
problem 0.008286439753653205
topic : 1
learning 0.06361167333143003
algorithm 0.05496684033178554
gradient 0.029577348747162977
weight 0.02347752368853067
error 0.021588484323139926
convergence 0.015966324298809526
rate 0.014369645761162692
descent 0.012982777561825492
function 0.010329638976695813
parameter 0.009962002418224576
topic : 2
direction 0.03520721401428869
motion 0.03398034949443053
eye 0.020649724516946414
cell 0.018779895385126283
head 0.015734231197231024
velocity 0.015713319780626806
model 0.014333959786142747
visual 0.01297128537367679
position 0.012224118942964374
system 0.011606187759166229