[機械学習] LDAのコードを書いてみた

昔書いたことがあったけど、どこかにいってしまったのでもう一度書いてみた。推論方法にはギブスサンプリングと変分ベイズの2つがあるけど、導出も実装もより楽なcollapsed gibbs sampling(Griffiths and Steyvers, PNAS, 2004)の方を採用。

Token.java

package lda;

public class Token {
  public int docId;
  public int wordId;
  public Token(int d , int w){
    docId = d;
    wordId = w;
  }
}

LDA.java

package lda;

import java.util.*;

public class LDA {
  int D; // number of document
  int K; // number of topic
  int W; // number of unique word
  int wordCount[][];
  int docCount[][];
  int topicCount[];
  // hyper parameter
  double alpha, beta;
  Token tokens[];
  double P[];
  // topic assignment
  int z[];
  Random rand;

  public LDA(int documentNum, int topicNum, int wordNum, List<Token> tlist,
      int seed) {
    this(documentNum, topicNum, wordNum, tlist, 50.0 / topicNum, 0.1, seed);
  }

  public LDA(int documentNum, int topicNum, int wordNum, List<Token> tlist,
      double alpha, double beta, int seed) {
    wordCount = new int[wordNum][topicNum];
    topicCount = new int[topicNum];
    docCount = new int[documentNum][topicNum];
    D = documentNum;
    K = topicNum;
    W = wordNum;
    tokens = tlist.toArray(new Token[0]);
    z = new int[tokens.length];
    this.alpha = alpha;
    this.beta = beta;
    P = new double[K];
    rand = new Random(seed);
    init();
  }

  private void init() {
    for (int i = 0; i < z.length; ++i) {
      Token t = tokens[i];
      int assign = rand.nextInt(K);
      wordCount[t.wordId][assign]++;
      docCount[t.docId][assign]++;
      topicCount[assign]++;
      z[i] = assign;
    }
  }

  private int selectNextTopic(Token t) {
    for (int k = 0; k < P.length; ++k) {
      P[k] = (wordCount[t.wordId][k] + beta) * (docCount[t.docId][k] + alpha)
          / (topicCount[k] + W * beta);
      if (k != 0) {
        P[k] += P[k - 1];
      }
    }
    double u = rand.nextDouble() * P[K - 1];
    for (int k = 0; k < P.length; ++k) {
      if (u < P[k]) {
        return k;
      }
    }
    return K - 1;
  }

  private void resample(int tokenId) {
    Token t = tokens[tokenId];
    int assign = z[tokenId];
    // remove from current topic
    wordCount[t.wordId][assign]--;
    docCount[t.docId][assign]--;
    topicCount[assign]--;
    assign = selectNextTopic(t);
    wordCount[t.wordId][assign]++;
    docCount[t.docId][assign]++;
    topicCount[assign]++;
    z[tokenId] = assign;
  }

  public void update() {
    for (int i = 0; i < z.length; ++i) {
      resample(i);
    }
  }

  public double[][] getTheta() {
    double theta[][] = new double[D][K];
    for (int i = 0; i < D; ++i) {
      double sum = 0.0;
      for (int j = 0; j < K; ++j) {
        theta[i][j] = alpha + docCount[i][j];
        sum += theta[i][j];
      }
      // normalize
      double sinv = 1.0 / sum;
      for (int j = 0; j < K; ++j) {
        theta[i][j] *= sinv;
      }
    }
    return theta;
  }

  public double[][] getPhi() {
    double phi[][] = new double[K][W];
    for (int i = 0; i < K; ++i) {
      double sum = 0.0;
      for (int j = 0; j < W; ++j) {
        phi[i][j] = beta + wordCount[j][i];
        sum += phi[i][j];
      }
      // normalize
      double sinv = 1.0 / sum;
      for (int j = 0; j < W; ++j) {
        phi[i][j] *= sinv;
      }
    }
    return phi;
  }
}

基本的な使い方としてはコンストラクタで(文章ID, 単語ID)のリストと潜在クラスの数を渡して、一定回数update()を呼び出す。推定したパラメータはgetTheta,getPhiで取得する。注意としてこの場合単一のサンプルからパラメータを推定しているが、まじめにやろうとすると複数のサンプルを取って平均する必要がある。

ThetaはD×Kの行列で文章が持っている潜在クラスの分布となる。またPhiはK×Wの行列でそのクラスにおける単語の出現確率を表す。

例としてUCI Machine Learning Repository: Bag of Words Data SetのNIPSデータセットでK=50とした場合の各クラスにおけるトピックの単語生成確率の一部を示す。これを見ると反復が進むに連れ、単語の出現確率がトピックによって大きく違ってくることがわかる。

1反復目

topic : 0
network 0.014092019554221354
model 0.00961357020472769
input 0.007852382258297598
learning 0.0075504643246238675
function 0.006795669490439543
unit 0.006267313106510514
neural 0.00594023534503064
algorithm 0.005764116550387631
result 0.005059641371815594
set 0.00485836274936644
topic : 5
network 0.01344023110856413
learning 0.009439113815432327
neural 0.008508036206338762
model 0.00795442249282367
input 0.007778272674887049
unit 0.0069226878449091785
function 0.006218088573162697
data 0.006218088573162697
system 0.006117431534341771
set 0.005890953196994688
topic : 49
network 0.011941178364697854
model 0.009365192988368095
learning 0.008919349365541791
input 0.008597351193500571
set 0.008126738480517249
function 0.0075818184970628776
neural 0.006120442177798879
system 0.00587275127622871
algorithm 0.005649829464815558
data 0.005501214923873456

10反復目

topic : 0
network 0.025539408544249742
linear 0.014513134050736839
contour 0.013714128652656194
capacity 0.013234725413807807
problem 0.011988276992802
wavelet 0.010582027492180065
function 0.007481886547627161
direction 0.007481886547627161
rotation 0.007449926331703935
point 0.007002483308778774
topic : 5
network 0.06928157250824087
output 0.0305151994792456
training 0.02848098200396099
neural 0.02346324556492562
recognition 0.021215919592230244
net 0.019821027609177942
layer 0.01722497864071949
set 0.01650815914942872
input 0.016488785649664106
character 0.015035773167317957
topic : 49
unit 0.09013773693760363
hidden 0.04481769982166848
training 0.03433223140859152
network 0.03115149308244893
weight 0.029243050086763374
rbf 0.01978857961388436
error 0.018406603651491372
output 0.014216803511538027
generalization 0.013975506121278933
input 0.013865825489342983

50反復目

topic : 0
segment 0.02273003693935464
contour 0.017135271693618368
region 0.0159885185687159
local 0.012166008152374348
surface 0.010637003985837726
wavelet 0.010463253512367656
shape 0.010393753322979626
point 0.009386000576853216
group 0.009142749913995118
rotation 0.00882999906174899
topic : 5
recognition 0.030339713633886357
character 0.028085248199496936
network 0.022057003668412175
digit 0.019434962347980998
system 0.015122071951757753
neural 0.014582960652229847
input 0.014043849352701941
set 0.011666858622965267
output 0.010833686614603958
training 0.010711161319256706
topic : 49
unit 0.0814758439901167
network 0.07836241125756047
input 0.0627824351143984
output 0.05246838840778206
hidden 0.05054651635064858
layer 0.04809933259789862
weight 0.04242340378916441
training 0.034710290599868715
pattern 0.022192497267739322
set 0.01342876068721066

100反復目

topic : 0
region 0.02414254659304815
segment 0.021018255724856386
local 0.015230095800627433
contour 0.015230095800627433
surface 0.013355521279712375
depth 0.010527215862191408
shape 0.00980369587166279
texture 0.00950771042099199
edge 0.00901440133654066
grouping 0.008817077702760127
topic : 5
character 0.03326499065788623
recognition 0.03314766417641939
digit 0.02338023459430484
system 0.012674193160455554
set 0.010239668670018593
handwritten 0.010122342188551752
image 0.009418383299750702
segmentation 0.008919745753516626
feature 0.008831750892416494
pattern 0.008802419272049784
topic : 49
input 0.07237126180147059
network 0.0704561431890157
unit 0.06993275673592245
output 0.05644366042210978
layer 0.052220883357380066
hidden 0.04764125189281403
weight 0.04369206320129215
training 0.035734210084942354
pattern 0.018307820226269023
set 0.01409693830820052

200反復目

topic : 0
image 0.05129741438909944
images 0.01835710548219398
region 0.017218428137263916
pixel 0.01667620083015436
segment 0.01393795292925111
local 0.013368614256786078
vision 0.012663718757543655
contour 0.012040157354367668
scene 0.009817025395218492
surface 0.009464577645597282
topic : 5
character 0.03446915797682203
recognition 0.0332252221639097
digit 0.02409291290326061
system 0.012988510280677329
image 0.011107436612370871
set 0.010348939165473107
handwritten 0.010348939165473107
feature 0.009893840697334448
pixel 0.009742141207954894
pattern 0.009681461412203073
topic : 49
input 0.07946527053340079
network 0.07009151773133951
unit 0.06570943206559061
output 0.05974013426894054
layer 0.0505389045245859
hidden 0.04640985267680675
weight 0.041579207061020564
training 0.03748465982478275
net 0.01888667420400594
pattern 0.017690514337239834

実験に使ったのは以下のコードである

class PComp implements Comparable<PComp>{
  int id;
  double prob;
  @Override
  public int compareTo(PComp o) {
    return Double.compare(prob, o.prob);
  }
}
public class NIPS {    
  public static void main(String[] args) throws Exception{
    Scanner sc = new Scanner(new File("data/docword.nips.txt"));
    int D = sc.nextInt();
    int W = sc.nextInt();
    int N = sc.nextInt();
    List<Token> tlist = new ArrayList<Token>();
    for(int i = 0 ; i < N ; ++i){
      int did = sc.nextInt() - 1;
      int wid = sc.nextInt() - 1;
      int count = sc.nextInt();
      for(int c = 0 ; c < count ; ++c){
        tlist.add(new Token(did , wid));
      }
    }
    String words[] = new String[W];
    sc = new Scanner(new File("data/vocab.nips.txt"));
    for(int i = 0 ; i < W ; ++i){
      words[i] = sc.nextLine();
    }
    int K = 50;
    LDA lda = new LDA(D, K , W, tlist, 777);
    for(int i = 0 ; i <= 200 ; ++i){
      lda.update();
      if(i % 10 == 0){
        PrintWriter out = new PrintWriter("output/wordtopic"+i+".txt");
        double phi[][] = lda.getPhi();
        outputWordTopicProb(phi , words, out);
        out.close();
      }
    }
  }

  private static void outputWordTopicProb(double phi[][], String[] words, PrintWriter out) {
    int K = phi.length;
    int W = phi[0].length;
    for(int k = 0 ; k < K ; ++k){
      out.println("topic : " + k);
      PComp ps[] = new PComp[W];
      for(int w = 0 ; w < W ; ++w){
        PComp pc = new PComp();
        pc.id = w; pc.prob = phi[k][w];
        ps[w] = pc;
      }
      Arrays.sort(ps);
      for(int i = 0 ; i < 10 ; ++i){
        // output related word
        PComp p = ps[W - 1 - i];
        out.println(words[p.id]+" "+p.prob);        
      }
    }
  }
}

[機械学習] トピックモデル関係の論文メモ

最近読んだトピックモデル関係の論文のざっとしたメモ。内容については間違って理解しているところも多々あると思います。

(追記 12/24) 最後のほうに論文を読む基礎となる文献を追加しました。

Efficient Methods for Topic Model Inference on Streaming Document Collections (KDD 2009)

論文の話は2つあって一つ目がSparseLDAというCollapsed Gibbs samplerの省メモリかつ高速な方法の提案と2つ目はオンラインで文章が入力されるような場合において訓練データと新規データをどう使うかという戦略について述べて実験している。

Collapsed Gibbs samplerを高速化しようという論文はPorteous et al.(KDD 2008)でも述べられているけどそれよりも2倍ぐらい高速(通常のCollapsed Gibbs samplerの20倍)かつ省メモリというのがSparseLDAの売り。方法については3.4で述べられているけど基本的にはLDAにおいて単語-トピックカウントの行列がほとんど0であることを利用している。

Modeling Social Annotation Data with Content Relevance using a Topic Model (NIPS 2009)

Social Bookmarkのような文章+annotationがついたデータの生成モデルを考えようという話。
ただ、Social Bookmarkは「あとで読む」みたいなトピックとは関係しないアノテーションが含まれるのでその部分をうまいことモデル化する。
文章の生成モデルのところはLDAとほぼおなじで、その後アノテーションを生成するときにトピックを文章のクラスの経験分布から生成する(例えば文章にトピック1の単語が2つ、トピック2の単語が3つ、トピック3の単語が5つ含まれる場合、各トピックの選択確率は(2/10,3/10,5/10)となる)。つぎにアノテーションが関係するかしないかをベルヌーイ分布により選択し、関係しない場合はトピックによらない多項分布から、そうでないときはトピックに応じた多項分布から生成する。

推論方法はGibbs samplingを使っている。

Supervised topic models (NIPS 2007)

LDAのモデルにおいて各文章に教師データがついているようなモデルについての論文
文章の生成モデルはLDAと同じで、つぎに文章のクラスの経験分布とパラメータηの内積を平均とした正規分布からresponceをつくる。
文章とresponceが与えられた元で学習を行うのがこの論文の話。推論は変分ベイズを使う。
このままだと回帰にしか使えないが一般化線形モデルを利用することによってラベルのような離散値の出力にも対応できるというのが論文2.3で述べられている
実装はhttp://www.cs.princeton.edu/~chongw/slda/にある。

Named Entitiy Recognition in Query (SIGIR 2009)

Named Entity(固有表現)をモデル化するという話 + 教師データを用いるためにLDAを拡張したWS(Weak Supervised)-LDAという方法を考えたという話

WS-LDAは(LDAのモデルの確率値)+λ*(文章の経験分布と教師データのクラスラベルの値と内積をとったもの)を目的とする。λ=0のときは単なるLDAになる。そうでないときに関しては4.2.2のE-stepの式をみるとわかるが文章のトピック帰属確率の通常の変分ベイズの更新式に教師データの影響を足した形で更新することになる。

Dirichlet-Bernoulli Alignment: A Generative Model for Multi-Class Multi-Label Multi-Instance Corpora (NIPS 2009)

webページの分類みたいなことを考えるとページは複数のテキスト、画像などから構成され(Multi Instance)、スポーツ、政治などの複数のクラスないしはトピックを持つ(Multi-Class)、その場合分類でのラベル付も複数のラベルとなりえる(Multi Label)。このようなMulti-Class Multi-Label Multi-Instanceの問題を考えている論文。

データの生成モデルを次のようにする。

ラベルとデータが与えられた時の推論アルゴリズムとしては変分ベイズを用いる。

追記:(12/24)

上記の論文を読む上で基本となるLDA関係の文献を挙げておく

Latent Dirichlet allocation (JMLR 2003)

LDAのオリジナル論文。ジャーナル版であるため会議の論文では省略されがちなほかのモデルとの関係性や推論方法に関して詳しく述べられている。推論には変分ベイズを用いている

Finding scientific topics (PNAS 2004)

LDAをギブスサンプラーを用いて推論しようという話が乗っている論文。更新式の導出は丁寧ではないので Wikipeidaの説明 (http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation)を見ることをお勧めする。
余談ではあるがこの論文が載ったPNASは学術雑誌としてはNature,Scienceの次ぐらいに位置している非常に権威のある論文誌である。

A collapsed variational Bayesian inference algorithm for latent Dirichlet Allocation (NIPS 2007)

LDAの変分ベイズでの推論方法を工夫するともっと精度がよくなりますよという話。この話の内容に触れると変分ベイズについて話さなければならないので具体的なところは論文を見てください。

On Smoothing and Inference for Topic Models (UAI 2009)

上記3つの推論アルゴリズムとCVB0という方法をいろいろなデータセットに対して比較している実験。また、LDA系の論文を読むとハイパーパラメータに関してパラメータから推定するというものと決めうちで与えるというものの2種類あるが実際ハイパーパラメータの学習を行うことによりどの程度精度が変わるかについても実験している。実験結果を見ると結構ハイパーパラメータの学習が結果に影響してくることがわかる。

ほかにもトピック数を変えたときなどの実験も行っており、どの推論方法を選ぶべきかの指針となる。