LDA using Gibbs sampling in R

LDA using Gibbs sampling with R

The setting

Latent Dirichlet Allocation (LDA) is a text mining approach made popular by David Blei. I find it easiest to understand as clustering for words. The idea is that each document in a corpus is made up by a words belonging to a fixed number of topics. These topics are unobserved/latent, but if we could estimate them, we could describe and relate the documents by their topics instead of their raw text.

The papers I’ve seen on topic modeling and LDA (in the text mining and marketing literature) haven’t helped me much in understanding the process of estimating the actual model. They usually show a nice picture of the random process, throw in the topic and word probability equations and mention that they use Gibb sampling. Implenting a basic version of the full model in R has helped me to understand what’s happening under the hood. It ows a lot to this blog entry and its naive Python implementation (and example data) as well as this report featuring derivations and pseudo-code. I recommend the first eight pages of the report as a theoretical introduction.

To keep things simple, I’ll use the example from the blog. The documents are users, each of which has added an arbitrary number of interests to their profile. The interests are our word tokens. If we were working with text, these could be the words in blog articles after removing all stop words.

documents = list(
c("Hadoop", "Big Data", "HBase", "Java", "Spark", "Storm", "Cassandra"),
c("NoSQL", "MongoDB", "Cassandra", "HBase", "Postgres"),
c("Python", "scikit-learn", "scipy", "numpy", "statsmodels", "pandas"),
c("R", "Python", "statistics", "regression", "probability"),
c("machine learning", "regression", "decision trees", "libsvm"),
c("Python", "R", "Java", "C++", "Haskell", "programming languages"),
c("statistics", "probability", "mathematics", "theory"),
c("machine learning", "scikit-learn", "Mahout", "neural networks"),
c("neural networks", "deep learning", "Big Data", "artificial intelligence"),
c("statistics", "R", "statsmodels"),
c("C++", "deep learning", "artificial intelligence", "probability"),
c("pandas", "R", "Python"),
c("databases", "HBase", "Postgres", "MySQL", "MongoDB"),
c("libsvm", "regression", "support vector machines")
)

The model

For completeness, this is the mathematical description of the mdoel. The generative process for a document collection D under the LDA model is this:

1. For $k = 1, \dots,K$:
1. $\phi^{(k)} \sim Dirichlet(\beta)$
1. For each document $d \in D$:
1. $\theta_d \sim Dirichlet(\alpha)$
2. For each word $w_i \in d$:

1. $z_i \sim Discrete(\theta_d)$
2. $w_i \sim Discrete(\phi^{z_i})$

where $K$ is the number of latent topics in the collection, $\phi^{(k)}$ is a discrete probability distribution over a fixed vocabulary that represents the $k$th topic distribution, $\theta_d$ is a document-specific distribution over the available topics, $z_i$ is the topic index for word $w_i$, and $\alpha$ and $\beta$ are hyperparameters for the symmetric Dirichlet distributions that the discrete distributions are drawn from. Symmetric means that $\alpha$ and $\beta$ are the same for all documents and topics, respectively.

An important intuition is here that each topic is really nothing more than a probability distribution over words. This means that each word can occur in each topic, but many words will occur only with a very small probability. It also means that we don’t get a name to describe the topic, only a list of common words within the topic. It’s possible to then choose descriptive names, but this may be difficult and very subjective.

Estimating the distribution

The task is now to estimate the document-topic and topic-word distributions from the words that we observe in each document. The popular method to do this is Gibbs sampling which belongs to the Markov Chain Monte Carlo algorithms. It works by going through all words in all documents guessing a topic for each word based on the topics guessed for the other words in the document. While going through the words, the changes made for previous words and during previous runs will change the distirbutions underlying the guesses. After some time, the distributions will become stable, that is they won’t change much anymore. At this point, we can use the last or an average over the last distributions as the final output, for example to find topics in new documents.

As a starting point, we’ll need a function to randomly choose a topic for a word based on a set of weights. The weights capture how common we think the topic is. Always remember that what we optimistically call a ‘topic’ is just the number given to a distribution over words.

### Function to randomly choose a topic for a word given a weight distribution

sample_from <- function(weights){
topic <- sample(x = 1:length(weights), size = 1, prob = weights/sum(weights))
return(topic)
}

# Try drawing 10000 times from 3 topics
# Topic 3 is 8 times more likely than topic 1 and 2
test <- replicate(10000, sample_from(c(10, 10, 80)))
table(test)
## test
##    1    2    3
##  981  970 8049

The weights describe how often we expect to see this word in this document. Our guess on this chance depends on two things:

1. how often the word appears for each topic and
2. how often each topic already appears in the document.

When does the word ‘topic’ come up a lot this blog post? Because the post is talks a lot about the topic ‘LDA’ and the word ‘topic’ is common for this topic. These two conditional probability functions determine which topic we choose during sampling. We include a smoothing term that ensures every topic has a nonzero chance of being chosen in any document and that every word has a nonzero chance of being chosen for any topic.

# The ratio of this word within the words in this topic
# (plus some smoothing)
p_word_given_topic <- function(word, topic, beta = 0.1){
return ((topic_word_counts[word, topic] + beta) /
(topic_counts[topic] + nrow(topic_word_counts) * beta))
}

# The ratio of this topics within the topics in this document
# (plus some smoothing)
p_topic_given_document <- function(topic, d, alpha = 0.1){
return ((document_topic_counts[d, topic] + alpha) /
(document_lengths[d] + ncol(document_topic_counts) * alpha))
}

To be able to calculate these probabilities, we need to decide three things in advance:

1. the number topics,
2. how specific the words are to each topic, and
3. how specific the topics are to the documents

These can be guessed, determined via heuristics, or found by testing a bunch of values.

# The number of topics
K <- 4
# The smoothing parameters
# Higher alpha: topics will put more weight on many words
alpha <- 0.05
# Higher beta: documents will contain more topics
beta <- 0.1

We’ll combine the conditional probabilities to create the weights for guessing the topic for a word. Blei et al. (2003) give the proofs why the topic weight should be calculated in this way, but I feel it’s also pretty intuitive.

# Calculate the topic weights for topic k
# given document d and word
topic_weight <- function(d, word, k){
return(p_word_given_topic(word, k, beta) * p_topic_given_document(k, d, alpha))
}

# Choose a new topic based on the weights of the topics for this word
choose_new_topic <- function(d, word){
return(sample_from(sapply(1:K, function(k) topic_weight(d, word, k))))
}

Putting things together: Gibbs sampling

During sampling, we’ll calculate the conditional probabilities for a word within a topic and a topic within a document as defined above. So we need to keep track of how often each word appears for each topic and how often each topic appears in each document.

Since the probabilities depend on the total number of words within the topic and the topics within the document in the denominator, we’ll also keep track of these numbers for convenience.

# Get the distinct words
distinct_words <- unique(unlist(documents))

# Matrix to count how often each topic appears in each document
document_topic_counts <- matrix(0, nrow = length(documents), ncol = K,
dimnames = list(1:length(documents), 1:K))
# Matrix to count how many times each word is assigned to each topic
topic_word_counts     <- matrix(0, nrow = length(distinct_words), ncol = K,
dimnames = list(distinct_words, 1:K))

# How often is each topic assigned to a word over all documents
topic_counts          <- rep(0, K)
# How many words does it document contain
document_lengths      <- sapply(documents, length)

Before we start sampling, we randomly pick a topic for each word in each document.

# Set a seet for the random number generator
# for replicability
set.seed(123)
# For each word, choose a topic at random
document_topics = lapply(documents, function(doc) sapply(doc, function(wor) sample(K, 1)))

Then we can fill the count matrices from above with the assignments that we have just created randomly.

# Go through all documents
for(d in seq(length(documents))){
# Collect all the words and their associated topic in a data frame
wordTopic <- data.frame(word = documents[[d]], topic = document_topics[[d]],
stringsAsFactors = FALSE)

# Go through all the words
for(i in 1:nrow(wordTopic)){
# Increase the counter for this topic for this word
topic_word_counts[wordTopic$word[i], wordTopic$topic[i]] <-
topic_word_counts[wordTopic$word[i], wordTopic$topic[i]] + 1

# Increase the counter for this topic (in aggregate)
topic_counts[wordTopic$topic[i]] <- topic_counts[wordTopic$topic[i]] + 1

# Increase counter for this topic within this document
document_topic_counts[d, wordTopic$topic[i]] <- document_topic_counts[d, wordTopic$topic[i]] + 1
}
}

We now have some counts to estimate the probabiliy of each word occuring in each topic and each topic occuring in each document (even though they are based on our naive random assignment). Based on the conditional probabilities, we can calculate the weights of the topics and make a new guess for each word. The promise of Gibbs sampling is that if we do this often enough, we will start making good guesses.

# The number of iterations
I <- 2000
# Save our probability estimates after each iteration
# in these lists
topic_mixture_per_document <- list()
word_mixture_per_topic <- list()

# Repeat the process
for(iter in seq(I)){
# Give feedback on the progress
# This is a little hack that works by printing
# the backspace character x times before
# writing the new line
if(iter %% 100 == 0){
cat(paste0(rep("\b", 20), collapse = ""))
cat("Iteration: ", iter)
flush.console() # Force output print
}

# Go through the documents....
for(d in seq(length(documents))){
wordTopic <- data.frame(word = documents[[d]], topic = document_topics[[d]], stringsAsFactors = FALSE)

# ... and the words within the document
for(i in seq(nrow(wordTopic))){
# Define the word and topic explicitely for convenience and speed
word <- wordTopic$word[i] topic <- wordTopic$topic[i]

# Our probability estimate should not be based on the word
# that we are looking at. We remove this word / topic from the counts
# so that it doesn't influence the weights
document_topic_counts[d, topic] <- document_topic_counts[d, topic] - 1
topic_word_counts[word, topic] <- topic_word_counts[word, topic] - 1
topic_counts[topic] <- topic_counts[topic] - 1
document_lengths[d] <- document_lengths[d] - 1

# Choose a new topic for the word based on the weights
new_topic <- choose_new_topic(d, word)
document_topics[[d]][[word]] <- new_topic

# Redo the counts to account for the new assignment
document_topic_counts[d, new_topic] <- document_topic_counts[d, new_topic] + 1
topic_word_counts[word, new_topic] <- topic_word_counts[word, new_topic] + 1
topic_counts[new_topic] <- topic_counts[new_topic] + 1
document_lengths[d] <- document_lengths[d] + 1
}
}
# Save the probability estimates after the iteration
topic_mixture_per_document[[iter]] <- sweep(document_topic_counts + beta,
rowSums(document_topic_counts + beta),
MARGIN = 1, FUN = "/")
word_mixture_per_topic[[iter]] <- sweep(topic_word_counts + alpha,
rowSums(topic_word_counts + alpha),
MARGIN = 1, FUN = "/")

}
## Iteration:  100Iteration:  200Iteration:  300Iteration:  400Iteration:  500Iteration:  600Iteration:  700Iteration:  800Iteration:  900Iteration:  1000Iteration:  1100Iteration:  1200Iteration:  1300Iteration:  1400Iteration:  1500Iteration:  1600Iteration:  1700Iteration:  1800Iteration:  1900Iteration:  2000

There are several ways to determine a good estimate for the distributions. The simplest is to work with the probability estimate determined in the last iteration. Since our estimate will approximate the true posterior distribution with each iteration being a little off, we can also average our estimates to make them more robust. When estimates over iterations are averaged, the first iterations are usually dropped, since they still depend a lot on the randomly chosen starting point. Another trick to make the estimate more robust is to keep only each n-th iteration to reduce the correlation between the estimates.

# Discard the first n iterations when aggregating the esimate
burn_in <- 1000
# For the remaining, aggregate only each n-th iteration
thinning <- 50 # sampling lag

# Select the iterations that we want to aggregate
temp_tmpd <- topic_mixture_per_document[seq(burn_in, length(topic_mixture_per_document), thinning)]
# Average the estimates for the document-topic distribution
topic_mixture_per_document_final <- Reduce("+", temp_tmpd) / length(temp_tmpd)

# Select the iterations that we want to aggregate
temp_wmpt <- word_mixture_per_topic[seq(burn_in, length(word_mixture_per_topic), thinning)]
# Average the estimates for the word-topic distribution
word_mixture_per_topic_final <- Reduce("+", temp_wmpt) / length(temp_wmpt)

We now have the results, which brings us back to the intuition from the beginning. Each topic is just a vector of probabilities describing how likely each word is to occur within the topic.

##                   1         2         3         4
## Hadoop   0.02272727 0.1309524 0.3690476 0.4772727
## Big Data 0.01562500 0.1644345 0.3727679 0.4471726
## HBase    0.03050595 0.1495536 0.4025298 0.4174107

If we want to cluster documents, find similar articles, or make a prediction based on the text, this doesn’t concern us. But for interpretation there we can assign a topic title by looking at the common words or use one of a couple of complicated heuristics (e.g. Wikipedia) to assign a topic title automatically.

# Make a data frame that contains the words sorted
# by probability for each topic
wordTopicFreq <- apply(word_mixture_per_topic_final, 2, function(x)
row.names(word_mixture_per_topic_final)[order(x, decreasing = TRUE)])
##      1             2                         3           4
## [1,] "statsmodels" "libsvm"                  "NoSQL"     "MapReduce"
## [2,] "Python"      "decision trees"          "MongoDB"   "Hadoop"
## [3,] "R"           "support vector machines" "Postgres"  "Big Data"
## [4,] "pandas"      "regression"              "Cassandra" "Spark"
## [5,] "numpy"       "machine learning"        "HBase"     "Storm"

There are several nice applications I can image for this kind of analysis. We could, for example, make clusters of users sorted by general topics rather than the very specific tags that they have chosen. We could also display these results to head hunters who have no idea what many of the specific tags even mean.

topic_names <- c("statistical programming",
"machine learning",
"data engineering",
"artificial intelligence"
)

# Give some cleaned up output on the interests
# of each user
for(i in seq(length(documents))){
# Sort the topics of the user
topicRanking <- order(topic_mixture_per_document_final[i,], decreasing = TRUE)
# Display...
cat(sprintf("User %i: \n %s \n %s",
i, # ... the user name,...
paste(documents[[i]], collapse = ", "), # ... her interests...,
# ... and aggregated topics in order together with an importance count.
paste(
paste(topic_names[topicRanking],
round(topic_mixture_per_document_final[i,],2)[topicRanking],
sep = ": "),
collapse = ", ")
),
"\n\n")
}
## User 1:
##  Hadoop, Big Data, HBase, Java, Spark, Storm, Cassandra
##  artificial intelligence: 0.47, data engineering: 0.36, machine learning: 0.16, statistical programming: 0.01
##
## User 2:
##  NoSQL, MongoDB, Cassandra, HBase, Postgres
##  data engineering: 0.42, artificial intelligence: 0.4, machine learning: 0.14, statistical programming: 0.04
##
## User 3:
##  Python, scikit-learn, scipy, numpy, statsmodels, pandas
##  statistical programming: 0.77, artificial intelligence: 0.09, machine learning: 0.07, data engineering: 0.07
##
## User 4:
##  R, Python, statistics, regression, probability
##  statistical programming: 0.72, machine learning: 0.19, data engineering: 0.06, artificial intelligence: 0.04
##
## User 5:
##  machine learning, regression, decision trees, libsvm
##  machine learning: 0.64, artificial intelligence: 0.23, data engineering: 0.1, statistical programming: 0.03
##
## User 6:
##  Python, R, Java, C++, Haskell, programming languages
##  statistical programming: 0.59, artificial intelligence: 0.2, machine learning: 0.1, data engineering: 0.1
##
## User 7:
##  statistics, probability, mathematics, theory
##  statistical programming: 0.61, machine learning: 0.21, data engineering: 0.12, artificial intelligence: 0.07
##
## User 8:
##  machine learning, scikit-learn, Mahout, neural networks
##  machine learning: 0.41, artificial intelligence: 0.33, data engineering: 0.14, statistical programming: 0.12
##
## User 9:
##  neural networks, deep learning, Big Data, artificial intelligence
##  artificial intelligence: 0.41, data engineering: 0.37, machine learning: 0.2, statistical programming: 0.02
##
## User 10:
##  Hadoop, Java, MapReduce, Big Data
##  artificial intelligence: 0.49, data engineering: 0.35, machine learning: 0.11, statistical programming: 0.06
##
## User 11:
##  statistics, R, statsmodels
##  statistical programming: 0.88, machine learning: 0.04, data engineering: 0.04, artificial intelligence: 0.03
##
## User 12:
##  C++, deep learning, artificial intelligence, probability
##  artificial intelligence: 0.31, machine learning: 0.26, data engineering: 0.23, statistical programming: 0.2
##
## User 13:
##  pandas, R, Python
##  statistical programming: 0.91, machine learning: 0.03, data engineering: 0.03, artificial intelligence: 0.03
##
## User 14:
##  databases, HBase, Postgres, MySQL, MongoDB
##  data engineering: 0.42, artificial intelligence: 0.37, machine learning: 0.15, statistical programming: 0.06
##
## User 15:
##  libsvm, regression, support vector machines
##  machine learning: 0.67, artificial intelligence: 0.16, data engineering: 0.13, statistical programming: 0.04