[機械学習] LDAのコードを書いてみた

昔書いたことがあったけど、どこかにいってしまったのでもう一度書いてみた。推論方法にはギブスサンプリングと変分ベイズの2つがあるけど、導出も実装もより楽なcollapsed gibbs sampling(Griffiths and Steyvers, PNAS, 2004)の方を採用。

Token.java

package lda;

public class Token {
  public int docId;
  public int wordId;
  public Token(int d , int w){
    docId = d;
    wordId = w;
  }
}

LDA.java

package lda;

import java.util.*;

public class LDA {
  int D; // number of document
  int K; // number of topic
  int W; // number of unique word
  int wordCount[][];
  int docCount[][];
  int topicCount[];
  // hyper parameter
  double alpha, beta;
  Token tokens[];
  double P[];
  // topic assignment
  int z[];
  Random rand;

  public LDA(int documentNum, int topicNum, int wordNum, List<Token> tlist,
      int seed) {
    this(documentNum, topicNum, wordNum, tlist, 50.0 / topicNum, 0.1, seed);
  }

  public LDA(int documentNum, int topicNum, int wordNum, List<Token> tlist,
      double alpha, double beta, int seed) {
    wordCount = new int[wordNum][topicNum];
    topicCount = new int[topicNum];
    docCount = new int[documentNum][topicNum];
    D = documentNum;
    K = topicNum;
    W = wordNum;
    tokens = tlist.toArray(new Token[0]);
    z = new int[tokens.length];
    this.alpha = alpha;
    this.beta = beta;
    P = new double[K];
    rand = new Random(seed);
    init();
  }

  private void init() {
    for (int i = 0; i < z.length; ++i) {
      Token t = tokens[i];
      int assign = rand.nextInt(K);
      wordCount[t.wordId][assign]++;
      docCount[t.docId][assign]++;
      topicCount[assign]++;
      z[i] = assign;
    }
  }

  private int selectNextTopic(Token t) {
    for (int k = 0; k < P.length; ++k) {
      P[k] = (wordCount[t.wordId][k] + beta) * (docCount[t.docId][k] + alpha)
          / (topicCount[k] + W * beta);
      if (k != 0) {
        P[k] += P[k - 1];
      }
    }
    double u = rand.nextDouble() * P[K - 1];
    for (int k = 0; k < P.length; ++k) {
      if (u < P[k]) {
        return k;
      }
    }
    return K - 1;
  }

  private void resample(int tokenId) {
    Token t = tokens[tokenId];
    int assign = z[tokenId];
    // remove from current topic
    wordCount[t.wordId][assign]--;
    docCount[t.docId][assign]--;
    topicCount[assign]--;
    assign = selectNextTopic(t);
    wordCount[t.wordId][assign]++;
    docCount[t.docId][assign]++;
    topicCount[assign]++;
    z[tokenId] = assign;
  }

  public void update() {
    for (int i = 0; i < z.length; ++i) {
      resample(i);
    }
  }

  public double[][] getTheta() {
    double theta[][] = new double[D][K];
    for (int i = 0; i < D; ++i) {
      double sum = 0.0;
      for (int j = 0; j < K; ++j) {
        theta[i][j] = alpha + docCount[i][j];
        sum += theta[i][j];
      }
      // normalize
      double sinv = 1.0 / sum;
      for (int j = 0; j < K; ++j) {
        theta[i][j] *= sinv;
      }
    }
    return theta;
  }

  public double[][] getPhi() {
    double phi[][] = new double[K][W];
    for (int i = 0; i < K; ++i) {
      double sum = 0.0;
      for (int j = 0; j < W; ++j) {
        phi[i][j] = beta + wordCount[j][i];
        sum += phi[i][j];
      }
      // normalize
      double sinv = 1.0 / sum;
      for (int j = 0; j < W; ++j) {
        phi[i][j] *= sinv;
      }
    }
    return phi;
  }
}

基本的な使い方としてはコンストラクタで(文章ID, 単語ID)のリストと潜在クラスの数を渡して、一定回数update()を呼び出す。推定したパラメータはgetTheta,getPhiで取得する。注意としてこの場合単一のサンプルからパラメータを推定しているが、まじめにやろうとすると複数のサンプルを取って平均する必要がある。

ThetaはD×Kの行列で文章が持っている潜在クラスの分布となる。またPhiはK×Wの行列でそのクラスにおける単語の出現確率を表す。

例としてUCI Machine Learning Repository: Bag of Words Data SetのNIPSデータセットでK=50とした場合の各クラスにおけるトピックの単語生成確率の一部を示す。これを見ると反復が進むに連れ、単語の出現確率がトピックによって大きく違ってくることがわかる。

1反復目

topic : 0
network 0.014092019554221354
model 0.00961357020472769
input 0.007852382258297598
learning 0.0075504643246238675
function 0.006795669490439543
unit 0.006267313106510514
neural 0.00594023534503064
algorithm 0.005764116550387631
result 0.005059641371815594
set 0.00485836274936644
topic : 5
network 0.01344023110856413
learning 0.009439113815432327
neural 0.008508036206338762
model 0.00795442249282367
input 0.007778272674887049
unit 0.0069226878449091785
function 0.006218088573162697
data 0.006218088573162697
system 0.006117431534341771
set 0.005890953196994688
topic : 49
network 0.011941178364697854
model 0.009365192988368095
learning 0.008919349365541791
input 0.008597351193500571
set 0.008126738480517249
function 0.0075818184970628776
neural 0.006120442177798879
system 0.00587275127622871
algorithm 0.005649829464815558
data 0.005501214923873456

10反復目

topic : 0
network 0.025539408544249742
linear 0.014513134050736839
contour 0.013714128652656194
capacity 0.013234725413807807
problem 0.011988276992802
wavelet 0.010582027492180065
function 0.007481886547627161
direction 0.007481886547627161
rotation 0.007449926331703935
point 0.007002483308778774
topic : 5
network 0.06928157250824087
output 0.0305151994792456
training 0.02848098200396099
neural 0.02346324556492562
recognition 0.021215919592230244
net 0.019821027609177942
layer 0.01722497864071949
set 0.01650815914942872
input 0.016488785649664106
character 0.015035773167317957
topic : 49
unit 0.09013773693760363
hidden 0.04481769982166848
training 0.03433223140859152
network 0.03115149308244893
weight 0.029243050086763374
rbf 0.01978857961388436
error 0.018406603651491372
output 0.014216803511538027
generalization 0.013975506121278933
input 0.013865825489342983

50反復目

topic : 0
segment 0.02273003693935464
contour 0.017135271693618368
region 0.0159885185687159
local 0.012166008152374348
surface 0.010637003985837726
wavelet 0.010463253512367656
shape 0.010393753322979626
point 0.009386000576853216
group 0.009142749913995118
rotation 0.00882999906174899
topic : 5
recognition 0.030339713633886357
character 0.028085248199496936
network 0.022057003668412175
digit 0.019434962347980998
system 0.015122071951757753
neural 0.014582960652229847
input 0.014043849352701941
set 0.011666858622965267
output 0.010833686614603958
training 0.010711161319256706
topic : 49
unit 0.0814758439901167
network 0.07836241125756047
input 0.0627824351143984
output 0.05246838840778206
hidden 0.05054651635064858
layer 0.04809933259789862
weight 0.04242340378916441
training 0.034710290599868715
pattern 0.022192497267739322
set 0.01342876068721066

100反復目

topic : 0
region 0.02414254659304815
segment 0.021018255724856386
local 0.015230095800627433
contour 0.015230095800627433
surface 0.013355521279712375
depth 0.010527215862191408
shape 0.00980369587166279
texture 0.00950771042099199
edge 0.00901440133654066
grouping 0.008817077702760127
topic : 5
character 0.03326499065788623
recognition 0.03314766417641939
digit 0.02338023459430484
system 0.012674193160455554
set 0.010239668670018593
handwritten 0.010122342188551752
image 0.009418383299750702
segmentation 0.008919745753516626
feature 0.008831750892416494
pattern 0.008802419272049784
topic : 49
input 0.07237126180147059
network 0.0704561431890157
unit 0.06993275673592245
output 0.05644366042210978
layer 0.052220883357380066
hidden 0.04764125189281403
weight 0.04369206320129215
training 0.035734210084942354
pattern 0.018307820226269023
set 0.01409693830820052

200反復目

topic : 0
image 0.05129741438909944
images 0.01835710548219398
region 0.017218428137263916
pixel 0.01667620083015436
segment 0.01393795292925111
local 0.013368614256786078
vision 0.012663718757543655
contour 0.012040157354367668
scene 0.009817025395218492
surface 0.009464577645597282
topic : 5
character 0.03446915797682203
recognition 0.0332252221639097
digit 0.02409291290326061
system 0.012988510280677329
image 0.011107436612370871
set 0.010348939165473107
handwritten 0.010348939165473107
feature 0.009893840697334448
pixel 0.009742141207954894
pattern 0.009681461412203073
topic : 49
input 0.07946527053340079
network 0.07009151773133951
unit 0.06570943206559061
output 0.05974013426894054
layer 0.0505389045245859
hidden 0.04640985267680675
weight 0.041579207061020564
training 0.03748465982478275
net 0.01888667420400594
pattern 0.017690514337239834

実験に使ったのは以下のコードである

class PComp implements Comparable<PComp>{
  int id;
  double prob;
  @Override
  public int compareTo(PComp o) {
    return Double.compare(prob, o.prob);
  }
}
public class NIPS {    
  public static void main(String[] args) throws Exception{
    Scanner sc = new Scanner(new File("data/docword.nips.txt"));
    int D = sc.nextInt();
    int W = sc.nextInt();
    int N = sc.nextInt();
    List<Token> tlist = new ArrayList<Token>();
    for(int i = 0 ; i < N ; ++i){
      int did = sc.nextInt() - 1;
      int wid = sc.nextInt() - 1;
      int count = sc.nextInt();
      for(int c = 0 ; c < count ; ++c){
        tlist.add(new Token(did , wid));
      }
    }
    String words[] = new String[W];
    sc = new Scanner(new File("data/vocab.nips.txt"));
    for(int i = 0 ; i < W ; ++i){
      words[i] = sc.nextLine();
    }
    int K = 50;
    LDA lda = new LDA(D, K , W, tlist, 777);
    for(int i = 0 ; i <= 200 ; ++i){
      lda.update();
      if(i % 10 == 0){
        PrintWriter out = new PrintWriter("output/wordtopic"+i+".txt");
        double phi[][] = lda.getPhi();
        outputWordTopicProb(phi , words, out);
        out.close();
      }
    }
  }

  private static void outputWordTopicProb(double phi[][], String[] words, PrintWriter out) {
    int K = phi.length;
    int W = phi[0].length;
    for(int k = 0 ; k < K ; ++k){
      out.println("topic : " + k);
      PComp ps[] = new PComp[W];
      for(int w = 0 ; w < W ; ++w){
        PComp pc = new PComp();
        pc.id = w; pc.prob = phi[k][w];
        ps[w] = pc;
      }
      Arrays.sort(ps);
      for(int i = 0 ; i < 10 ; ++i){
        // output related word
        PComp p = ps[W - 1 - i];
        out.println(words[p.id]+" "+p.prob);        
      }
    }
  }
}