[機械学習] CVB0を実装してみた
On Smoothing and Inference for Topic Models (UAI 2009) pdfで述べられているLDAの推論方法であるCVB0を実装してみた。
これはTehらのCVBで述べられている期待値の2次の項までの近似の部分をさらに近似して0次の項だけで近似したものとなっている。二次近似に比べexpの計算がいらない分、高速に実行することができる。
手元で実験したところ前回の記事で実装したCGS(Collapsed Gibbs samplerより2倍程度高速であった。
Token.java 前回のものに文章重みの項が付け加わっている
public class Token { public int docId; public int wordId; public double weight; public Token(int d , int w ){ docId = d; wordId = w; weight = 1.0; } public Token(int d , int w , double ww){ docId = d; wordId = w; weight = ww; } }
LDACVB0.java
import java.util.*; public class LDACVB0 { int D; // number of document int K; // number of topic int W; // number of unique word float wordExpectation[][]; float docExpectation[][]; float topicExpectation[]; // hyper parameter double alpha, beta; Token tokens[]; // token probability float gamma[][]; Random rand; public LDACVB0(int documentNum, int topicNum, int wordNum, List<Token> tlist, int seed) { this(documentNum, topicNum, wordNum, tlist, 50.0 / topicNum, 0.1, seed); } public LDACVB0(int documentNum, int topicNum, int wordNum, List<Token> tlist, double alpha, double beta, int seed) { wordExpectation = new float[wordNum][topicNum]; topicExpectation = new float[topicNum]; docExpectation = new float[documentNum][topicNum]; D = documentNum; K = topicNum; W = wordNum; tokens = tlist.toArray(new Token[0]); gamma = new float[tokens.length][topicNum]; this.alpha = alpha; this.beta = beta; rand = new Random(seed); init(); } private void init() { for (int i = 0; i < gamma.length ; ++i) { double sum = 0.0; for(int j = 0 ; j < K ; ++j){ gamma[i][j] = rand.nextFloat(); sum += gamma[i][j]; } double sinv = 1.0 / sum; for(int j = 0 ; j < K ; ++j){ gamma[i][j] *= sinv; } } } private void updateExpectation(){ Arrays.fill(topicExpectation, 0.0f); for(int i = 0 ; i < W ; ++i){ Arrays.fill(wordExpectation[i], 0.0f); } for(int i = 0 ; i < D ; ++i){ Arrays.fill(docExpectation[i], 0.0f); } for (int i = 0; i < tokens.length; ++i) { Token t = tokens[i]; for (int k = 0; k < K; ++k) { topicExpectation[k] += t.weight * gamma[i][k]; wordExpectation[t.wordId][k] += t.weight * gamma[i][k]; docExpectation[t.docId][k] += t.weight * gamma[i][k]; } } } private void updateGamma(){ for(int i = 0 ; i < tokens.length ; ++i){ Token t = tokens[i]; double sum = 0.0; for(int k = 0 ; k < K ; ++k){ double g = gamma[i][k]; double wk = (wordExpectation[t.wordId][k] + beta) - g; double dk = (docExpectation [t.docId][k] + alpha) - g; double kk = (topicExpectation[k] + W * beta) - g; gamma[i][k] =(float)(wk * dk / kk); sum += gamma[i][k]; } double sinv = 1.0 / sum; for(int k = 0 ; k < K ; ++k){ gamma[i][k] *= sinv; } } } public void update() { updateExpectation(); updateGamma(); } public double[][] getTheta() { double theta[][] = new double[D][K]; for(int d = 0 ; d < D ; ++d){ double sum = 0.0; for(int k = 0 ; k < K ; ++k){ theta[d][k] = alpha + docExpectation[d][k]; sum += theta[d][k]; } double sinv = 1.0 / sum; for(int k = 0 ; k < K ; ++k){ theta[d][k] *= sinv; } } return theta; } public double[][] getPhi() { double phi[][] = new double[K][W]; for(int k = 0 ; k < K ; ++k){ double sum = 0.0; for(int w = 0 ; w < W ; ++w){ phi[k][w] = wordExpectation[w][k] + beta; sum += phi[k][w]; } double sinv = 1.0 / sum; for(int w = 0 ; w < W ; ++w){ phi[k][w] *= sinv; } } return phi; } }
前回のテストコードで実験したところ以下のようなトピックが抽出できた
topic : 0 graph 0.01639231445222314 weight 0.014217440489921464 pruning 0.013566353552508154 level 0.013337075394709439 network 0.012245316879923073 variables 0.01154173846067167 component 0.009558425902995188 contribution 0.009189845206519427 structure 0.009004521752442873 problem 0.008286439753653205 topic : 1 learning 0.06361167333143003 algorithm 0.05496684033178554 gradient 0.029577348747162977 weight 0.02347752368853067 error 0.021588484323139926 convergence 0.015966324298809526 rate 0.014369645761162692 descent 0.012982777561825492 function 0.010329638976695813 parameter 0.009962002418224576 topic : 2 direction 0.03520721401428869 motion 0.03398034949443053 eye 0.020649724516946414 cell 0.018779895385126283 head 0.015734231197231024 velocity 0.015713319780626806 model 0.014333959786142747 visual 0.01297128537367679 position 0.012224118942964374 system 0.011606187759166229