Tagged: initialization

Improved Seeding For Clustering With K-Means++

Clustering data into subsets is an important task for many data science applications. At The Data Science Lab we have illustrated how Lloyd’s algorithm for k-means clustering works, including snapshots of python code to visualize the iterative clustering steps. One of the issues with the procedure is that

this algorithm does not supply information as to which K for the k-means is optimal; that has to be found out by alternative methods,

so that we went a step further and coded up the gap statistic to find the proper k for k-means clustering. In combination with the clustering algorithm, the gap statistic allows to estimate the best value for k among those in a given range.
An additional problem with the standard k-means procedure still remains though, as shown by the image on the right, where a poor random initialization of the centroids leads to suboptimal clustering:

If the target distribution is disjointedly clustered and only one instantiation of Lloyd’s algorithm is used, the danger exists that the local minimum reached is not the optimal solution.

The initialization problem for the k-means algorithm is an important practical one, and has been discussed extensively. It is desirable to augment the standard k-means clustering procedure with a robust initialization mechanism that guarantees convergence to the optimal solution.

k-means++: the advantages of careful seeding

A solution called k-means++ was proposed in 2007 by Arthur and Vassilvitskii. This algorithm comes with a theoretical guarantee to find a solution that is O(log k) competitive to the optimal k-means solution. It is also fairly simple to describe and implement. Starting with a dataset X of N points (\mathrm{x}_1, \ldots, \mathrm{x}_N),

  • choose an initial center c_1 uniformly at random from X. Compute the vector containing the square distances between all points in the dataset and c_1: D_i^2 = ||\mathrm{x}_i - c_1 ||^2
  • choose a second center c_2 from X randomly drawn from the probability distribution D_i^2 / \sum_j D_j^2
  • recompute the distance vector as D_i^2 = \mathrm{min} \left(||\mathrm{x}_i - c_1 ||^2, ||\mathrm{x}_i - c_2 ||^2\right)
  • choose a successive center c_l and recompute the distance vector as D_i^2 = \mathrm{min} \left(||\mathrm{x}_i - c_1 ||^2, \ldots, ||\mathrm{x}_i - c_l ||^2\right)
  • when exactly k centers have been chosen, finalize the initialization phase and proceed with the standard k-means algorithm

The interested reader can find a review of the k-means++ algorithm at normaldeviate, a survey of implementations in several languages at rosettacode and a ready-to-use solution in pandas by Jack Maney in github.

A python implementation of the k-means++ algorithm

Out python implementation of the k-means++ algorithm builds on the code for standard k-means shown in the previous post. The KMeans class defined below contains all necessary functions and methods to generate toy data and run the Lloyd’s clustering algorithm on it:

class KMeans():
    def __init__(self, K, X=None, N=0):
        self.K = K
        if X == None:
            if N == 0:
                raise Exception("If no data is provided, \
                                 a parameter N (number of points) is needed")
                self.N = N
                self.X = self._init_board_gauss(N, K)
            self.X = X
            self.N = len(X)
        self.mu = None
        self.clusters = None
        self.method = None

    def _init_board_gauss(self, N, k):
        n = float(N)/k
        X = []
        for i in range(k):
            c = (random.uniform(-1,1), random.uniform(-1,1))
            s = random.uniform(0.05,0.15)
            x = []
            while len(x) < n:
                a,b = np.array([np.random.normal(c[0],s),np.random.normal(c[1],s)])
                # Continue drawing points from the distribution in the range [-1,1]
                if abs(a) and abs(b)<1:
        X = np.array(X)[:N]
        return X

    def plot_board(self):
        X = self.X
        fig = plt.figure(figsize=(5,5))
        if self.mu and self.clusters:
            mu = self.mu
            clus = self.clusters
            K = self.K
            for m, clu in clus.items():
                cs = cm.spectral(1.*m/self.K)
                plt.plot(mu[m][0], mu[m][1], 'o', marker='*', \
                         markersize=12, color=cs)
                plt.plot(zip(*clus[m])[0], zip(*clus[m])[1], '.', \
                         markersize=8, color=cs, alpha=0.5)
            plt.plot(zip(*X)[0], zip(*X)[1], '.', alpha=0.5)
        if self.method == '++':
            tit = 'K-means++'
            tit = 'K-means with random initialization'
        pars = 'N=%s, K=%s' % (str(self.N), str(self.K))
        plt.title('\n'.join([pars, tit]), fontsize=16)
        plt.savefig('kpp_N%s_K%s.png' % (str(self.N), str(self.K)), \
                    bbox_inches='tight', dpi=200)

    def _cluster_points(self):
        mu = self.mu
        clusters  = {}
        for x in self.X:
            bestmukey = min([(i[0], np.linalg.norm(x-mu[i[0]])) \
                             for i in enumerate(mu)], key=lambda t:t[1])[0]
            except KeyError:
                clusters[bestmukey] = [x]
        self.clusters = clusters

    def _reevaluate_centers(self):
        clusters = self.clusters
        newmu = []
        keys = sorted(self.clusters.keys())
        for k in keys:
            newmu.append(np.mean(clusters[k], axis = 0))
        self.mu = newmu

    def _has_converged(self):
        K = len(self.oldmu)
        return(set([tuple(a) for a in self.mu]) == \
               set([tuple(a) for a in self.oldmu])\
               and len(set([tuple(a) for a in self.mu])) == K)

    def find_centers(self, method='random'):
        self.method = method
        X = self.X
        K = self.K
        self.oldmu = random.sample(X, K)
        if method != '++':
            # Initialize to K random centers
            self.mu = random.sample(X, K)
        while not self._has_converged():
            self.oldmu = self.mu
            # Assign all points in X to clusters
            # Reevaluate centers

To initalize the board with n data points normally distributed around k centers, we call kmeans = KMeans(k, N=n).

kmeans = KMeans(3, N=200)

The snippet above creates a board with 200 points around 3 clusters. The call to the find_centers() function runs the standard k-means algorithm initializing the centroids to 3 random points. Finally, the function plot_board() produces a plot of the data points as clustered by the algorithm, with the centroids marked as stars. In the image below we can see the results of running the algorithm twice. Due to the random initialization of the standard k-means, the correct solution is found some of the times (right panel) whereas in some cases a suboptimal end point is reached instead (left panel). kpp_N200_K3

Let us now implement the k-means++ algorithm in its own class, which inherits from the class Kmeans defined above.

class KPlusPlus(KMeans):
    def _dist_from_centers(self):
        cent = self.mu
        X = self.X
        D2 = np.array([min([np.linalg.norm(x-c)**2 for c in cent]) for x in X])
        self.D2 = D2

    def _choose_next_center(self):
        self.probs = self.D2/self.D2.sum()
        self.cumprobs = self.probs.cumsum()
        r = random.random()
        ind = np.where(self.cumprobs >= r)[0][0]

    def init_centers(self):
        self.mu = random.sample(self.X, 1)
        while len(self.mu) < self.K:

    def plot_init_centers(self):
        X = self.X
        fig = plt.figure(figsize=(5,5))
        plt.plot(zip(*X)[0], zip(*X)[1], '.', alpha=0.5)
        plt.plot(zip(*self.mu)[0], zip(*self.mu)[1], 'ro')
        plt.savefig('kpp_init_N%s_K%s.png' % (str(self.N),str(self.K)), \
                    bbox_inches='tight', dpi=200)

To run the k-means++ initialization stage using this class and visualize the centers found by the algorithm, we simply do:

kplusplus = KPlusPlus(5, N=200)

kpp_init_N200_K5 Let us explore what the function init_centers() is actually doing: to begin with, a random point is chosen as first center from the X data points as random.sample(self.X, 1). Then, the successive centers are picked, stopping when we have K=5 of them. The procedure to choose the next most suitable center is coded up in the _choose_next_center() function. As we described above, the next center is drawn from a distribution given by the normalized distance vector D_i^2 / \sum_j D_j^2. To implement such a probability distribution, we compute the cumulative probabilities for choosing each of the N points in X. These cumulative probabilities are partitions in the interval [0,1] with length equal to the probability of the corresponding point being chosen as a center, as explained in this stackoverflow thread. Therefore, by picking a random value r \in [0,1] and finding the point corresponding to the segment of the partition where that r value falls, we are effectively choosing a point drawn according to the desired probability distribution. On the right is a plot showing the results of the algorithm for 200 points and 5 clusters.

Finally let us compare the results of k-means with random initialization and k-means++ with proper seeding, using the following code snippets:

# Random initialization
# k-means++ initialization


The standard algorithm with random initialization in a particular instantiation (left panel) fails at identifying the 5 optimal centroids for the clustering, whereas the k-means++ initialization (right panel) succeeds in doing so. By picking up a specific and not random set of centroids to initiate the clustering process, the k-means++ algorithm also reaches convergence faster, guaranteed by the theorems proved in the Arthur and Vassilvitskii article.

Table-top data experiment take-away message

The k-means++ method for finding a proper seeding for the choice of initial centroids yields considerable improvement over the standard Lloyd’s implementation of the k-means algorithm. The initial selection in k-means++ takes extra time and involves choosing centers in a successive order and drawing them from a particular probability distribution that has to be recomputed at each step. However, by doing so, the k-means part of the algorithm converges very quickly after this seeding and thus the whole procedure actually runs in a shorter computation time. The combination of the k-means++ initialization stage with the standard Lloyd’s algorithm, together with additional various techniques to find out an optimal value for the ideal number of clusters, poses a robust way to solve the complete problem of clustering data points.