ICML2011読み会で発表してきました

id:nokunoさんの主催したICML2011読み会で発表してきました

自分はSparse Additive Generative Model for Textという論文について発表しました

発表資料:

他の発表についていくつかメモ

  • Infinite SVM: a Dirichlet Process Mixture of Large-margin Kernel Machines
    • データをある程度セグメントに分けて、それぞれで識別モデルを構成したほうが性能が上がるよねという話という理解
    • これは実務的にもユーザモデルの作成のときに性別で分けて独立にモデルを作成したほうが性能が高いことがあったり、決定木的な方法が割とうまくいくことからも納得できる
  • GoDec: Randomized Low-rank & Sparse Matrix Decomposition in Noisy Case
    • この話を聴いてSAGEでもバックグラウンドのmは学習で求めてもいいのではという気がした
  • Parallel Coordinate Descent for L1-Regularized Loss Minimization
    • SGDはデータ点を一つ持ってきてパラメータを最適化する方法で、SCDはパラメータの次元の一つを持ってきてそれについて最適化する手法
    • この論文では複数の次元を一気に並列で最適化する場合にどれくらいの次元数を一気に更新していいかの理論値をあたえた
    • SGDに比べて収束が早かったり、損失関数によっては解析的に解が求まるため学習率みたいなパラメータを与えなくてもいいというメリットがある
    • この話を聴いてて、今までSGDとSCDをごっちゃにしてたことに気付いて恥ずかしかった
      • 言い訳をするとliblinearを使ってるアルゴリズムの一つである"A dual coordinate descent method for large-scale linear SVM"は双対空間でSCDをしてるので双対だと次元がデータ点に対応するためデータ点を一個持ってきて最適化とか思ってました。学習率とか出てきてない時点で気付けよという話ですが

[NLP] 第五回自然言語処理勉強会で発表してきました

id:nokunoさんの主催する自然言語処理勉強会で、Infer.NETを使ってLDAを実装してみたというタイトルで発表してきました。
Infer.NETはMicrosoftが公開しているグラフィカルモデル上でベイズ推定を行うためのフレームワークです。このようなものを使うことにより、具体的な推論アルゴリズムの導出を人が行うことなく、生成モデルを記述するだけで事後分布の推論が可能になり、簡単に確率モデルを問題に合わせて定義するということが行えるようになるといいなと思って、今回紹介しました。

参考文献

Infer.NETを使う上で参考になるかと思われる書籍をあげておきます。

パターン認識と機械学習 上 - ベイズ理論による統計的予測

パターン認識と機械学習 上 - ベイズ理論による統計的予測

パターン認識と機械学習 下 - ベイズ理論による統計的予測

パターン認識と機械学習 下 - ベイズ理論による統計的予測

まず、確率分布についてはある程度理解する必要があるので(測度論とかまで踏み込む必要は全くない)、おなじみPRMLの上巻の1,2章および下巻の8章のグラフィカルモデルの章が参考になります。

予測にいかす統計モデリングの基本―ベイズ統計入門から応用まで (KS理工学専門書)

予測にいかす統計モデリングの基本―ベイズ統計入門から応用まで (KS理工学専門書)

本書は最近出版された本で、まえがきにあるようにベイズモデリングに関して読み物レベルの記事と専門書の橋渡しを行い、"自分の問題に対してベイズモデリングを行い、モデルに基づいて予測し、その予測結果を通じてモデルの評価を行い、必要があればモデルを改良する、この一連の流れを学ぶ"ことを目的とした本です。

扱っている題材は時系列モデルで、鎖状構造グラフィカルモデルはInfer.NETがあまり得意としないところなのですが、それでもモデリングについての考えは参考になると思います。

初ビット列がでるまでの期待時間

kinabaさんのブログの 「無限ビット列を作ったときに最初に "001" が並ぶインデックスの期待値は 8。では、"000" なら?」という問題に対して、マルチンゲールを使った解説をしてみます。

いま無限に生成されるビット列に対して次に何がでるかを賭けるギャンブルを考えます。配当はフェアな賭けで賭け金の2倍返しとします。ここで000が出現したら賭けは終了するとします。
このとき毎時刻ギャンブラーが1$持ってきてつぎのように賭けます。

1. はじめに0が出ることに賭けて、勝ったら次へ、そうでなければ終了
2. 再び0が出ることに前の儲け2$を全額賭ける、勝ったら次へ、そうでなければ終了
3. 再び0が出ることに前の儲け4$を全額賭ける、勝ったら000が出てるので賭け自体が終了、そうでなければ終了

この話はたとえば010がでる期待値を考えるときは2.のところで0にではなく1に賭けることになります。

たとえばビット列が1011000と出た場合

時刻 0 1 2 3 4 5 6
出現ビット 1 0 1 1 0 0 0
個々の時刻に入ったギャンブラーの支払い 1 1 1 1 1 1 1
個々の時刻に入ったギャンブラーの儲け 0 0 0 0 8 4 2

この場合すべてのギャンブラーが支払った額は7であり、儲けの総額は14になります。

ここで儲けの総額というのは000で停止という条件の元では終了時には常に14になります。一方で支払いは毎時刻1ずつ支払うので000が出現するまでの時刻Tだった場合支払いもTになります。

ここですべてのギャンブラーの収支を考えると-T + 14となります。このため収支の合計の期待値は

14 - E[T]

となります。ここで求めたかった値であるE[T]がでてきました。ではこの値がいくつかですが、基本的にn$払えば半分の確率で2n$当たるという賭けなのでマイナスになることはなさそうです。一方でプラスにもなることはなさそうです(なるのならこんなブログとか書いてないでラスベガスにいってルーレットでぼろ儲けしてます)。確率論ではこの期待値が0になりますというのを保証しており、それがOptional stopping theorem(任意抽出定理)です。簡単にいうと公平な賭けの期待値はいつやめても0になるということを意味してます。
これにより

14 - E[T] = 0 => E[T] = 14

となります。

この辺のことを詳しく知りたい方は以下の本が参考になるかと思います

マルチンゲールによる確率論

マルチンゲールによる確率論

確率と計算 ―乱択アルゴリズムと確率的解析―

確率と計算 ―乱択アルゴリズムと確率的解析―

wat-arrayを使った2次元探索プログラム

岡野原氏の作成したwavelet木を使った高速配列処理ライブラリwat-arrayを利用して、2次元探索のプログラムを書いてみた。
なお、自分はwavelet木のアルゴリズムについては全く分かってないですが、wat-arrayでは配列に対して、操作を行うインターフェイスがしっかり与えられているのでそれを見ながら作りました。

問題定義

2次元座標の集合P={(x,y)}が与えられる。Queryとして(xs,xe,ys,ye)が与えられたときにPの中でxs <= x < x, ys <= y < yeを満たす点の数を答えるというものを考える。

なお、Pの内容は途中で変化したりすることはないものとする(変更が加わった場合は一から作り直す)。

インターフェースとしては次のようになる、2次元座標の表現にはpairを用いることにする。

namespace wat2DSearch{
  typedef std::pair<uint64_t , uint64_t> Point;
  class Wat2DSearch{
    public:
      Wat2DSearch(const std::vector<Point> & P);
      //Compute the frequency of points where xs <= x < xe and ys <= y < ye
      uint64_t FreqRange2D(uint64_t xs , uint64_t xe , uint64_t ys , uint64_t ye);
  };
}

使い方の例

int main(){
  vector<wat2DSearch::Point> P;
  P.push_back(make_pair(1,2));
  P.push_back(make_pair(2,1));
  P.push_back(make_pair(2,4));
  P.push_back(make_pair(4,2));
  wat2DSearch::Wat2DSearch wat(P);
  cout << wat.FreqRange2D(0 , 5 , 0 , 5) << endl;
  cout << wat.FreqRange2D(2 , 3 , 2 , 4) << endl;
  cout << wat.FreqRange2D(2 , 3 , 2 , 5) << endl;
  return 0;
}

以降でどのようにしてこのFreqRange2Dを実現するかについて述べる。

アルゴリズム

Pの値をxの昇順でソートする。たとえばP={(2,1),(4,2),(1,3),(2,4)}のときソート後は

index 0 1 2 3
座標 (1,3) (2,1) (2,3) (4,2)

となる。これをx座標とy座標に分解すると

index 0 1 2 3
x座標 1 2 2 4
y座標 3 1 4 2

となる。

たとえばxs=2,xe=3のとき、x座標が昇順に並んでいることからy座標の配列のうちindexが[1ー3)までの範囲に関してys <= y < yeを満たすものの数を求めることと同じであることが分かる。
与えられた部分配列の中で指定されたRangeに収まるような要素の数を求めるのはwatArray::FreqRangeを用いて求めることができる。
また、与えられたxs,xeに対してindexを求める部分もxが昇順に並んでいることに着目するとxsに対応するindexはx座標の配列の中でxsより小さいものの数であることが分かり、
これはwatArray::FreqSumを用いて簡単に求まる。xeについても同様である。

プログラム

以上を踏まえて作ったのが以下のコードである、watArrayを使うことにより簡潔に書けていることが分かる。

ヘッダ

#include <wat_array/wat_array.hpp>
#include <utility>
#include <vector>
namespace wat2DSearch{
  typedef std::pair<uint64_t , uint64_t> Point;
  class Wat2DSearch{
    public:
      Wat2DSearch(const std::vector<Point> & P);
      uint64_t FreqRange2D(uint64_t xs , uint64_t xe , uint64_t ys , uint64_t ye);
    private:
      wat_array::WatArray xwat;
      wat_array::WatArray ywat;
  };
}

実装部分(2011/1/11 FreqRangeの境界条件での仕様に対応)

#include "Wat2DSearch.hpp"
#include <algorithm>
using namespace std;
namespace wat2DSearch{
  Wat2DSearch::Wat2DSearch(const vector<Point> & P): xwat(), ywat(){
    vector<Point> tmp(P);
    sort(tmp.begin() , tmp.end());
    vector<uint64_t> xArray , yArray;
    for(vector<Point>::iterator it = tmp.begin() ; it != tmp.end() ; ++it){
      xArray.push_back(it->first);
      yArray.push_back(it->second);
    }
    xwat.Init(xArray);
    ywat.Init(yArray);
  }
  uint64_t Wat2DSearch::FreqRange2D(uint64_t xs , uint64_t xe , uint64_t ys , uint64_t ye){
    if(xs >= xwat.alphabet_num() || ys >= ywat.alphabet_num())return 0;
    if(xs >= xe || ys >= ye)return 0;
    xe = min(xe , xwat.alphabet_num());
    ye = min(ye , ywat.alphabet_num());
    uint64_t xsPos = xwat.FreqSum(0 , xs);
    uint64_t xePos = xwat.FreqSum(0 , xe);
    if(ye == ywat.alphabet_num()){ // ye >= ywat.alphabet_num()のときNOT_FOUNDが返ってくるため
      return ywat.FreqRange(ys , ye - 1, xsPos , xePos) +
        ywat.Rank(ye - 1 , xePos) - ywat.Rank(ye - 1 , xsPos);
    }
    return ywat.FreqRange(ys , ye , xsPos , xePos);
  }
}

備考

なお、今回の実装は特に論文とか読んだわけではないので、論文とかでやってる二次元探索の方法とは違う可能性があります。たとえばこの実装だと3次元以上のときの拡張とかが難しかったり、wat_arrayを2つ使ってますがもっと少なくてもいいかもしれません。あとx,yの範囲Rが大きければドキュメントにかかれてあることが確かであれば点数によらずO(log R)で計算量が増えるので適宜座標圧縮的なことが必要かもしれません。

NIPS Oral Sessionのリスト

参考までにOral Sessionの他の論文のリストも挙げておきます(Oral SessionではInvited talkとかもあるけどそれは除外)、accepted paperが300本近くあるなかでOral Sessionは20本しかなく非常に競争率が高いです。

  • Over-complete representations on recurrent neural networks can support persistent percepts
  • A rational decision making framework for inhibitory control
  • A Theory of Multiclass Boosting
  • Online Learning: Random Averages, Combinatorial Parameters, and Learnability
  • Fast global convergence rates of gradient methods for high-dimensional statistical recovery
  • Structured sparsity-inducing norms through submodular functions
  • Semi-Supervised Learning with Adversarially Missing Label Information
  • MAP estimation in Binary MRFs via Bipartite Multi-cuts
  • A Dirty Model for Multi-task Learning
  • Linear Complementarity for Regularized Policy Evaluation and Improvement
  • The Multidimensional Wisdom of Crowds (今回紹介したやつ)
  • Construction of Dependent Dirichlet Processes based on Poisson Processes
  • Slice sampling covariance hyperparameters of latent Gaussian models
  • Tree-Structured Stick Breaking for Hierarchical Data
  • Identifying graph-structured activation patterns in networks
  • Phoneme Recognition with Large Hierarchical Reservoirs
  • Identifying Patients at Risk of Major Adverse Cardiovascular Events Using Symbolic Mismatch
  • On the Convexity of Latent Social Network Inference
  • Learning to combine foveal glimpses with a third-order Boltzmann machine
  • Humans Learn Using Manifolds, Reluctantly

[機械学習] NIPS読む会で発表してきました

id:nokunoさん主催のNIPS読む会で発表してきました。

僕が発表した論文はThe Multidimensional Wisdom of Crowdsというものです。この論文を選んだのはOral Sessionの論文の中でタイトルがちょっと面白そうだったのでという理由です。

機械学習では大量のラベル付けしたデータが必要になることが多いのですが、データ自体は例えば画像ならGoogle画像検索やflickrなどから大量に取ってくることができるのですが、それにラベル付けをするとなるとどうしても人手が必要であることが多く困ってしまうことがあります。
こういったときのためにAmazon Mechanical Turkという画像にラベルをつけるとか簡単なアノテーションを安く大量に行ってもらうサービスがあるのですが、このサービスを普通に使うとアノテータの質が良くなかったりしてラベルの質が良くないという問題があるので、複数のアノテータの結果で多数決を取って決めるとかいう方法を取る必要があります。今回の論文の手法をつかうことにより単純に多数決をとる方法など先行研究よりも高い精度でラベルの推定が行えるという内容です。

詳しくはスライドもしくは元論文をご参考ください。

mean-shiftを実装してみた

次回のCVIM勉強会でmean-shiftについて話すことになってしまったので、理解のためにmean-shiftアルゴリズムを実装してみた。
カーネルは一番簡単なフラットカーネルを利用し、また画像もグレイスケール画像のみを扱うためピクセルの値は[0,256]の一次元データとみなす。
またナイーブな実装だと512*512ぐらいの画像でもすごい時間がかかるので二分探索とFenwickTree(動的に更新する必要はないので別に使う必要なかった、累積和を保存した配列に変更)を使ってグレイスケール画像かつフラットカーネルであることを利用して高速化した関数meanShiftLoopFastを用意した。

元画像は下のようになっている。

bandWidth=10のときは下のようになり、肌の部分が同じ値に収束していることがわかる。

bandWidth=20のときは以下のようになり、やや平滑化されすぎていることがわかる。

import java.awt.image.BufferedImage;

import java.io.File;
import java.util.Arrays;

import javax.imageio.ImageIO;

public class MeanShift {
  int getHigh(double data[] , double upperBound){
    int lo = 0;
    int hi = data.length - 1;
    while(lo < hi){
      int mid = lo + (hi - lo + 1) / 2;
      if(data[mid] > upperBound){
        hi = mid - 1;
      }else{
        lo = mid;
      }
    }
    return lo;
  }
  int getLow(double data[] , double lowerBound){
    int lo = 0;
    int hi = data.length - 1;
    while(lo < hi){
      int mid = lo + (hi - lo) / 2;
      if(data[mid] >= lowerBound){
        hi = mid;
      }else{
        lo = mid + 1;
      }
    }
    return lo;
  }

  double[] meanShiftLoopFast(double input[] , int maxIter , double bandwidth){
    int N = input.length;
    double data[] = input.clone();
    Arrays.sort(data);
    double accumulate[] = new double[N];
    accumulate[0] = data[0];
    for(int i = 1 ; i < N ; ++i){
      accumulate[i] = data[i] + accumulate[i - 1];
    }
    double[] output = new double[N];
    for(int i = 0 ; i < N ; ++i){
      double mean = input[i];
      for(int iter = 0 ; iter < maxIter ; ++iter){
        int div = 0;
        int l = getLow( data , mean - bandwidth);
        int h = getHigh(data, mean + bandwidth);        
        double sum = l == 0 ? accumulate[h] : accumulate[h] -  accumulate[l - 1];
        div = h - l + 1;
        double nextMean = sum / div;
        if(Math.abs(mean - nextMean) < 1.0E-9){
          break;
        }
        mean = nextMean;
      }
      output[i] = mean;
    }
    return output;
  }
  
  double[] meanShiftLoopNaive(double input[] , int maxIter , double bandwidth){
    int N = input.length;
    double[] output = new double[N];
    for(int i = 0 ; i < N ; ++i){
      double mean = input[i];
      for(int iter = 0 ; iter < maxIter ; ++iter){
        int div = 0;
        double sum = 0;
        for(int j = 0 ; j < N ; ++j){
          if(Math.abs(input[j] - mean) <= bandwidth){
            sum += input[j];
            ++div;
          }
        }
        double nextMean = sum / div;
        if(Math.abs(mean - nextMean) < 1.0E-9){
          break;
        }
        mean = nextMean;
      }
      output[i] = mean;
    }
    return output;
  }

  static void write(double data[] , int H , BufferedImage bin){
    int W = data.length / H;
    for(int h = 0 ; h < H ; ++h){
      for(int w = 0 ; w < W ; ++w){
        int d = (int)data[h * W + w];
        int rgb = (d << 16) | (d << 8) | d;
        bin.setRGB(w, h, rgb);
      }
    }
  }
  public static void main(String[] args) throws Exception{
    BufferedImage bin = ImageIO.read(new File("data/Lenna.png"));
    int W = bin.getWidth();
    int H = bin.getHeight();
    double in[] = new double[H * W];
    for(int h = 0 ; h < H ; ++h){
      for(int w = 0 ; w < W ; ++w){
        int rgb = bin.getRGB(w, h);
        int r = (rgb >> 16) & 0xff;
        int g = (rgb >>  8) & 0xff;
        int b = (rgb      ) & 0xff;
        int m = (2 * r + 4 * g + b) / 7;
        in[h * W + w] = m;
      }
    }
    MeanShift m = new MeanShift();
    int maxIter = 200;
    int bandWidth = 15;
    long t = System.currentTimeMillis();
    double[] out = m.meanShiftLoopFast(in, maxIter , bandWidth);
    System.err.println(System.currentTimeMillis() - t);
    write(out , H , bin);
    String outputFileName = String.format("data/output%02d.png", bandWidth);
    ImageIO.write(bin, "png", new File(outputFileName)); 
  }
}