K-Means Step by Step in F#

This article was originally published at tech.blinemedical.com

Recently we had a discussion on clustering techniques at one of our weekly tech talks and the k-means clustering algorithm came up. K-means is considered one of the simplest unsupervised machine learning techniques and I thought it would be cool to try my hand at an F# implementation of it.

In short, k-means is a way to put data into groups based on distance between nearest neighbors. Technically it works with any dimension but for this post I’ll stick with 1-d data to make things easy.

K-means background

Imagine you have some 1-dimensional data like

1, 2, 5, 14, 17, 19, 20

And you want to group this into two groups. Pretty easily you can see that 1, 2, and 5 go into one group, and 14, 17, 19, 20 should go in another group. But why? For this small data set we intuitively figured out that 1, 2, and 5 are close together, but at the same time far away from 14, 17, 19, and 20, which are also close together. Really what we did was we calculated the centroid of each group and assigned the values to the centroid. A centroid is a fancy way of saying center point of a group. It’s calculated by taking the average of a group.

For example, if we have the group

1, 2, 5

What is its centroid? The answer is

(1 + 2 + 5) / 3 = 2.666

So with k-means you end up with k centroids and a bunch of data grouped to that centroid. It’s grouped to that centroid because that data is closest to that centroid vs any other centroid. To calculate the centroids and grouping k-means uses an iterative process:

  • Pick k random points. So if you are doing a k=2 clustering, just pick 2 random points from your data set. These are going to be our initial centroids. It doesn’t matter if it’s right, or even if the points are the same. The algorithm is going to fine tune things for us. Technically this is called the Forgy Method of initialization. There are a bunch of other ways to initialize your centroids, but Forgy seems to be pretty common.
  • For every point, calculate its distance from the centroid. So lets say we picked points 1 and 2 as our centroids. The distances then look like this:
         Δ1      Δ2
    1    0        1
    2    1        0
    5    4        3
    14   13       12
    17   16       15
    19   18       17
    20   19       18
  • Now what we do is assign each data point to its nearest centroid. So, obviously, point 1 is closest to the initial arbitrary point 1 centroid. And you can see that basically everything else is closer to the arbitrary point 2 centroid. So we’re left with a grouping like this now:
    data around centroid 1: 1
    data around centroid 2: 2, 5, 14, 17, 19, 20
  • Now, the next step is to calculate new centroids based on these groups. For this example, we can take the average of all the points together to find the new center. So the new centroids become 1 and 12.83. Now we go back to step 2 using the new centroids. If you keep track of the current centroid and the previous centroid, then you can stop the k-means calculation when centroids don’t change between iterations. At this point everything has converged and you’ve got your k clusters

Unfortunately k-means doesn’t always converge, so you can add a convergence delta (i.e. if the centroid between the current and previous iterations hasn’t really changed much, so it’s within some acceptable range) or an iteration limit (only try to cluster 100 times) so you can set bounds on the clustering.


The first thing to do for my implementation is to define a structure representing a single data point. I created a DataPoint class that takes a float list which represents the data points dimensions. So, a 1-dimesional data point would have a float list of one element. A 2-d data point has one of two points (x and y), etc. To give credit where credit is due, I took the dimensional array representation from another F# implementation.

type DataPoint(input:float list) =
    member this.Data = input
    member this.Dimensions = List.length input
    override x.Equals(yobj) =
        match yobj with
        | :? DataPoint as y -> (x.Data = y.Data)
        | _ -> false
    static member Distance (d1:DataPoint) (d2:DataPoint) =
                                    if List.length d1.Data <> List.length d2.Data then
                                        let sums = List.fold2(fun acc d1Item d2Item ->
                                                                     acc + (d2Item - d1Item)**2.0
                                                              ) 0.0 d1.Data d2.Data

It has its data, the dimensions, an equality overload (so I can compare data points by their data and not by reference), as well as a static helper method to calculate the euclidean distance between two n-dimensional points. Here I’m using the fold2 method which can fold two lists simultaneously.

Alias data types

For fun, I wanted to use F#’s data aliasing feature, so I created some helper types to make dealing with tuples and other data pairing easier. I’ve defined the following

type Centroid = DataPoint

type Cluster = Centroid * DataPoint List

type Distance = float

type DistanceAroundCentroid = DataPoint * Centroid * Distance

type Clusters = Cluster List

type ClustersAndOldCentroids = Clusters * Centroid list

The only unfortunate thing is you need to explicitly mention aliased types for the type inference to work properly. The reason is because F# will prefer native types to typed aliases (why choose Distance over float unless you are told to), but it would’ve been nice if it had some way to prefer local type aliases in a module.

Centroid assignment

This next major step is to take the current raw data list and what the current centroids are and cluster them. A cluster is a centroid and its associated data points. To make things clearer, I created a helper function distFromCentroid which takes a point, a centroid, and returns a tuple of point * centroid * distance.

let private distFromCentroid pt centroid : DistanceAroundCentroid = (pt, centroid, (DataPoint.Distance centroid pt))

The bulk of the k-means work is in assignCentroids. First, for every data point, we calculate its distance from the current centroid. Then we select the centroid which is closest to our data points and create a list of point * centroid tuples. Once we have a list of point * centroid tuples, we want to group this list by the second value in the tuple: the centroid. Here, I’m leveraging the pipe operator to do the grouping and mapping. First, we group the lists by the centroid. This gives us a list of centroid * (datapoint * centroid) items. But this isn’t quite right yet. We want to end up with a list of centroid * dataPoint list so I passed this temporary list through some other minor filtering and formatting.

    Takes the input list and the current group of centroids.
    Calculates the distance of each point from each centroid
    then assigns the data point to the centroid that is closest.
let private assignCentroids (rawDataList:DataPoint list) (currentCentroids:Clusters) =
        let dataPointsWithCentroid =
                |> List.map(fun pt ->
                                    let currentPointByEachCentroid = List.map(fun (centroid, _) -> distFromCentroid pt centroid) currentCentroids
                                    let (_, nearestCentroid, _) = List.minBy(fun (_, _, distance) -> distance) currentPointByEachCentroid
                                    (pt, nearestCentroid)

            |> Group all the data points by their centroid
            |> Select centroid * dataPointList
            |> For each previous centroid, calculate a new centroid based on the aggregated list

        List.toSeq dataPointsWithCentroid
            |> Seq.groupBy (fun (_, centr) -> centr)
            |> Seq.map(fun (centr, dataPointsWithCentroid) ->
                                let dataPointList = Seq.toList (Seq.map(fun (pt, _) -> pt) dataPointsWithCentroid)
                                (centr, dataPointList)
            |> Seq.map(fun (cent, dataPointList) ->
                                let newCent = calculateCentroidForPts dataPointList
                                (newCent, dataPointList)
            |> Seq.toList

Calculating new centroids

Since we’ve grouped data points by centroid, we now need to calculate what the new centroid for that list is. This is where I employ a disclaimer. For my implementation of n-dimensional centroids I averaged each dimension together to get a new value. I’m not totally sure this is actually valid for arbitrary geometrical shapes (given by n-dimensions), but if it’s not at least this is the only method that needs to be updated.

First I create an empty n-dimension array of all zeros, initialized by the dimensions of the first point in the list (all the dimensions should be the same so it doesn’t matter which point we choose). This is the seed for list for the next fold. I’m not actually going to fold the value into a single element; I’m going to fold all indexes of the data point list together into a new array by adding each dimension together.

For example, if I have two points of two dimensions [1, 2] and [3, 4] I add 1 and 3 together to make the first dimension, then 2 and 4 together to make the second dimension. This gives me an intermediate vector of [4,6]. Then I can average each dimension by the number of data points used, so in our example we used two data points, so the new centroid would be [2, 3]. Since a centroid is the same as a data point (even though one that doesn’t exist in our original raw set), we can return the new centroid as a data point. Having everything treated as data points makes the calculations easier.

    take each float list representing an n-dimesional space for every data point
    and average each dimension together.  I.e. element 1 for each point gets averaged together
    and element 2 gets averaged together.  This new float list is the new nDimesional centroid

let private calculateCentroidForPts (dataPointList:DataPoint List) =
    let firstElem = List.head dataPointList
    let nDimesionalEmptyList = List.init firstElem.Dimensions (fun i-> 0.0)
    let addedDimensions = List.fold(fun acc (dataPoint:DataPoint) -> addListElements acc dataPoint.Data (+))

    // scale the nDimensional sums by the size
    let newCentroid = List.map(fun pt -> pt/((float)(List.length dataPointList))) addedDimensions
    new DataPoint(newCentroid)

Here is the function to aggregate list elements together

(*  goes through two lists and adds each element together
    i.e. index 1 with index 1, index 2 with index2
    and returns a new list of the items aggregated

let private addListElements list1 list2 operator =
    let rec addListElements' list1 list2 accumulator operator =
        if List.length list1 = 0 then
            let head1 = List.head list1
            let head2 = List.head list2
            addListElements' (List.tail list1) (List.tail list2) ((operator head1 head2)::accumulator) operator
    addListElements' list1 list2 [] operator

I’m using a hidden accumulator so that the function will be tail recursive.

Cluster runner

The above code only does one single cluster iteration though . What we need is a clustering runner that does the clustering iterations. This runner will be responsible for determining when to stop the clustering. There are four functions here:

  • getClusters and getOldCentroids are helper methods to pull data out of the tuples (fst and snd are built in functions)
  • cluster is a basic method that will cluster as many times as necessary until the centroids stop changing
  • clusterWithIterationLimit is the bulk method that takes in the raw data, the clustering amount (k), the max number of clustering iterations, and a convergence delta to avoid infinite loops

For each iteration create a new set of clusters, then test to see if if the centroids are the same. If they are, we’ve converged and we can end. If not, see if the clusters are “close enough” (the convergence delta). The final base case is if we’ve clustered the max iterations then return.

    Start with the input source and the clustering amount
    but also take in a max iteration limit as well as a converge

    continue to cluster until the centroids stop changing or
    the iteration limit is reached OR the converge delta is achieved

    Returns a new "Clusters" type representing the centroid
    and all the data points associated with it

let private getClusters (cluster:ClustersAndOldCentroids) = fst cluster
let private getOldCentroids (cluster:ClustersAndOldCentroids) = snd cluster

let clusterWithIterationLimit data k limit delta =
    let initialClusters:Clusters = initialCentroids (List.toArray data) k
                                    |> Seq.toList
                                    |> List.map (fun i -> (i,[]))

    let rec cluster' (clustersWithOldCentroids:ClustersAndOldCentroids) count =
        if count = 0 then
            getClusters clustersWithOldCentroids
            let newClusters = assignCentroids data (getClusters clustersWithOldCentroids)
            let previousCentroids = getOldCentroids clustersWithOldCentroids
            let newCentroids = extractCentroidsFromClusters newClusters

            // clusters didnt change
            if listsEqual newCentroids previousCentroids then
            else if centroidsDeltaMinimzed previousCentroids newCentroids delta then
                cluster' (newClusters, newCentroids) (count - 1)

    cluster' (initialClusters, extractCentroidsFromClusters initialClusters) limit

    Start with the input source and the clustering amount.
    continue to cluster until the centroids stop changing
    Returns a new "Clusters" type representing the centroid
    and all the data points associated with it

let cluster data k = clusterWithIterationLimit data k Int32.MaxValue 0.0


Using the sample data from the beginning of this post, let’s run it through my k-means clusterer:

module KDataTest

open System
open KMeans

let kClusterValue = 2

let sampleData : KMeans.DataPoint list = [new DataPoint([1.0]);
                                          new DataPoint([2.0]);
                                          new DataPoint([5.0]);
                                          new DataPoint([14.0]);
                                          new DataPoint([17.0]);
                                          new DataPoint([19.0]);
                                          new DataPoint([20.0])]

KMeans.cluster sampleData kClusterValue
    |> Seq.iter(KMeans.displayClusterInfo)

Console.ReadKey() |> ignore

And the output is

Centroid [2.66666666666667], with data points:
[1], [2], [5]

Centroid [17.5], with data points:
[14], [17], [19], [20]

Full source available at our github

Post a comment

You may use the following HTML:
<a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <s> <strike> <strong>