[機械学習] PRML 14章の混合正規回帰モデルの実装
PRML 14.5.1の混合正規回帰モデルのEMアルゴリズムによるパラメータ推定を実装してみた。
一本の直線でフィッティングする代わりに
複数の直線でフィッティングするというモデルである。
これは混合正規分布の推定と同じくEMアルゴリズムで推定できる。詳しくはPRMLの14.5.1を参照。
package chapter14; import java.util.Random; public class Main { static double generate(double x){ if(Math.abs(x) >= 0.5){ return - 0.5 * x - 1; }else{ return 0.5 * x + 1; } } public static void main(String[] args) throws Exception{ Random rand = new Random(222); int N = 20; double Phi[][] = new double[N][2]; double target[] = new double[N]; for(int i = 0 ; i < N ; ++i){ double x = rand.nextDouble() * 2 - 1; double y = generate(x) + rand.nextGaussian() * 0.1; Phi[i][0] = 1; Phi[i][1] = x; target[i] = y; System.out.println(x+"\t"+y); } MixtureOfLL ml = new MixtureOfLL(2, Phi, target, rand.nextLong()); for(int iter = 0 ; iter < 30 ; ++iter){ System.out.println(iter); System.out.println(ml.W[0][1] + " * x + " + ml.W[0][0]); System.out.println(ml.W[1][1] + " * x + " + ml.W[1][0]); ml.iteration(); } } }
package chapter14; import java.util.Random; import org.apache.commons.math.linear.Array2DRowRealMatrix; import org.apache.commons.math.linear.LUDecompositionImpl; import org.apache.commons.math.linear.RealMatrix; import static java.lang.Math.*; public class MixtureOfLL { Random rand; int N , K; double gamma[][]; double pi[]; double betaInv; double W[][]; // design matrix double Phi[][]; Array2DRowRealMatrix PhiMatrix; double target[]; public MixtureOfLL(int classNum , double Phi[][] , double target[], long seed) { int dataNum = Phi.length; int featureNum = Phi[0].length; pi = new double[classNum]; W = new double[classNum][featureNum]; gamma = new double[dataNum][classNum]; this.Phi = Phi; this.target = target; this.N = dataNum; this.K = classNum; rand = new Random(seed); PhiMatrix = new Array2DRowRealMatrix(Phi); initParameter(); } void initParameter(){ betaInv = 1.0; double sum = 0.0; for(int k = 0 ; k < pi.length ; ++k){ pi[k] = rand.nextDouble(); sum += pi[k]; } double sinv = 1.0 / sum; for(int k = 0 ; k < pi.length ; ++k){ pi[k] *= sinv; } for(int k = 0 ; k < W.length ; ++k){ for(int f = 0 ; f < W[k].length ; ++f){ W[k][f] = rand.nextDouble() - 0.5; } } } double normal(double x , double mu , double sigma){ return 1.0 / sqrt(2.0 * PI * sigma * sigma) * exp( -(x-mu) * (x - mu)/(2*sigma * sigma)); } void optimizeBeta(){ double s = 0.0; for(int n = 0 ; n < gamma.length ; ++n){ double phi[] = Phi[n]; for(int k = 0 ; k < gamma[n].length ; ++k){ double mean = 0.0; for(int f = 0 ; f < phi.length ; ++f){ mean += W[k][f] * phi[f]; } s += gamma[n][k] * (target[n] - mean) * (target[n] - mean); } } betaInv = s / N; } void optimizePi(){ for(int n = 0 ; n < gamma.length ; ++n){ for(int k = 0 ; k < gamma[n].length ; ++k){ pi[k] += gamma[n][k]; } } for(int k = 0 ; k < pi.length ; ++k){ pi[k] /= N; } } void optimizeW(){ for(int k = 0 ; k < K ; ++k){ double R[][] = new double[N][N]; for(int n = 0 ; n < N ; ++n) R[n][n] = gamma[n][k]; Array2DRowRealMatrix Rmat = new Array2DRowRealMatrix(R); RealMatrix PhiR = PhiMatrix.transpose().multiply(Rmat); double[] prt = PhiR.operate(target); RealMatrix PRP = PhiR.multiply(PhiMatrix); // solve (PRP)W_k = prt for W_k LUDecompositionImpl luImpl = new LUDecompositionImpl(PRP); W[k] = luImpl.getSolver().solve(prt); } } void eStep(){ for(int n = 0 ; n < gamma.length ; ++n){ double sum = 0.0; double phi[] = Phi[n]; for(int k = 0 ; k < gamma[0].length ; ++k){ double mean = 0.0; for(int f = 0 ; f < phi.length ; ++f){ mean += W[k][f] * phi[f]; } gamma[n][k] = pi[k] * normal(target[n] , mean , betaInv); sum += gamma[n][k]; } double sinv = 1.0 / sum; for(int k = 0 ; k < gamma[0].length ; ++k){ gamma[n][k] *= sinv; } } } void mStep(){ optimizeBeta(); optimizePi(); optimizeW(); } void iteration(){ eStep(); mStep(); } }