Dirichlet Process Mixtures
On literally a rainy Sunday afternoon I set out to implement a Dirichlet Process mixture model (DPM). My main reason? To find clusters on streaming data. At one point, every data scientist faces this question: how many clusters do I model?. This question gets more cumbersome once you work with streaming data. As more data enters, you probably add more clusters. You can do that with heuristics, but the DPM solves it naturally.
So the DPM covers two of our concerns:
Definitely, there exist alternatives for the first concern, to pick the number of clusters. Among others: PCA, MDL, AIC, BIC. Even Wikipedia dedicates a full page to the question
The best way to understand the Dirichlet process is to understand how it generates data. Later on, we will discuss how to learn the model.
A Dirichlet process assumes it has infinitely many clusters to draw from. The best intuition I’ve heard compares this to entering a Chinese restaurant, based on the seemingly infinite supply of tables at Chinese restaurants. The analogy works as follows: The tables represent clusters, and the customers represent our data. When a person enters the restaurant, he chooses to join an existing table with probability proportional to the number of people already sitting at this table (the N_k
); otherwise, he may choose to sit at a new table k, with probability proportional to alpha
. Like the following formulas
To narrow this metaphor down to our case of data points. For every new point, it picks a cluster with the above formula. And every cluster has a Gaussian distribution associated to it. So in the code dpm.z[i]
is an integer that indicates your cluster number. Let’s call it dpm.z[i] = k
, then dpm.q[k]
is a Gaussian distribution over the datapoints for that cluster.
The DPM does not require explicitly choose the number of mixtures. That solves our first concern. Unaivoidably, the DPM does require a hyperparameter that controls the number of clusters. It is called alpha
. Now alpha relates both the numbers of clusters and the number of observations by the following formula:
Fortunately, the variance is quite big. So changing the alpha does not influence our results too much. (Source)
To give you some intuition, this plots shows the expected number of clusters when N=100
.
alpha_expected_num_clusters
We want to know two things:
Now we can calculate the mean and variance of a cluster by simple formulas: the sample mean and sample variance.
But how to find the cluster assignments?
The answer is Gibbs sampling: it is too hard to infer all cluster assignments at one time. Remember, we could have infinitely many. So we iteratively assign each point a new cluster. When that iteration becomes stable, we conlude that we found a valid sample of the cluster assignments.
Honestly, people write complete books and teach entire courses on Gibbs sampling. Or rather its overarching concept: Markov Chain Monte Carlo sampling. So I’ll make no attempt to explain it formally. For now, we assume that as we iterate long enough over our cluster assignments, we find a valid sample.
Also note that Expectation Maximization won’t help us here. The Dirichlet process is a non-parametric model that assumes an infinite amount of clusters. EM only deals with the finite case.
This is Gibbs sampling learning on a fixed dataset
The reason for doing this project was the case of streaming data. Here are two examples
You’ll see that as data comes in, the DPM learns more clusters. For the clusters it found, the streaming data makes it more confident.
This is only a small example of a stream. Concerning the stream, you can play with
main_dpm.py >> points_to_add
)main_dpm.py >> interval_to_add
)main_dpm.py >> N_start
)You’ll notice that the Gibbs sampler can make seemingly irrelevant clusters. In the gif. you’ll see a cluster with only few points assigned to it. In the logs, from line to line, you’ll see a sudden increase in clusters. This is becuase of the sampling that Gibbs sampling does. Let say we do a Gibbs sample on points x_i
, then we sample from the conditional probability p(z_i| ... )
*see line 78 dpm_class.py >> k_new = ...
. By this sampling, we can have sudden clusters with only few data points.
This also has a positive side to it: these redundant clusters allow the DPM to escape poor local minima. As opposed to EM (for the finite case) which often gets stuck in poor local minima. (Read Murphy chapter 25.2.4 for more information)
#Further reading
As always, I am curious to any comments and questions. Reach me at romijndersrob@gmail.com