% Tony Hyun Kim
% CS 229, PS#3, Problem 6
% k-means clustering for image compression
%------------------------------------------------------------
clear all; close all;

% source = 'mandrill-large.tiff';
source = 'mandrill-small.tiff';
A = double(imread(source));
% imshow(uint8(round(A)));

% Reorder the image information into a format required by
%   my kmeans implementation (my_kmeans). We want 'examples'
%   to be Nc x (Ny*Nx) matrix, i.e. each example is taken
%   columnwise.
%------------------------------------------------------------
[Ny Nx Nc] = size(A);
examples = permute(A,[3 1 2]);
examples = reshape(examples,Nc,Nx*Ny);

% Run k-means clustering with 16 clusters
%------------------------------------------------------------
k = 16;
ucolors = my_kmeans(examples,k);
ucolors = round(ucolors);

% Construct the compressed image
%------------------------------------------------------------
B = zeros(size(A));
for i = 1:Ny
    for j = 1:Nx % We've now selected the (i,j)-th pixel
        pixel = reshape(A(i,j,:),Nc,1);
        dists = zeros(1,k);
        for l = 1:k % Compute distances to the l-th centroid
            dists(l) = norm(pixel-ucolors(:,l));
        end
        [~, minl] = min(dists);
        B(i,j,:) = ucolors(:,minl); % Construct new pixel with
                                    %   the nearest centroid
    end
end

% Display results
%------------------------------------------------------------
subplot(121);
imshow(uint8(round(A)));
title(sprintf('Original %s',source));
subplot(122);
imshow(uint8(round(B)));
title(sprintf('k-means compression (k=%d)',k));