Selection of K in K-means Clustering, Reloaded
This article follows up on the series devoted to k-means clustering at The Data Science Lab. Previous posts have dealt with how to implement Lloyd’s algorithm for clustering in python, described an improved initialization algorithm for proper seeding of the initial clusters, k-means++, and introduced the gap statistic as a method of finding the optimal K for k-means clustering.
Although the gap statistic, based on a paper by Tibshirani et al was shown to find optimal values for the number of clusters in a variety of cases when the clusters where globular and mildly disjointed, its performance might be hampered by the need of perfoming Monte Carlo simulations to estimate the reference datasets. A reader of this blog, Jonathan Stray, pointed out a potentially superior method for selecting the K in k-means clustering, so let us implement it and compare.
An alternative approach to finding the optimal K
The approach suggested by our reader is based on a publication by Pham, Dimov and Nguyen from 2004. The article is very much worth reading, as it includes an explanation of the drawbacks of the standard k-means algorithm as well as a comprehensive survey on different methods that have been proposed for selecting an optimal number of clusters.
In section 3 of the paper, the authors justify the introduction of a function to evaluate the quality of the resulting clustering and help decide on the optimal value of
for each data set. Quoting from the paper:
A data set with
objects could be grouped into any number of clusters between 1 and
, which would correspond to the lowest and the highest levels of detail respectively. By specifying different
values, it is possible to assess the results of grouping objects into various numbers of clusters. From this evaluation, more than one
value could be recommended to users, but the final selection is made by them.
The goal of a clustering algorithm is to identify regions in which the data points are concentrated. It is also important to analyze the internal distribution of each cluster as well as its relation to other clusters in the data set. The distorsion of a cluster is a measure of the distance between points in a cluster and its centroid:
.
The global impact of all clusters’ distortions is given by the quantity
.
The authors Pham et al. proceed to discuss further constrains that the sought-after function should verify for it to be informative to the problem of selection of K. They finally arrive at the following definition:
is the number of dimensions (attributes) of the data set and
is a weight factor. With this definition,
is the ratio of the real distortion to the estimated distortion and decreases when there are areas of concentration in the data distribution. Values of
that yield small
can be regarded as giving well-defined clusters.
A python implementation of Pham et al. f(K)
Our implementation of the Pham et al. procedure builds on the KMeans
and KPlusPlus
python classes defined in our article on the k-means++ algorithm. We define a new class that inherits from KPlusPlus
and contains a function to compute :
class DetK(KPlusPlus): def fK(self, thisk, Skm1=0): X = self.X Nd = len(X[0]) a = lambda k, Nd: 1 - 3/(4*Nd) if k == 2 else a(k-1, Nd) + (1-a(k-1, Nd))/6 self.find_centers(thisk, method='++') mu, clusters = self.mu, self.clusters Sk = sum([np.linalg.norm(mu[i]-c)**2 \ for i in range(thisk) for c in clusters[i]]) if thisk == 1: fs = 1 elif Skm1 == 0: fs = 1 else: fs = Sk/(a(thisk,Nd)*Skm1) return fs, Sk
Note the recursive definition of (variable
a
in the code snapshot above) and the fact that the computation of for
requires knowing the value of
, which is passed as input parameter to the function.
This article aims at showing that the Pham et al. procedure works and is computationally more efficient than the gap statistic. Therefore, we will code up the algorithm for the gap statistic within the same class DetK
, so that we can run both procedures simultaneously. The full code is below the fold:
class DetK(KPlusPlus): def fK(self, thisk, Skm1=0): X = self.X Nd = len(X[0]) a = lambda k, Nd: 1 - 3/(4*Nd) if k == 2 else a(k-1, Nd) + (1-a(k-1, Nd))/6 self.find_centers(thisk, method='++') mu, clusters = self.mu, self.clusters Sk = sum([np.linalg.norm(mu[i]-c)**2 \ for i in range(thisk) for c in clusters[i]]) if thisk == 1: fs = 1 elif Skm1 == 0: fs = 1 else: fs = Sk/(a(thisk,Nd)*Skm1) return fs, Sk def _bounding_box(self): X = self.X xmin, xmax = min(X,key=lambda a:a[0])[0], max(X,key=lambda a:a[0])[0] ymin, ymax = min(X,key=lambda a:a[1])[1], max(X,key=lambda a:a[1])[1] return (xmin,xmax), (ymin,ymax) def gap(self, thisk): X = self.X (xmin,xmax), (ymin,ymax) = self._bounding_box() self.init_centers(thisk) self.find_centers(thisk, method='++') mu, clusters = self.mu, self.clusters Wk = np.log(sum([np.linalg.norm(mu[i]-c)**2/(2*len(c)) \ for i in range(thisk) for c in clusters[i]])) # Create B reference datasets B = 10 BWkbs = zeros(B) for i in range(B): Xb = [] for n in range(len(X)): Xb.append([random.uniform(xmin,xmax), \ random.uniform(ymin,ymax)]) Xb = np.array(Xb) kb = DetK(thisk, X=Xb) kb.init_centers(thisk) kb.find_centers(thisk, method='++') ms, cs = kb.mu, kb.clusters BWkbs[i] = np.log(sum([np.linalg.norm(ms[j]-c)**2/(2*len(c)) \ for j in range(thisk) for c in cs[j]])) Wkb = sum(BWkbs)/B sk = np.sqrt(sum((BWkbs-Wkb)**2)/float(B))*np.sqrt(1+1/B) return Wk, Wkb, sk def run(self, maxk, which='both'): ks = range(1,maxk) fs = zeros(len(ks)) Wks,Wkbs,sks = zeros(len(ks)+1),zeros(len(ks)+1),zeros(len(ks)+1) # Special case K=1 self.init_centers(1) if which == 'f': fs[0], Sk = self.fK(1) elif which == 'gap': Wks[0], Wkbs[0], sks[0] = self.gap(1) else: fs[0], Sk = self.fK(1) Wks[0], Wkbs[0], sks[0] = self.gap(1) # Rest of Ks for k in ks[1:]: self.init_centers(k) if which == 'f': fs[k-1], Sk = self.fK(k, Skm1=Sk) elif which == 'gap': Wks[k-1], Wkbs[k-1], sks[k-1] = self.gap(k) else: fs[k-1], Sk = self.fK(k, Skm1=Sk) Wks[k-1], Wkbs[k-1], sks[k-1] = self.gap(k) if which == 'f': self.fs = fs elif which == 'gap': G = [] for i in range(len(ks)): G.append((Wkbs-Wks)[i] - ((Wkbs-Wks)[i+1]-sks[i+1])) self.G = np.array(G) else: self.fs = fs G = [] for i in range(len(ks)): G.append((Wkbs-Wks)[i] - ((Wkbs-Wks)[i+1]-sks[i+1])) self.G = np.array(G) def plot_all(self): X = self.X ks = range(1, len(self.fs)+1) fig = plt.figure(figsize=(18,5)) # Plot 1 ax1 = fig.add_subplot(131) ax1.set_xlim(-1,1) ax1.set_ylim(-1,1) ax1.plot(zip(*X)[0], zip(*X)[1], '.', alpha=0.5) tit1 = 'N=%s' % (str(len(X))) ax1.set_title(tit1, fontsize=16) # Plot 2 ax2 = fig.add_subplot(132) ax2.set_ylim(0, 1.25) ax2.plot(ks, self.fs, 'ro-', alpha=0.6) ax2.set_xlabel('Number of clusters K', fontsize=16) ax2.set_ylabel('f(K)', fontsize=16) foundfK = np.where(self.fs == min(self.fs))[0][0] + 1 tit2 = 'f(K) finds %s clusters' % (foundfK) ax2.set_title(tit2, fontsize=16) # Plot 3 ax3 = fig.add_subplot(133) ax3.bar(ks, self.G, alpha=0.5, color='g', align='center') ax3.set_xlabel('Number of clusters K', fontsize=16) ax3.set_ylabel('Gap', fontsize=16) foundG = np.where(self.G > 0)[0][0] + 1 tit3 = 'Gap statistic finds %s clusters' % (foundG) ax3.set_title(tit3, fontsize=16) ax3.xaxis.set_ticks(range(1,len(ks)+1)) plt.savefig('detK_N%s.png' % (str(len(X))), \ bbox_inches='tight', dpi=100)
For a first experiment comparing the Pham et al. and the gap statistic approaches, we create a data set comprising 300 points around 2 Gaussian-distributed clusters. We run both methods to select spanning the values
. (The function
run
from class DetK
takes a value as input and checks all values such that
.) Note that every run of the k-means clustering algorithm for different values of
is preceded by the k-means++ initialization algorithm, to prevent landing at suboptimal clustering solutions.
To run a full comparison of both methods, the following simple commands are invoked:
kpp = DetK(2, N=300) kpp.run(10) kpp.plot_all()
This produces the following result plots:
According to Pham et al. lower values of , and especially values
are an indication of cluster-like features in the data at that particular
. In the case of
, the global minimum of
in the central plot leaves no doubt that this is the right value to choose for this particular data configuration. The gap statistic, depicted in the plot on the right, yields the same result of
. Remember that the optimal
with the gap statistic is the smallest value for which the gap quantity becomes positive.
Similarly, we can analyze a data set consisting of 100 points around a single cluster. The results are shown in the plots below. We observe how the function does not show any prominent valley or value for which
for any of the surveyed
s. According to the Pham et al. paper, this is an indication of no clustering, as is the case. The gap statistic agrees that there is no more than one cluster in this case.
Finally, let us look at two cases, both with 500 data points around 4 clusters. Below are the plots of the results:
For the data distribution on the top, one can see that the 4 clusters are positioned in such a way that they could also be interpreted as 2 clusters made of 2 subclusters each. The detects this configuration and suggests 2 possible values of
, with a slight preference for
over
. The gap statistic changes sign at
, albeit barely, and it does it again and more clearly at
. In both cases, a strict application of the rules prescribed to select the correct
does lead to a rather suboptimal, or at least dubious, choice.
In the bottom plot however, the 4 clusters are somehow more evenly spreaded and both algorithms succeed at identifying . The
method still shows a relative minimum at
, indicating a potentially alternative clustering.
Performance comparison of f(K) and the gap statistic
If both methods to select the optimal for k-means clustering yield similar results, one should ask about the relative performance of them in real-life data science clustering problems. It is straightforward to predict that the gap statistic, with its need for running the k-means algorithm multiple times to create a Monte Carlo reference distribution, will necessarily be a poorer performer. We can easily test this hypothesis with our code by running both approaches and timing them using the IPython magic
%time
function. For a data set with :
%time kpp.run(10, which='f')
CPU times: user 2.72 s, sys: 0.00 s, total: 2.72 s
Wall time: 2.90 s
%time kpp.run(10, which='gap')
CPU times: user 51.30 s, sys: 0.01 s, total: 51.31 s
Wall time: 51.40 s
In this particular example, the method is more than one order of magnitude more performant than the gap statistic, and this comparison looks worse for the latter the more data we take into consideration and the larger the number
employed for generating the reference distributions.
Table-top data experiment take-away message
The estimation of the optimal number of clusters within a set of data points is a very important problem, as most clustering algorithms need that parameter as input in order to group the data. Many methods have been proposed to find the proper , among which the approach proposed by Pham et al. in 2004 seems to offer a very straightforward and performant solution. The estimation of the function
over the desired range of test values for
offers an immediate way of assessing when the cluster-like features appear and allows to choose among a best value and other alternatives. A comparison in performance with the gap statistic method of Tibshirani et al. concludes that the
is computationally advantageous.
I think this is great. But then, I would.
I’d love to understand better how k-means works on real data. These synthetic datasets are good for testing, but they are very low dimensional and very clean. It’s a pretty major challenge to visualize how k-means is doing in 40 dimensions… one thing that happens is the algorithm gets less stable because there are many more directions to slice the same data set into clusters. Have you seen any work explaining or evaluating this phenomenon?
You’re totally right. I’m very curious to see how the method works on multi-dimensional, real data. Visualization becomes much trickier though, which is why I chose this simple 2-d case for didactical purposes. I need to find literature in the matter before I dig deeper into practical experiments. Hang in there!
i used clustering in analysis of financial services data.
it is highly effective to understand the dimensional structure of data as it occurs “naturally”.
after a stable solution is obtained, you can use the segmentation for extensive profiling.
when building on real data, a lot of time is spent trying to make the sample amenable to clustering.
for instance, you need to get set bounds on individual variables so that extreme outliers do not destabilise the solution, or keep it from converging.
also, how do you deal with missing data? do you impute? do you exclude altogether?
these decisions are subjective, driven primarily by expert knowledge of domain. due to this, clusters built by one analyst are rarely the same as those built by another.
check for multicollinearity among variables before going any further.
analysts constrain the minimum size of each cluster, e.g., no clusters should be less than 3% of the sample, or some such rule. some analysts may constrain the maximum number of clusters that can be identified. this is to make business strategies manageable/scalable. this will probably be different is biological/engineering industry.
since many subjective decisions are involved, quality assessment of a cluster solution is different from that of a regression model. in clusters, you have to spend a lot more time looking at the results to check whether they make intuitive sense, and whether they will help you with your objective.
Hello guys, great articles, is it possible to see whole source code with examples in one? I meant if you are sharing this code on github or somewhere else.
Thanks a lot!
Hello, it’s just one gal over here, which is why I haven’t had yet time to set up a github repo for all the code. But it’s in the pipeline!
Very impressive. If you need additional helping hand with code on github repo, let me know. Good work.
These changes will make the code work:
def init_centers(self,K):
self.K = K
self.mu = random.sample(self.X,1)
while len(self.mu) < self.K:
self._dist_from_centers()
self.mu.append(self._choose_next_center())
def plot_init_centers(self):
X = self.X
fig = plt.figure(figsize=(5,5))
plt.xlim(-1,1)
plt.ylim(-1,1)
plt.plot(zip(*X)[0], zip(*X)[1], '.', alpha=0.5)
plt.plot(zip(*self.mu)[0], zip(*self.mu)[1], 'ro')
plt.savefig('kpp_init_N%s_K%s.png' % (str(self.N),str(self.K)), \
bbox_inches='tight', dpi=200)
Sorry, no changes in plot_init_centers required. But in this one:
def find_centers(self, K, method=’random’):
self.method = method
X = self.X
K = self.K
self.oldmu = random.sample(X, K)
if method != ‘++’:
# Initialize to K random centers
self.mu = random.sample(X, K)
while not self._has_converged():
self.oldmu = self.mu
# Assign all points in X to clusters
self._cluster_points()
# Reevaluate centers
self._reevaluate_centers()
Do we have any java code available , not good at python
The code isn’t running even when applying the changes from Anonymous. A github repo would be really, really great!
This code doesn’t run, despite my attempts to aggregate the required code over several posts. If you don’t want to go through the hassle of setting up a repo on git, could you email me the final working code that generated the plots in this post? I can even create the repo for you (and credit you of course!) if you’d like.
The lambda expression for ‘a’ in the fK() method is incorrect. All of its variables and constants are integers. Thus it only returns integer values for a.
Agreed, I had to change the integer values to floats to get the correct weights,
a = lambda k, Nd: 1.0 – 3.0/(4.0*Nd) if k == 2 else a(k-1, Nd) + (1.0-a(k-1, Nd))/6.0
Otherwise, it worked and is a great post!
So if you’re unable to get the code to run what you need to do is implement anon’s fixes:
1. add K as an arg in find_enters and init_enters
def find_centers(self, K, method=’random’)
def init_centers(self,K)
2. you need to import a few modules for all 3 files, KMeans, KPlusPlus, and DetK
import numpy as np
import random
import matplotlib.pyplot as plt
3. add np. to all the zeros (np.zeros instead of zeros). I think there are 4 around line 40-50
With these, you should be able to run the program. (a.run(10) does take a little while)
4. I’m a bit new to Python inheritance but make sure you add from subclass import subclass
for KPlusPlus:
from Kmeans import Kmeans
for DetK
from KPlusPlus import KPlusPlus
It looks like OP might have ditched this blog, but he/she did a pretty good job with this code and all the work he/she has done to present these complex algorithms in a simplified form.
Also, if you’re passing a specific argument to only run one version of the algorithm (gap or f(K) just remember to adjust the graphs in plot_all accordingly as its default is to print both (even as f(K) is much faster).
Page 3: Unfortunately, this method of selecting
K
cannot
be applied to practical problems. The data distri-
bution in practical problems is unknown and also
the number of generators cannot be specified.
That is for the method of “Values of K equated to the
number of generators”
You cannot use this method to find K because you cannot know how many generators there were. They are covering normal methods of finding K.
This function makes the calculation very slow: a = lambda k, Nd: 1 – 3/(4*Nd) if k == 2 else a(k-1, Nd) + (1-a(k-1, Nd))/6
It can be avoided by implementing this piece of code using Memoization. That way it will store all previous results and therefore there will be no need to compute a(k-1, Nd) again for k’s that has been already computed in previous ‘thisk’ iterations.
Great feedback. Right, that’s the problem with implicit functions. The code has been written for clarity without any intend of optimising it. Memoization is a good way to go.
Here is a simple implementation I used – is this the kind of thing you had in mind?
”’
This function allows for memoization (caching) of results for the recursive
computation of the weighting value a_k. It is used as a decorator below.
INPUT:
func = The function that will be called
OUTPUT:
helper = The memoized function
Adapted from
http://www.python-course.eu/python3_memoization.php
and
Click to access NLTK_parsing_demos.pdf
For a more complete version of memoization see
https://wiki.python.org/moin/PythonDecoratorLibrary#Memoize
”’
def memoize(func):
memo = {}
def helper(*args):
if args not in memo:
memo[args] = func(*args)
return memo[args]
return helper
”’
This function computes the weighting constant used to account for the
data dimensionality. It has a memoize decorator to help with the
recursive nature of the function – see above.
INPUT:
k = The current value of K (num_clusters) being evaluated
Nd = The number of features in the data
OUTPUT:
a_k = The weighting factor a_k
”’
@memoize
def get_ak(k, Nd):
if k == 2:
a_k = 1.0 – 3.0/(4.0 * Nd)
else:
a_k = get_a(k-1, Nd) + (1.0 – get_a(k-1, Nd))/6.0
return a_k
You would then call
fs = Sk/(get_ak(thisk,Nd)*Skm1)
rather than
fs = Sk/(a(thisk,Nd)*Skm1)
in the fK() function
The problem with the original lambda recursion…
a = lambda k, Nd: 1 – 3/(4*Nd) if k == 2 else a(k-1, Nd) + (1-a(k-1, Nd))/6
… is that the else condition makes 2 recursions of a(k-1, Nd), when only 1 recursion is needed, whose value can be reused in this formula. With 2 recursions, the complexity is a exponential function of k on base 2 (i.e. for k=40 there will be about 2^40 recursions). A simple reuse of the recursion makes the complexity linear on k, as follows:
def a(k, Nd):
if k == 2:
return 1 – 3/(4*Nd)
else:
previous_a = a(k-1, Nd)
return previous_a + (1-previous_a)/6
Reblogged this on d2ab.
Is Nd must bigger than 1? What if Nd == 1?
I got a KeyError which I believe exist in the range of index
File “….\kmeans_seed_deptk_0413.py”, line 144, in fK
for i in range(thisk) for c in clusters[i]])
KeyError: 2
Any ideas on how to solve this?
I changed thisk to len(mu) and then it would pop up a KeyError: 1.
I need this code in MATLAB or python to MATLAB converter. Any help?
Thanks for sharing
I am seeing best k = maxk-1 for all maxk I have tried, anyone else seeing that?