[機械学習] LDAのコードを書いてみた
昔書いたことがあったけど、どこかにいってしまったのでもう一度書いてみた。推論方法にはギブスサンプリングと変分ベイズの2つがあるけど、導出も実装もより楽なcollapsed gibbs sampling(Griffiths and Steyvers, PNAS, 2004)の方を採用。
Token.java
package lda; public class Token { public int docId; public int wordId; public Token(int d , int w){ docId = d; wordId = w; } }
LDA.java
package lda; import java.util.*; public class LDA { int D; // number of document int K; // number of topic int W; // number of unique word int wordCount[][]; int docCount[][]; int topicCount[]; // hyper parameter double alpha, beta; Token tokens[]; double P[]; // topic assignment int z[]; Random rand; public LDA(int documentNum, int topicNum, int wordNum, List<Token> tlist, int seed) { this(documentNum, topicNum, wordNum, tlist, 50.0 / topicNum, 0.1, seed); } public LDA(int documentNum, int topicNum, int wordNum, List<Token> tlist, double alpha, double beta, int seed) { wordCount = new int[wordNum][topicNum]; topicCount = new int[topicNum]; docCount = new int[documentNum][topicNum]; D = documentNum; K = topicNum; W = wordNum; tokens = tlist.toArray(new Token[0]); z = new int[tokens.length]; this.alpha = alpha; this.beta = beta; P = new double[K]; rand = new Random(seed); init(); } private void init() { for (int i = 0; i < z.length; ++i) { Token t = tokens[i]; int assign = rand.nextInt(K); wordCount[t.wordId][assign]++; docCount[t.docId][assign]++; topicCount[assign]++; z[i] = assign; } } private int selectNextTopic(Token t) { for (int k = 0; k < P.length; ++k) { P[k] = (wordCount[t.wordId][k] + beta) * (docCount[t.docId][k] + alpha) / (topicCount[k] + W * beta); if (k != 0) { P[k] += P[k - 1]; } } double u = rand.nextDouble() * P[K - 1]; for (int k = 0; k < P.length; ++k) { if (u < P[k]) { return k; } } return K - 1; } private void resample(int tokenId) { Token t = tokens[tokenId]; int assign = z[tokenId]; // remove from current topic wordCount[t.wordId][assign]--; docCount[t.docId][assign]--; topicCount[assign]--; assign = selectNextTopic(t); wordCount[t.wordId][assign]++; docCount[t.docId][assign]++; topicCount[assign]++; z[tokenId] = assign; } public void update() { for (int i = 0; i < z.length; ++i) { resample(i); } } public double[][] getTheta() { double theta[][] = new double[D][K]; for (int i = 0; i < D; ++i) { double sum = 0.0; for (int j = 0; j < K; ++j) { theta[i][j] = alpha + docCount[i][j]; sum += theta[i][j]; } // normalize double sinv = 1.0 / sum; for (int j = 0; j < K; ++j) { theta[i][j] *= sinv; } } return theta; } public double[][] getPhi() { double phi[][] = new double[K][W]; for (int i = 0; i < K; ++i) { double sum = 0.0; for (int j = 0; j < W; ++j) { phi[i][j] = beta + wordCount[j][i]; sum += phi[i][j]; } // normalize double sinv = 1.0 / sum; for (int j = 0; j < W; ++j) { phi[i][j] *= sinv; } } return phi; } }
基本的な使い方としてはコンストラクタで(文章ID, 単語ID)のリストと潜在クラスの数を渡して、一定回数update()を呼び出す。推定したパラメータはgetTheta,getPhiで取得する。注意としてこの場合単一のサンプルからパラメータを推定しているが、まじめにやろうとすると複数のサンプルを取って平均する必要がある。
ThetaはD×Kの行列で文章が持っている潜在クラスの分布となる。またPhiはK×Wの行列でそのクラスにおける単語の出現確率を表す。
例としてUCI Machine Learning Repository: Bag of Words Data SetのNIPSデータセットでK=50とした場合の各クラスにおけるトピックの単語生成確率の一部を示す。これを見ると反復が進むに連れ、単語の出現確率がトピックによって大きく違ってくることがわかる。
1反復目
topic : 0 network 0.014092019554221354 model 0.00961357020472769 input 0.007852382258297598 learning 0.0075504643246238675 function 0.006795669490439543 unit 0.006267313106510514 neural 0.00594023534503064 algorithm 0.005764116550387631 result 0.005059641371815594 set 0.00485836274936644 topic : 5 network 0.01344023110856413 learning 0.009439113815432327 neural 0.008508036206338762 model 0.00795442249282367 input 0.007778272674887049 unit 0.0069226878449091785 function 0.006218088573162697 data 0.006218088573162697 system 0.006117431534341771 set 0.005890953196994688 topic : 49 network 0.011941178364697854 model 0.009365192988368095 learning 0.008919349365541791 input 0.008597351193500571 set 0.008126738480517249 function 0.0075818184970628776 neural 0.006120442177798879 system 0.00587275127622871 algorithm 0.005649829464815558 data 0.005501214923873456
10反復目
topic : 0 network 0.025539408544249742 linear 0.014513134050736839 contour 0.013714128652656194 capacity 0.013234725413807807 problem 0.011988276992802 wavelet 0.010582027492180065 function 0.007481886547627161 direction 0.007481886547627161 rotation 0.007449926331703935 point 0.007002483308778774 topic : 5 network 0.06928157250824087 output 0.0305151994792456 training 0.02848098200396099 neural 0.02346324556492562 recognition 0.021215919592230244 net 0.019821027609177942 layer 0.01722497864071949 set 0.01650815914942872 input 0.016488785649664106 character 0.015035773167317957 topic : 49 unit 0.09013773693760363 hidden 0.04481769982166848 training 0.03433223140859152 network 0.03115149308244893 weight 0.029243050086763374 rbf 0.01978857961388436 error 0.018406603651491372 output 0.014216803511538027 generalization 0.013975506121278933 input 0.013865825489342983
50反復目
topic : 0 segment 0.02273003693935464 contour 0.017135271693618368 region 0.0159885185687159 local 0.012166008152374348 surface 0.010637003985837726 wavelet 0.010463253512367656 shape 0.010393753322979626 point 0.009386000576853216 group 0.009142749913995118 rotation 0.00882999906174899 topic : 5 recognition 0.030339713633886357 character 0.028085248199496936 network 0.022057003668412175 digit 0.019434962347980998 system 0.015122071951757753 neural 0.014582960652229847 input 0.014043849352701941 set 0.011666858622965267 output 0.010833686614603958 training 0.010711161319256706 topic : 49 unit 0.0814758439901167 network 0.07836241125756047 input 0.0627824351143984 output 0.05246838840778206 hidden 0.05054651635064858 layer 0.04809933259789862 weight 0.04242340378916441 training 0.034710290599868715 pattern 0.022192497267739322 set 0.01342876068721066
100反復目
topic : 0 region 0.02414254659304815 segment 0.021018255724856386 local 0.015230095800627433 contour 0.015230095800627433 surface 0.013355521279712375 depth 0.010527215862191408 shape 0.00980369587166279 texture 0.00950771042099199 edge 0.00901440133654066 grouping 0.008817077702760127 topic : 5 character 0.03326499065788623 recognition 0.03314766417641939 digit 0.02338023459430484 system 0.012674193160455554 set 0.010239668670018593 handwritten 0.010122342188551752 image 0.009418383299750702 segmentation 0.008919745753516626 feature 0.008831750892416494 pattern 0.008802419272049784 topic : 49 input 0.07237126180147059 network 0.0704561431890157 unit 0.06993275673592245 output 0.05644366042210978 layer 0.052220883357380066 hidden 0.04764125189281403 weight 0.04369206320129215 training 0.035734210084942354 pattern 0.018307820226269023 set 0.01409693830820052
200反復目
topic : 0 image 0.05129741438909944 images 0.01835710548219398 region 0.017218428137263916 pixel 0.01667620083015436 segment 0.01393795292925111 local 0.013368614256786078 vision 0.012663718757543655 contour 0.012040157354367668 scene 0.009817025395218492 surface 0.009464577645597282 topic : 5 character 0.03446915797682203 recognition 0.0332252221639097 digit 0.02409291290326061 system 0.012988510280677329 image 0.011107436612370871 set 0.010348939165473107 handwritten 0.010348939165473107 feature 0.009893840697334448 pixel 0.009742141207954894 pattern 0.009681461412203073 topic : 49 input 0.07946527053340079 network 0.07009151773133951 unit 0.06570943206559061 output 0.05974013426894054 layer 0.0505389045245859 hidden 0.04640985267680675 weight 0.041579207061020564 training 0.03748465982478275 net 0.01888667420400594 pattern 0.017690514337239834
実験に使ったのは以下のコードである
class PComp implements Comparable<PComp>{ int id; double prob; @Override public int compareTo(PComp o) { return Double.compare(prob, o.prob); } } public class NIPS { public static void main(String[] args) throws Exception{ Scanner sc = new Scanner(new File("data/docword.nips.txt")); int D = sc.nextInt(); int W = sc.nextInt(); int N = sc.nextInt(); List<Token> tlist = new ArrayList<Token>(); for(int i = 0 ; i < N ; ++i){ int did = sc.nextInt() - 1; int wid = sc.nextInt() - 1; int count = sc.nextInt(); for(int c = 0 ; c < count ; ++c){ tlist.add(new Token(did , wid)); } } String words[] = new String[W]; sc = new Scanner(new File("data/vocab.nips.txt")); for(int i = 0 ; i < W ; ++i){ words[i] = sc.nextLine(); } int K = 50; LDA lda = new LDA(D, K , W, tlist, 777); for(int i = 0 ; i <= 200 ; ++i){ lda.update(); if(i % 10 == 0){ PrintWriter out = new PrintWriter("output/wordtopic"+i+".txt"); double phi[][] = lda.getPhi(); outputWordTopicProb(phi , words, out); out.close(); } } } private static void outputWordTopicProb(double phi[][], String[] words, PrintWriter out) { int K = phi.length; int W = phi[0].length; for(int k = 0 ; k < K ; ++k){ out.println("topic : " + k); PComp ps[] = new PComp[W]; for(int w = 0 ; w < W ; ++w){ PComp pc = new PComp(); pc.id = w; pc.prob = phi[k][w]; ps[w] = pc; } Arrays.sort(ps); for(int i = 0 ; i < 10 ; ++i){ // output related word PComp p = ps[W - 1 - i]; out.println(words[p.id]+" "+p.prob); } } } }