Published on

๐Ÿง  AI Exploration #7: K-Means Clustering Explained

Authors

๐Ÿง  AI Exploration #7: K-Means Clustering Explained

K-Means is one of the most widely used algorithms for unsupervised clustering. It partitions data into K distinct groups based on similarity, without requiring any labels.

In this post, youโ€™ll learn what K-Means is, how it works, the math behind it, how to choose K, and how to implement it in Python.


๐Ÿงฉ What is K-Means?

K-Means Clustering aims to group data points into K clusters such that each point belongs to the cluster with the nearest mean (centroid).


๐Ÿง  How K-Means Works (Step-by-Step)

  1. Initialize K centroids randomly
  2. Assign each point to the nearest centroid
  3. Update centroids by computing the mean of points in each cluster
  4. Repeat steps 2โ€“3 until convergence (assignments donโ€™t change)

The objective is to minimize the intra-cluster variance (or within-cluster sum of squares).


๐Ÿงฎ Objective Function

The K-Means algorithm minimizes the following loss:

J=โˆ‘i=1Kโˆ‘xโˆˆCiโˆฅxโˆ’ฮผiโˆฅ2J = \sum_{i=1}^{K} \sum_{x \in C_i} \| x - \mu_i \|^2

Where:

  • CiC_i is the set of points in cluster ii
  • ฮผi\mu_i is the centroid of cluster ii

๐Ÿ“‰ Choosing K: The Elbow Method

The number of clusters KK is a hyperparameter.

To choose KK, you can:

  • Plot inertia (sum of squared distances to nearest centroid)
  • Look for an โ€œelbowโ€ point where the gain of adding more clusters drops off

The image below shows the "elbow point" where the gain of adding more clusters drops off.

Elbow Point

Figure: Elbow point where the gain of adding more clusters drops off. Source:

Medium


๐Ÿงช Code Example: Clustering Iris Data with K-Means

from sklearn.datasets import load_iris
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# ๐Ÿ“ฅ Load Iris data
iris = load_iris()
X = iris.data
labels = iris.target
features = iris.feature_names

# ๐Ÿ” Apply KMeans clustering
kmeans = KMeans(n_clusters=3, random_state=42)
clusters = kmeans.fit_predict(X)

# ๐Ÿ“Š Visualize clusters
df = pd.DataFrame(X, columns=features)
df['Cluster'] = clusters

sns.pairplot(df, hue='Cluster', palette='Set2', corner=True)
plt.suptitle('K-Means Clustering on Iris Dataset', y=1.02)
plt.tight_layout()
plt.show()

This example clusters the Iris dataset into 3 groups without using the true species labels - demonstrating the power of unsupervised learning to discover structure.

๐Ÿ“Š The pair plot below shows how K-Means clustered the Iris dataset into three distinct groups based on feature similarities - without using the true species labels. Notably, the clusters align well with actual species, especially when petal length and petal width are involved, demonstrating the power of unsupervised learning in discovering natural structure.

๐Ÿ“ˆ Along the diagonal, each subplot is a KDE (Kernel Density Estimate) plot, which visualizes how values of a specific feature are distributed within each cluster:

  1. Each colored curve represents one cluster (e.g., Cluster 0, 1, or 2).
  2. The x-axis is the feature value (e.g., petal width), while the y-axis is the estimated density.
  3. Peaks in the KDE plots show where data points concentrate - helping you see which features best separate the clusters.
  4. If the KDE curves are clearly separated, the feature contributes strongly to the clustering.

K-Means Clustering on Iris Dataset

โœ… Pros and Cons

โœ… Advantages

  • Simple, fast, and easy to implement
  • Works well when clusters are spherical and well-separated
  • Scales to large datasets

โŒ Disadvantages

  • Must specify K in advance
  • Sensitive to initialization (can converge to local minima)
  • Struggles with irregular or non-spherical cluster shapes

๐Ÿ“Š When to Use K-Means

Use K-Means when:

  • You want fast, scalable clustering
  • You have a rough idea of how many clusters exist
  • Your data has relatively uniform variance

๐Ÿ”š Recap

K-Means is a go-to algorithm for many clustering problems. While simple, itโ€™s powerful when applied to the right kind of data and offers a great starting point for unsupervised learning tasks.


๐Ÿ”œ Coming Next

Next in this subseries of clustering techniques: DBSCAN - a density-based method for discovering clusters of arbitrary shape and detecting outliers.

Stay curious and keep exploring ๐Ÿ‘‡

๐Ÿ™ Acknowledgments

Special thanks to ChatGPT for enhancing this post with suggestions, formatting, and emojis.