mean-shiftを実装してみた
次回のCVIM勉強会でmean-shiftについて話すことになってしまったので、理解のためにmean-shiftアルゴリズムを実装してみた。
カーネルは一番簡単なフラットカーネルを利用し、また画像もグレイスケール画像のみを扱うためピクセルの値は[0,256]の一次元データとみなす。
またナイーブな実装だと512*512ぐらいの画像でもすごい時間がかかるので二分探索とFenwickTree(動的に更新する必要はないので別に使う必要なかった、累積和を保存した配列に変更)を使ってグレイスケール画像かつフラットカーネルであることを利用して高速化した関数meanShiftLoopFastを用意した。
bandWidth=10のときは下のようになり、肌の部分が同じ値に収束していることがわかる。
bandWidth=20のときは以下のようになり、やや平滑化されすぎていることがわかる。
import java.awt.image.BufferedImage; import java.io.File; import java.util.Arrays; import javax.imageio.ImageIO; public class MeanShift { int getHigh(double data[] , double upperBound){ int lo = 0; int hi = data.length - 1; while(lo < hi){ int mid = lo + (hi - lo + 1) / 2; if(data[mid] > upperBound){ hi = mid - 1; }else{ lo = mid; } } return lo; } int getLow(double data[] , double lowerBound){ int lo = 0; int hi = data.length - 1; while(lo < hi){ int mid = lo + (hi - lo) / 2; if(data[mid] >= lowerBound){ hi = mid; }else{ lo = mid + 1; } } return lo; } double[] meanShiftLoopFast(double input[] , int maxIter , double bandwidth){ int N = input.length; double data[] = input.clone(); Arrays.sort(data); double accumulate[] = new double[N]; accumulate[0] = data[0]; for(int i = 1 ; i < N ; ++i){ accumulate[i] = data[i] + accumulate[i - 1]; } double[] output = new double[N]; for(int i = 0 ; i < N ; ++i){ double mean = input[i]; for(int iter = 0 ; iter < maxIter ; ++iter){ int div = 0; int l = getLow( data , mean - bandwidth); int h = getHigh(data, mean + bandwidth); double sum = l == 0 ? accumulate[h] : accumulate[h] - accumulate[l - 1]; div = h - l + 1; double nextMean = sum / div; if(Math.abs(mean - nextMean) < 1.0E-9){ break; } mean = nextMean; } output[i] = mean; } return output; } double[] meanShiftLoopNaive(double input[] , int maxIter , double bandwidth){ int N = input.length; double[] output = new double[N]; for(int i = 0 ; i < N ; ++i){ double mean = input[i]; for(int iter = 0 ; iter < maxIter ; ++iter){ int div = 0; double sum = 0; for(int j = 0 ; j < N ; ++j){ if(Math.abs(input[j] - mean) <= bandwidth){ sum += input[j]; ++div; } } double nextMean = sum / div; if(Math.abs(mean - nextMean) < 1.0E-9){ break; } mean = nextMean; } output[i] = mean; } return output; } static void write(double data[] , int H , BufferedImage bin){ int W = data.length / H; for(int h = 0 ; h < H ; ++h){ for(int w = 0 ; w < W ; ++w){ int d = (int)data[h * W + w]; int rgb = (d << 16) | (d << 8) | d; bin.setRGB(w, h, rgb); } } } public static void main(String[] args) throws Exception{ BufferedImage bin = ImageIO.read(new File("data/Lenna.png")); int W = bin.getWidth(); int H = bin.getHeight(); double in[] = new double[H * W]; for(int h = 0 ; h < H ; ++h){ for(int w = 0 ; w < W ; ++w){ int rgb = bin.getRGB(w, h); int r = (rgb >> 16) & 0xff; int g = (rgb >> 8) & 0xff; int b = (rgb ) & 0xff; int m = (2 * r + 4 * g + b) / 7; in[h * W + w] = m; } } MeanShift m = new MeanShift(); int maxIter = 200; int bandWidth = 15; long t = System.currentTimeMillis(); double[] out = m.meanShiftLoopFast(in, maxIter , bandWidth); System.err.println(System.currentTimeMillis() - t); write(out , H , bin); String outputFileName = String.format("data/output%02d.png", bandWidth); ImageIO.write(bin, "png", new File(outputFileName)); } }