mean-shiftを実装してみた

次回のCVIM勉強会でmean-shiftについて話すことになってしまったので、理解のためにmean-shiftアルゴリズムを実装してみた。
カーネルは一番簡単なフラットカーネルを利用し、また画像もグレイスケール画像のみを扱うためピクセルの値は[0,256]の一次元データとみなす。
またナイーブな実装だと512*512ぐらいの画像でもすごい時間がかかるので二分探索とFenwickTree(動的に更新する必要はないので別に使う必要なかった、累積和を保存した配列に変更)を使ってグレイスケール画像かつフラットカーネルであることを利用して高速化した関数meanShiftLoopFastを用意した。

元画像は下のようになっている。

bandWidth=10のときは下のようになり、肌の部分が同じ値に収束していることがわかる。

bandWidth=20のときは以下のようになり、やや平滑化されすぎていることがわかる。

import java.awt.image.BufferedImage;

import java.io.File;
import java.util.Arrays;

import javax.imageio.ImageIO;

public class MeanShift {
  int getHigh(double data[] , double upperBound){
    int lo = 0;
    int hi = data.length - 1;
    while(lo < hi){
      int mid = lo + (hi - lo + 1) / 2;
      if(data[mid] > upperBound){
        hi = mid - 1;
      }else{
        lo = mid;
      }
    }
    return lo;
  }
  int getLow(double data[] , double lowerBound){
    int lo = 0;
    int hi = data.length - 1;
    while(lo < hi){
      int mid = lo + (hi - lo) / 2;
      if(data[mid] >= lowerBound){
        hi = mid;
      }else{
        lo = mid + 1;
      }
    }
    return lo;
  }

  double[] meanShiftLoopFast(double input[] , int maxIter , double bandwidth){
    int N = input.length;
    double data[] = input.clone();
    Arrays.sort(data);
    double accumulate[] = new double[N];
    accumulate[0] = data[0];
    for(int i = 1 ; i < N ; ++i){
      accumulate[i] = data[i] + accumulate[i - 1];
    }
    double[] output = new double[N];
    for(int i = 0 ; i < N ; ++i){
      double mean = input[i];
      for(int iter = 0 ; iter < maxIter ; ++iter){
        int div = 0;
        int l = getLow( data , mean - bandwidth);
        int h = getHigh(data, mean + bandwidth);        
        double sum = l == 0 ? accumulate[h] : accumulate[h] -  accumulate[l - 1];
        div = h - l + 1;
        double nextMean = sum / div;
        if(Math.abs(mean - nextMean) < 1.0E-9){
          break;
        }
        mean = nextMean;
      }
      output[i] = mean;
    }
    return output;
  }
  
  double[] meanShiftLoopNaive(double input[] , int maxIter , double bandwidth){
    int N = input.length;
    double[] output = new double[N];
    for(int i = 0 ; i < N ; ++i){
      double mean = input[i];
      for(int iter = 0 ; iter < maxIter ; ++iter){
        int div = 0;
        double sum = 0;
        for(int j = 0 ; j < N ; ++j){
          if(Math.abs(input[j] - mean) <= bandwidth){
            sum += input[j];
            ++div;
          }
        }
        double nextMean = sum / div;
        if(Math.abs(mean - nextMean) < 1.0E-9){
          break;
        }
        mean = nextMean;
      }
      output[i] = mean;
    }
    return output;
  }

  static void write(double data[] , int H , BufferedImage bin){
    int W = data.length / H;
    for(int h = 0 ; h < H ; ++h){
      for(int w = 0 ; w < W ; ++w){
        int d = (int)data[h * W + w];
        int rgb = (d << 16) | (d << 8) | d;
        bin.setRGB(w, h, rgb);
      }
    }
  }
  public static void main(String[] args) throws Exception{
    BufferedImage bin = ImageIO.read(new File("data/Lenna.png"));
    int W = bin.getWidth();
    int H = bin.getHeight();
    double in[] = new double[H * W];
    for(int h = 0 ; h < H ; ++h){
      for(int w = 0 ; w < W ; ++w){
        int rgb = bin.getRGB(w, h);
        int r = (rgb >> 16) & 0xff;
        int g = (rgb >>  8) & 0xff;
        int b = (rgb      ) & 0xff;
        int m = (2 * r + 4 * g + b) / 7;
        in[h * W + w] = m;
      }
    }
    MeanShift m = new MeanShift();
    int maxIter = 200;
    int bandWidth = 15;
    long t = System.currentTimeMillis();
    double[] out = m.meanShiftLoopFast(in, maxIter , bandWidth);
    System.err.println(System.currentTimeMillis() - t);
    write(out , H , bin);
    String outputFileName = String.format("data/output%02d.png", bandWidth);
    ImageIO.write(bin, "png", new File(outputFileName)); 
  }
}