[機械学習] PRML 14章の混合正規回帰モデルの実装

PRML 14.5.1の混合正規回帰モデルのEMアルゴリズムによるパラメータ推定を実装してみた。

混合正規回帰モデルは下図のようなデータ点に対して

一本の直線でフィッティングする代わりに

複数の直線でフィッティングするというモデルである。

これは混合正規分布の推定と同じくEMアルゴリズムで推定できる。詳しくはPRMLの14.5.1を参照。

package chapter14;

import java.util.Random;

public class Main {
  static double generate(double x){
    if(Math.abs(x) >= 0.5){
      return - 0.5 * x - 1;
    }else{
      return 0.5 * x + 1;
    }
  }
  public static void main(String[] args) throws Exception{
    Random rand = new Random(222);
    int N = 20;
    double Phi[][] = new double[N][2];
    double target[] = new double[N];
    for(int i = 0 ; i < N ; ++i){
      double x = rand.nextDouble() * 2 - 1;
      double y = generate(x) + rand.nextGaussian() * 0.1;
      Phi[i][0] = 1;
      Phi[i][1] = x;
      target[i] = y;
      System.out.println(x+"\t"+y);
    }
    MixtureOfLL ml = new MixtureOfLL(2, Phi, target, rand.nextLong());
    for(int iter = 0 ; iter < 30 ; ++iter){
      System.out.println(iter);
      System.out.println(ml.W[0][1] + " * x + " + ml.W[0][0]);
      System.out.println(ml.W[1][1] + " * x + " + ml.W[1][0]);
      ml.iteration();      
    }
  }
}
package chapter14;

import java.util.Random;

import org.apache.commons.math.linear.Array2DRowRealMatrix;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;

import static java.lang.Math.*;
public class MixtureOfLL {
  Random rand;
  int N , K;
  double gamma[][];
  double pi[];
  double betaInv;
  double W[][];
  // design matrix 
  double Phi[][];
  Array2DRowRealMatrix PhiMatrix;
  double target[];
  public MixtureOfLL(int classNum , double Phi[][] , double target[], long seed) {
    int dataNum = Phi.length;
    int featureNum = Phi[0].length;
    pi = new double[classNum];
    W  = new double[classNum][featureNum];
    gamma = new double[dataNum][classNum];
    this.Phi = Phi;
    this.target = target;
    this.N = dataNum;
    this.K = classNum;
    rand = new Random(seed);
    PhiMatrix = new Array2DRowRealMatrix(Phi);
    initParameter();
  }
  
  void initParameter(){
    betaInv = 1.0;
    double sum = 0.0;
    for(int k = 0 ; k < pi.length ; ++k){
      pi[k] = rand.nextDouble();
      sum += pi[k];
    }
    double sinv = 1.0 / sum;
    for(int k = 0 ; k < pi.length ; ++k){
      pi[k] *= sinv;
    }
    for(int k = 0 ; k < W.length ; ++k){
      for(int f = 0 ; f < W[k].length ; ++f){
        W[k][f] = rand.nextDouble() - 0.5;
      }
    }
  }
  double normal(double x , double mu , double sigma){
    return 1.0 / sqrt(2.0 * PI * sigma * sigma) * exp( -(x-mu) * (x - mu)/(2*sigma * sigma));
  }

  void optimizeBeta(){
    double s = 0.0;
    for(int n = 0 ; n < gamma.length ; ++n){
      double phi[] = Phi[n];
      for(int k = 0 ; k < gamma[n].length ; ++k){
        double mean = 0.0;
        for(int f = 0 ; f < phi.length ; ++f){
          mean += W[k][f] * phi[f];
        }
        s += gamma[n][k] * (target[n] - mean) * (target[n] - mean);
      }
    }
    betaInv = s / N;
  }
  
  void optimizePi(){
    for(int n = 0 ; n < gamma.length ; ++n){
      for(int k = 0 ; k < gamma[n].length ; ++k){
        pi[k] += gamma[n][k];
      }
    }
    for(int k = 0 ; k < pi.length ; ++k){
      pi[k] /= N;
    }
  }
  
  void optimizeW(){
    for(int k = 0 ; k < K ; ++k){
      double R[][] = new double[N][N];
      for(int n = 0 ; n < N ; ++n)
        R[n][n] = gamma[n][k];
      Array2DRowRealMatrix Rmat = new Array2DRowRealMatrix(R);
      RealMatrix PhiR = PhiMatrix.transpose().multiply(Rmat);
      double[] prt = PhiR.operate(target);
      RealMatrix PRP = PhiR.multiply(PhiMatrix);
      // solve (PRP)W_k = prt for W_k
      LUDecompositionImpl luImpl = new LUDecompositionImpl(PRP);
      W[k] = luImpl.getSolver().solve(prt);
    }
  }
  void eStep(){
    for(int n  = 0 ; n < gamma.length ; ++n){
      double sum = 0.0;
      double phi[] = Phi[n];
      for(int k = 0 ; k < gamma[0].length ; ++k){
        double mean = 0.0;
        for(int f = 0 ; f < phi.length ; ++f){
          mean += W[k][f] * phi[f];
        }
        gamma[n][k] = pi[k] * normal(target[n] , mean , betaInv);
        sum += gamma[n][k];
      }
      double sinv = 1.0 / sum;
      for(int k = 0 ; k < gamma[0].length ; ++k){
        gamma[n][k] *= sinv;
      }
    }
  }
  
  void mStep(){
    optimizeBeta();
    optimizePi();
    optimizeW();
  }
  
  void iteration(){
    eStep();
    mStep();
  }
}