Improved Seeding For Clustering With K-Means++

Clustering data into subsets is an important task for many data science applications. At The Data Science Lab we have illustrated how Lloyd’s algorithm for k-means clustering works, including snapshots of python code to visualize the iterative clustering steps. One of the issues with the procedure is that

this algorithm does not supply information as to which K for the k-means is optimal; that has to be found out by alternative methods,

so that we went a step further and coded up the gap statistic to find the proper k for k-means clustering. In combination with the clustering algorithm, the gap statistic allows to estimate the best value for k among those in a given range.
An additional problem with the standard k-means procedure still remains though, as shown by the image on the right, where a poor random initialization of the centroids leads to suboptimal clustering:

If the target distribution is disjointedly clustered and only one instantiation of Lloyd’s algorithm is used, the danger exists that the local minimum reached is not the optimal solution.

The initialization problem for the k-means algorithm is an important practical one, and has been discussed extensively. It is desirable to augment the standard k-means clustering procedure with a robust initialization mechanism that guarantees convergence to the optimal solution.

k-means++: the advantages of careful seeding

A solution called k-means++ was proposed in 2007 by Arthur and Vassilvitskii. This algorithm comes with a theoretical guarantee to find a solution that is O(log k) competitive to the optimal k-means solution. It is also fairly simple to describe and implement. Starting with a dataset X of N points (\mathrm{x}_1, \ldots, \mathrm{x}_N),

  • choose an initial center c_1 uniformly at random from X. Compute the vector containing the square distances between all points in the dataset and c_1: D_i^2 = ||\mathrm{x}_i - c_1 ||^2
  • choose a second center c_2 from X randomly drawn from the probability distribution D_i^2 / \sum_j D_j^2
  • recompute the distance vector as D_i^2 = \mathrm{min} \left(||\mathrm{x}_i - c_1 ||^2, ||\mathrm{x}_i - c_2 ||^2\right)
  • choose a successive center c_l and recompute the distance vector as D_i^2 = \mathrm{min} \left(||\mathrm{x}_i - c_1 ||^2, \ldots, ||\mathrm{x}_i - c_l ||^2\right)
  • when exactly k centers have been chosen, finalize the initialization phase and proceed with the standard k-means algorithm

The interested reader can find a review of the k-means++ algorithm at normaldeviate, a survey of implementations in several languages at rosettacode and a ready-to-use solution in pandas by Jack Maney in github.

A python implementation of the k-means++ algorithm

Out python implementation of the k-means++ algorithm builds on the code for standard k-means shown in the previous post. The KMeans class defined below contains all necessary functions and methods to generate toy data and run the Lloyd’s clustering algorithm on it:

class KMeans():
    def __init__(self, K, X=None, N=0):
        self.K = K
        if X == None:
            if N == 0:
                raise Exception("If no data is provided, \
                                 a parameter N (number of points) is needed")
                self.N = N
                self.X = self._init_board_gauss(N, K)
            self.X = X
            self.N = len(X) = None
        self.clusters = None
        self.method = None

    def _init_board_gauss(self, N, k):
        n = float(N)/k
        X = []
        for i in range(k):
            c = (random.uniform(-1,1), random.uniform(-1,1))
            s = random.uniform(0.05,0.15)
            x = []
            while len(x) < n:
                a,b = np.array([np.random.normal(c[0],s),np.random.normal(c[1],s)])
                # Continue drawing points from the distribution in the range [-1,1]
                if abs(a) and abs(b)<1:
        X = np.array(X)[:N]
        return X

    def plot_board(self):
        X = self.X
        fig = plt.figure(figsize=(5,5))
        if and self.clusters:
            mu =
            clus = self.clusters
            K = self.K
            for m, clu in clus.items():
                cs = cm.spectral(1.*m/self.K)
                plt.plot(mu[m][0], mu[m][1], 'o', marker='*', \
                         markersize=12, color=cs)
                plt.plot(zip(*clus[m])[0], zip(*clus[m])[1], '.', \
                         markersize=8, color=cs, alpha=0.5)
            plt.plot(zip(*X)[0], zip(*X)[1], '.', alpha=0.5)
        if self.method == '++':
            tit = 'K-means++'
            tit = 'K-means with random initialization'
        pars = 'N=%s, K=%s' % (str(self.N), str(self.K))
        plt.title('\n'.join([pars, tit]), fontsize=16)
        plt.savefig('kpp_N%s_K%s.png' % (str(self.N), str(self.K)), \
                    bbox_inches='tight', dpi=200)

    def _cluster_points(self):
        mu =
        clusters  = {}
        for x in self.X:
            bestmukey = min([(i[0], np.linalg.norm(x-mu[i[0]])) \
                             for i in enumerate(mu)], key=lambda t:t[1])[0]
            except KeyError:
                clusters[bestmukey] = [x]
        self.clusters = clusters

    def _reevaluate_centers(self):
        clusters = self.clusters
        newmu = []
        keys = sorted(self.clusters.keys())
        for k in keys:
            newmu.append(np.mean(clusters[k], axis = 0)) = newmu

    def _has_converged(self):
        K = len(self.oldmu)
        return(set([tuple(a) for a in]) == \
               set([tuple(a) for a in self.oldmu])\
               and len(set([tuple(a) for a in])) == K)

    def find_centers(self, method='random'):
        self.method = method
        X = self.X
        K = self.K
        self.oldmu = random.sample(X, K)
        if method != '++':
            # Initialize to K random centers
   = random.sample(X, K)
        while not self._has_converged():
            self.oldmu =
            # Assign all points in X to clusters
            # Reevaluate centers

To initalize the board with n data points normally distributed around k centers, we call kmeans = KMeans(k, N=n).

kmeans = KMeans(3, N=200)

The snippet above creates a board with 200 points around 3 clusters. The call to the find_centers() function runs the standard k-means algorithm initializing the centroids to 3 random points. Finally, the function plot_board() produces a plot of the data points as clustered by the algorithm, with the centroids marked as stars. In the image below we can see the results of running the algorithm twice. Due to the random initialization of the standard k-means, the correct solution is found some of the times (right panel) whereas in some cases a suboptimal end point is reached instead (left panel). kpp_N200_K3

Let us now implement the k-means++ algorithm in its own class, which inherits from the class Kmeans defined above.

class KPlusPlus(KMeans):
    def _dist_from_centers(self):
        cent =
        X = self.X
        D2 = np.array([min([np.linalg.norm(x-c)**2 for c in cent]) for x in X])
        self.D2 = D2

    def _choose_next_center(self):
        self.probs = self.D2/self.D2.sum()
        self.cumprobs = self.probs.cumsum()
        r = random.random()
        ind = np.where(self.cumprobs >= r)[0][0]

    def init_centers(self): = random.sample(self.X, 1)
        while len( < self.K:

    def plot_init_centers(self):
        X = self.X
        fig = plt.figure(figsize=(5,5))
        plt.plot(zip(*X)[0], zip(*X)[1], '.', alpha=0.5)
        plt.plot(zip(*[0], zip(*[1], 'ro')
        plt.savefig('kpp_init_N%s_K%s.png' % (str(self.N),str(self.K)), \
                    bbox_inches='tight', dpi=200)

To run the k-means++ initialization stage using this class and visualize the centers found by the algorithm, we simply do:

kplusplus = KPlusPlus(5, N=200)

kpp_init_N200_K5 Let us explore what the function init_centers() is actually doing: to begin with, a random point is chosen as first center from the X data points as random.sample(self.X, 1). Then, the successive centers are picked, stopping when we have K=5 of them. The procedure to choose the next most suitable center is coded up in the _choose_next_center() function. As we described above, the next center is drawn from a distribution given by the normalized distance vector D_i^2 / \sum_j D_j^2. To implement such a probability distribution, we compute the cumulative probabilities for choosing each of the N points in X. These cumulative probabilities are partitions in the interval [0,1] with length equal to the probability of the corresponding point being chosen as a center, as explained in this stackoverflow thread. Therefore, by picking a random value r \in [0,1] and finding the point corresponding to the segment of the partition where that r value falls, we are effectively choosing a point drawn according to the desired probability distribution. On the right is a plot showing the results of the algorithm for 200 points and 5 clusters.

Finally let us compare the results of k-means with random initialization and k-means++ with proper seeding, using the following code snippets:

# Random initialization
# k-means++ initialization


The standard algorithm with random initialization in a particular instantiation (left panel) fails at identifying the 5 optimal centroids for the clustering, whereas the k-means++ initialization (right panel) succeeds in doing so. By picking up a specific and not random set of centroids to initiate the clustering process, the k-means++ algorithm also reaches convergence faster, guaranteed by the theorems proved in the Arthur and Vassilvitskii article.

Table-top data experiment take-away message

The k-means++ method for finding a proper seeding for the choice of initial centroids yields considerable improvement over the standard Lloyd’s implementation of the k-means algorithm. The initial selection in k-means++ takes extra time and involves choosing centers in a successive order and drawing them from a particular probability distribution that has to be recomputed at each step. However, by doing so, the k-means part of the algorithm converges very quickly after this seeding and thus the whole procedure actually runs in a shorter computation time. The combination of the k-means++ initialization stage with the standard Lloyd’s algorithm, together with additional various techniques to find out an optimal value for the ideal number of clusters, poses a robust way to solve the complete problem of clustering data points.



    • datasciencelab

      Thanks for the comment, you’re totally right!
      As this post follows up on the previous ones in the clustering series, I didn’t include the imports again in the code shown here. Thanks for noticing!

  1. abulrafay

    when I run the instruction!
    I got error:
    Traceback (most recent call last):
    File “”, line 1, in
    File “C:/Python33/”, line 90, in find_centers
    self.oldmu = random.sample(X, K)
    File “C:\Python33\lib\”, line 298, in sample
    raise TypeError(“Population must be a sequence or set. For dicts, use list(d).”)
    TypeError: Population must be a sequence or set. For dicts, use list(d).

  2. Anonymous

    ` = random.sample(X, K)` doesn’t do what you intend. For the random initialization method, I think you meant something more like:

    indices = np.random.choice(len(X), 2 * K)
    self.oldmu = X[indices[:K]] = X[indices[K:]]

    This picks two sets of (guaranteed distinct) points as your initial values for the cluster centers.

  3. Iain Strachan

    I think your implementation of the K-means++ algorithm is inefficient for large values of k, because it involves O(k^2) calculations but need only involve O(k). I found this also to be the case in the Matlab Statistical toolbox version of the code. The key is in the step where you compute the distance between each point in X and the nearest centre already chosen. The point is that you don’t need to recalculate the distances for each centre already chosen, because you have already worked that out on the previous step. So you only need an array of distances of the length of X and then compute the distance of each X to the SINGLE new centre you have chosen, and if that is less than the existing minimum for X then replace it in the array.

    When I implemented this in Matlab (about 6 lines of code!) it made a dramatic speed up – around a factor 100 when I was using a large number of centres.

    I am still trying to learn Python, but it looks to me that you are recomputing the distance of every x to every c each time. (If I’m wrong please forgive!)

    It probably makes it easier to understand the algorithm from a tutorial point of view, but it doesn’t scale well with increasing k.

    • datasciencelab

      Thanks for this thoughtful comment! Indeed the code is not written for speed or optimised in any way. As you mention, it was written in the spirit of a tutorial but not intended for production. Your suggested fixes would absolutely speed up the performance and I’m glad you pointed them out.

  4. yummy

    Error message:

    File “”, line 153, in main
    File “”, line 102, in find_centers
    while not self._has_converged():
    File “”, line 91, in _has_converged
    return(set([tuple(a) for a in]) == set([tuple(a) for a in self.oldmu])\
    TypeError: ‘NoneType’ object is not iterable

  5. Tyler Durden

    You should use ‘probability’ and not ‘cumulative probability’ as the distribution, or else cluster quality is unchanged from random seeding.

  6. Jason

    What does the “[0][0]” do in “ind = np.where(self.cumprobs >= r)[0][0]”? I do not use numpy, and I am trying to code this algorithm in pure Python.

  7. Bernd Fritzke

    k-means++ clearly is a major improvement over k-means with random initialization. There are however quite general problems where k-means++ fails to get anywhere close to the optimum if k is larger than a problem-dependent value. I illustrate this in a recent report ( and also propose a method which often can improve the results of k-means++ by a large margin (several percent SSE). Python software (jupyter notebooks) is available as well: feedback/criticism very welcome!

Post a comment

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s