Jekyll2022-10-13T20:30:11+00:00http://adamlineberry.ai/feed.xmlAdam LineberryData Science BlogAdam LineberryXGBoost2022-10-07T00:00:00+00:002022-10-07T00:00:00+00:00http://adamlineberry.ai/notes/xgboost<h2 id="derivation-of-xgboost">Derivation of XGBoost</h2>
<p>Consider a dataset of \(n\) observations \(\mathcal{D} = \{ (x_i, y_i) \}_{i=1}^n\) where \(x_i \in \mathbb{R}^d\) are the observed features and \(y_i \in \mathbb{R}\) are the targets.</p>
<p>An additive tree model produces a prediction by summing the predictions of \(K\) trees:</p>
\[\hat{y_i} = \sum_{k=1}^K f_k(x_i)
\\
f(x) = w_{q(x)} \quad q : \mathbb{R}^T \rightarrow T \quad w \in \mathbb{R}^T\]
<p>Where \(w\) is a vector of \(T\) leaf scores and \(q(x)\) maps data to a leaf node index.</p>
<p>The objective used to learn the model consists of a loss function that measures the difference between predictions \(\hat{y_i}\) and targets \(y_i\) combined with a regularization term \(\Omega\) that penalizes model complexity:</p>
\[\mathcal{L} = \sum_i l(y_i, \hat{y_i}) + \sum_k \Omega(f_k)\]
<p>The regularization term for a single tree is defined as:</p>
\[\Omega(f_t) = \gamma T + \frac{1}{2} \lambda \sum_{j=1}^T w_j^2\]
<p>The model is learned in a greedy, additive fashion, learning a new tree at each iteration. The prediction and corresponding objective at iteration \(t\) is as follows:</p>
\[\begin{align*}
\hat{y}^{(t)} &= \hat{y}_i^{(t-1)} + f_t(x_i)
\\
\mathcal{L}^{(t)} &= \sum_i l(y_i, \hat{y}_i^{(t-1)} + f_t(x_i)) + \Omega(f_t)
\end{align*}\]
<p>The objective is approximated with 2nd order Taylor expansion:</p>
\[\mathcal{L}^{(t)} \simeq \sum_i \Bigl[ l(y_i, \hat{y}_i^{(t-1)}) + g_i f_t(x_i) + \frac{1}{2} h_i f_t^2(x_i) \Bigr] + \Omega(f_k) \\
g_i = \frac{\partial l(y_i, \hat{y}_i^{(t-1)})}{\partial \hat{y}_i^{(t-1)}}
\quad
h_i = \frac{\partial^2 l(y_i, \hat{y}_i^{(t-1)})}{\partial \hat{y}_i^{(t-1)}}\]
<p>Remove constants with respect to the tree being learned \(f_t\):</p>
\[\tilde{\mathcal{L}}^{(t)} = \sum_i \Bigl[ g_i f_t(x_i) + \frac{1}{2} h_i f_t^2(x_i) \Bigr] + \gamma T + \frac{1}{2} \lambda \sum_{j=1}^T w_j^2\]
<p>Next, let’s define \(I_j = \{ i \lvert q(x_i) = j \}\) as the set of instance indices belonging to leaf node \(j\) and rearrange the objective to loop over leaf nodes instead of instances:</p>
\[\tilde{\mathcal{L}}^{(t)} = \sum_j \Bigl[
(\sum_{i \in I_j} g_i) w_j
+ \frac{1}{2} (\sum_{i \in I_j} h_i + \lambda) w_j^2
\Bigr] + \gamma T\]
<p>From here, we want to solve for optimal leaf scores \(w_j\) that minimize the objective. The loss contributed by a particular leaf \(j\) is:</p>
\[\mathcal{L}_j = (\sum_{i \in I_j} g_i) w_j
+ \frac{1}{2} (\sum_{i \in I_j} h_i + \lambda) w_j^2\]
<p>And the optimal leaf score for that leaf is the value that minimizes that leaf’s contribution to the overall objective \(w_j^* = \arg \min_{w_j} \: \mathcal{L}_j\).</p>
<p>Set derivative equal to zero and solve for \(w_j^*\)</p>
\[\frac{\partial \mathcal{L}_j}{\partial w_j^*} = 0 =
\sum_{i \in I_j} g_i + (\sum_{i \in I_j} h_i + \lambda) w_j^*
\\
w_j^* = - \frac{\sum_{i \in I_j} g_i}
{\sum_{i \in I_j} h_i + \lambda}\]
<p>We can plug in \(w_j^*\) into the objective to obtain a scoring function to measure the quality of a tree. This score is similar to an impurity measure, but more generalized:</p>
\[\begin{align*}
\tilde{\mathcal{L}}^{(t)} &= \sum_j \Bigl[
(\sum_{i \in I_j} g_i) w_j
+ \frac{1}{2} (\sum_{i \in I_j} h_i + \lambda) w_j^2
\Bigr] + \gamma T
\\
&= \sum_j \Bigl[
(\sum_{i \in I_j} g_i) \Biggl(- \frac{\sum_{i \in I_j} g_i}
{\sum_{i \in I_j} h_i + \lambda} \Biggr)
+ \frac{1}{2} (\sum_{i \in I_j} h_i + \lambda)
{\Biggl(- \frac{\sum_{i \in I_j} g_i}
{\sum_{i \in I_j} h_i + \lambda} \Biggr)}^2
\Bigr] + \gamma T
\\
&= \sum_j \Bigl[
-\frac{(\sum_{i \in I_j} g_i)^2}
{\sum_{i \in I_j} h_i + \lambda}
+ \frac{1}{2} \frac{(\sum_{i \in I_j} g_i)^2}
{\sum_{i \in I_j} h_i + \lambda}
\Bigr] + \gamma T
\\
&= - \frac{1}{2} \sum_{j=1}^T
\frac{(\sum_{i \in I_j} g_i)^2}
{\sum_{i \in I_j} h_i + \lambda}
+ \gamma T
\end{align*}\]
<p>It is impossible to enumerate all possible tree structures, so we use a greedy algorithm that builds tree \(f_t\) by iteratively adding branches (i.e., chooses a feature and a corresponding split value) according to the magnitude of loss reduction:</p>
\[\mathcal{L}_{split} = \mathcal{L}_{left} + \mathcal{L}_{right} - \mathcal{L}_{root}\]
<h2 id="xgboost-for-binary-classification">XGBoost for Binary Classification</h2>
<p>The equations above are general; we are interested in deriving them for the binary classifcation task. Equations of interest are the optimal leaf score and the tree quality score. For these, we first need gradient statistics \(g_i\) and \(h_i\) for a particular datapoint.</p>
<p>Shorten the notation for this section by dropping the \(i\) subscripts.</p>
<p>We will use negative log likelihood (a.k.a. cross entropy) as our loss function for binary classifcation</p>
\[J = -[y log(p) + (1-y) log(1-p)]\]
<p>The gradient statistics are the first and second derivatives of the loss function with respect to the previous iteration’s model out. An important note is that the model is learned in log odds space rather than probability space \(\hat{y} = log(\frac{p}{1-p})\). Also note that inversely, \(p = \frac{e^{\hat{y}}}{1+e^{\hat{y}}} = \frac{1}{1+e^{-\hat{y}}}\).</p>
<p>First, we must re-express the loss as a function of log odds:</p>
\[\begin{align*}
l(y, \hat{y}) &= -[y log(p) + (1-y) log(1-p)]
\\
&= -y \hat{y} + log(1 + e^{\hat{y}})
\end{align*}\]
<p>Now, take its derivative:</p>
\[\begin{align*}
g &= \frac{dJ}{d \hat{y}}
\\
&= \frac{d}{d \hat{y}} (-y \hat{y}) + \frac{d}{d \hat{y}} log(1+e^{\hat{y}})
\\
&= -y + \frac{1}{1+e^{\hat{y}}}\frac{d}{d \hat{y}}(1 + e^{\hat{y}})
\\
&= -y + \frac{e^{\hat{y}}}{1+e^{\hat{y}}}
\\
&= p - y
\end{align*}\]
<p>Now, calculate the second derivative:</p>
\[\begin{align*}
h &= \frac{d^2 J}{d \hat{y}^2}
\\
&= \frac{d g}{d \hat{y}}
\\
&= \frac{e^{\hat{y}}}{1 + e^{\hat{y}}} \frac{1}{1 + e^{\hat{y}}}
\\
&= p (1 - p)
\end{align*}\]
<h3 id="tree-quality-score-intuition">Tree quality score intuition</h3>
<p>The tree quality score for binary classification can be expressed as:</p>
\[\tilde{\mathcal{L}}^{(t)} = - \frac{1}{2} \sum_{j=1}^T
\frac{(\sum_{i \in I_j} (p_i^{t-1} - y_i))^2}
{\sum_{i \in I_j} p_i^{t-1} (1 - p_i^{t-1}) + \lambda}
+ \gamma T\]
<p>When the algorithm is evaluating ways to split a parent node into two child nodes, it will choose the split that minimizes the above expression. That is, maximizing the numerator and minimizing the denominator</p>
<h4 id="gradient-intuition">Gradient intuition</h4>
<p>It will want to define the split such that the numerator (gradient) term \((\sum_{i \in I_j} (p_i^{t-1} - y_i))^2\) is maximized for a child node (technically, maximizing the sum of this term for both children). Given that the \(p_i\)’s are all somewhere \(\in [0, 1]\), this term is maximized when all instances in a node are of the same class (i.e., all \(y_i = 1\) or all \(y_i = 0\)). This makes intuitive sense because we know a decision tree should be built in such a way that leaf nodes become more and more pure.</p>
<h4 id="hessian-intuition">Hessian intuition</h4>
<p>Similarly, to minimize the denominator (Hessian), all \(p_i^{t-1}\)’s should be very near \(0\) or \(1\) (the max of this function is at \(p_i^{t-1} = 0.5\)). Such extreme instance probabilities can be considered “mature” since the previous iteration of the model was quite certain of its predictions. The instance-level Hessians are also sometimes viewed as “weights” – if the model is already quite certain of its prediction for a particular instance, that instance should not be given as much consideration for the remaining training iterations. If the model is already certain about a set of instances, the algorithm has no problem with blowing up the logits even further to minimize the loss. This behavior can lead to overfitting, however. The lambda \(\lambda\) term in the denominator is usually set to \(1\) prevents the term from exploding to a very large number (nearing infinity), and the xgboost package also supplies a <code class="language-plaintext highlighter-rouge">min_child_weight</code> hyperparameter that limits splitting to nodes that have a hessian sum greater than the specified value.</p>
<h2 id="optimal-leaf-score-intuition">Optimal leaf score intuition</h2>
<p>The optimal leaf score for leaf \(j\) in the binary classification setting can now be expressed as:</p>
\[w_j^* = - \frac{\sum_{i \in I_j} (p_i^{t-1} - y_i)}
{\sum_{i \in I_j} p_i^{t-1} (1 - p_i^{t-1}) + \lambda}\]
<p>Remember that \(w_j^*\) is a logit \(\in (-\inf, \inf)\), not a probability.</p>
<p>The numerator (gradient) is basically saying to set the score to the average residual of the previous model’s predictions on that set of instances, which makes intuitive sense.</p>
<p>The denominator (hessian) is saying to blow this score up if the previous model was already certain about these instances, or step a little more carefully if it was uncertain.</p>Adam LineberryXGBoost Derivation +Getting Up to Speed with SHAP for Model Interpretability2020-01-17T00:00:00+00:002020-01-17T00:00:00+00:00http://adamlineberry.ai/shap<p class="notice">In this post I’ll give a brief overview of SHAP and explain why you should probably add it to your data science toolkit. I’ll also provide some explanation of the basic concepts from game theory that SHAP is built on. The intent of this post isn’t to provide a full treatment of SHAP, but rather to get a practitioner up to speed on the value proposition and theoretical intuition.</p>
<p>SHAP is an effective approach to model interpretability and explainability. It builds on the concept of Shapley Values, which was derived by Lloyd Shapley in the 1950’s. But recent advances in 2017<sup>1</sup> and 2018<sup>2</sup> papers and the introduction of the SHAP Python library<sup>6</sup> have brought Shapley Values into mainstream machine learning.</p>
<p>The important issues in model interpretability can be broken down into two parts:</p>
<ol>
<li><strong>Global Interpretability:</strong> Which features are important to my model? Understanding this gives us a better insight into the underlying process that’s being modeled. It’s also useful for feature selection: if a feature is unimportant, remove it. A less well-known but extremely handy use is to check for data leakage.</li>
<li><strong>Local Interpretability:</strong> Why did my model make the prediction that it did on a particular data instance (data point)? In many applications, the prediction alone is not enough; it needs to be accompanied by an explanation. In some cases the explanation is required by law. Sometimes the business isn’t interested in the predictions but rather the driving factors; in these cases the model is trained solely for the explanation.</li>
</ol>
<p>SHAP provides an elegant solution for both. Some data scientists may have experience with using Random Forest feature importances for global interpretability, and instance-level logistic regression activations for local interpretability. SHAP seems to trump these classical approaches in most cases.</p>
<p>From a theoretical standpoint, SHAP is well-grounded and one of the most robust model explanation method out there. SHAP provides some theoretical guarantees that support its trustworthiness: local accuracy, missingness, and consistency. I’ll refer the reader to the excellent references below for more details.</p>
<p>From a practical standpoint, the TreeSHAP implementation is fast and integrates seamlessly with cutting-edge tree-based models like xgboost and lightgbm. The nice integration with highly performant GBT models is perhaps its strongest selling point. You can generate a variety of highly informative plots such as the one shown below. A data scientist can learn much more from these plots than standard feature importance plots (plus, SHAP can generate feature importance plots too).</p>
<p><img src="http://adamlineberry.ai/images/shap/summary-plot.png" alt="" class="align-center" /></p>
<figcaption>Sample SHAP Summary Plot<sup>5</sup></figcaption>
<h2 id="shapley-values">Shapley Values</h2>
<p>Here’s the thought experiment that motivates Shapley Values: suppose three people enter a room and play a cooperative game that yields a payout, let’s say $100. Let’s call the players Player A, Player B, and Player C. Assuming each person has a different skillset and contributes in different ways, how should the payout be divided between the players in a fair way? For the purposes of this post, let’s focus on how much payout Player A should receive.</p>
<p>We could approach the problem in an iterative way:</p>
<ol>
<li>Have A enter the room alone, play the game, and record the payout.</li>
<li>Have B enter the room with A, play the game, and record the payout.</li>
<li>Have C enter the room with A and B, play the game, and record the payout.</li>
</ol>
<p>The additional amount each player contributed to the payout when they entered the room could be used to split up the winnings. But there’s a problem. Imagine that Player A and Player B have very similar skillsets. Whichever player enters the room first would make a greater contribution because the second player wouldn’t have much more to add. So, maybe we should also look at what happens when B enters first followed by A.</p>
<p>You can see where this is leading. We need to observe all possible combinations of this coalition and average each player’s contribution across them.</p>
<p>Now let’s analyze this a bit further. The figure below shows all possible combinations of the room entry sequence, sorted by when A enters. The coalitions are organized into <strong>blocks</strong> (via horizontal lines) based on the <em>set</em> of players in the room before and after A enters. It’s important to recognize that <em>the entry order of players who enter before and after A doesn’t matter to A’s calculation.</em> Only the sets matter. For example, there are two coalitions where the set of players \(\{B, C\}\) enter the room after A and two coalitions where \(\{B, C\}\) enter before A. When computing A’s contribution, it doesn’t matter what order the players enter after A has entered; at this point A has already made its contribution and it isn’t influenced by what what happens next. Similarly, it doesn’t matter what order the players enter the room before A; from A’s perspective, when he enters the room he is playing with the same team whether B came in first or A.</p>
<p><img src="http://adamlineberry.ai/images/shap/coalitions.png" alt="" width="70" class="align-center" /></p>
<p>The equation to compute Player A’s Shapley Value is presented below. The set notation and factorials can make this equation daunting at first, so I’ve added some annotation and color coding to assist in unpacking this thing.</p>
<p><img src="http://adamlineberry.ai/images/shap/shapley-equation.png" alt="" class="align-center" /></p>
<p>\(F\) is the set of all players, so in our example \(F = \{A, B, C \}\). The notation \(F\setminus\{A\}\) means the set \(F\) without \(A\) which is \(\{ B, C \}\). The value function is represented by \(f(\cdot)\), which in our example outputs the payout from a given coalition.</p>
<p>The summation is over subsets \(S \subseteq F \setminus \{A\}\). In our example there are 4 subsets/terms in the sum:</p>
\[\emptyset \\
\{B\} \\
\{C\} \\
\{B,C\}\]
<p>As indicated in the annotation, the numerator \(\lvert S \lvert ! (\lvert F \lvert - \lvert S \lvert - 1)!\) is the number of combinations of players before A multiplied by the number of combinations of players after A. Ultimately this term computes the number of redundant coalitions where players \(S\) enter the room before A, which is the size of the coalition blocks in the figure.</p>
<p>The denominator \(\lvert F \lvert !\) computes the total number of coalitions and is used to compute the average contribution of A. One can move \(\frac{1}{\lvert F \lvert ! }\) outside of the summation to make this more clear.</p>
<p>After careful unpacking, hopefully it’s more clear now that this equation is iterating over redundant blocks of entry patterns, computing A’s contribution multiplied by the size of the block, and then taking the average.</p>
<h2 id="application-to-machine-learning">Application to Machine Learning</h2>
<p>The extension of this intuition to machine learning is simple. The players are features in a model, the model is the value function \(f(\cdot)\) and the payout is the model prediction. To be clear, the players are instance feature values and the payout is the model’s prediction for that particular instance. For example the players could be the covariates \(x = [ age = 45, sex = male, weight = 185 ]\), and the payout could be \(f(x) = 0.98\).</p>
<p>While the intuition may be simple, the implementation is quite complex and would be the subject of another post. First of all, if you train a model on a dataset with \(p\) features, you can’t directly make predictions with \(< p\) features. There are some computationally expensive solutions that require training separate models for each feature subset, but of course that isn’t practical.</p>
<p>This line of thought leads to marginalizing out the effect of feature value subsets and computing expected values over your data, such as<sup>3</sup>:</p>
\[f(S)=\int f(x_{1},\ldots,x_{p})d\mathbb{P}_{x\notin{}S}\]
<p>And<sup>3</sup></p>
\[f(S)=f(\{x_{1},x_{3}\})=\int_{\mathbb{R}}\int_{\mathbb{R}}\hat{f}(x_{1},X_{2},x_{3},X_{4})d\mathbb{P}_{X_2X_4}\]
<p>For the case where there are four features and you’re trying to compute the value of a coalition of only two feature values \(S = \{x_1, x_3 \}\).</p>
<p>In practice, a data scientist will likely implement the computationally efficient TreeSHAP<sup>2</sup> algorithm from the SHAP Python library<sup>6</sup>. This algorithm cleverly exploits the structure of tree-based models such as Gradient Boosted Trees and Random Forests to compute Shapley values efficiently and accurately.</p>
<h2 id="conclusion">Conclusion</h2>
<p>This post introduced SHAP and provided motivation for why a practicing data scientist may want to use it. The primary focus was on building the intuition behind Shapley Values by discussing a simple example and carefully unpacking the Shapley value equation. Finally, the connection was made to machine learning and implementation was discussed briefly.</p>
<h2 id="references">References</h2>
<p>[1] Scott M. Lundberg, Sun-In Lee, <a href="https://arxiv.org/abs/1705.07874">A Unified Approach to Interpreting Model
Predictions</a></p>
<p>[2] Scott M. Lundberg, Gabriel G. Erion, Sun-In Lee, <a href="https://arxiv.org/abs/1802.03888">Consistent Individualized Feature Attribution for Tree Ensembles</a></p>
<p>[3] Christoph Molnar, <a href="https://christophm.github.io/interpretable-ml-book/shapley.html">5.9 Shapley Values</a></p>
<p>[4] Christoph Molnar, <a href="https://christophm.github.io/interpretable-ml-book/shap.html">5.10 SHAP (SHapley Additive exPlanations)</a></p>
<p>[5] Scott Lundberg, <a href="https://towardsdatascience.com/interpretable-machine-learning-with-xgboost-9ec80d148d27">Interpretable Machine Learning with XGBoost</a></p>
<p>[6] Scott Lundberg, <a href="https://github.com/slundberg/shap">SHAP Github Project</a></p>Adam LineberryValue proposition, intuition, and analytical breakdown of SHAPLogistic Regression Deep Dive2019-11-14T00:00:00+00:002019-11-14T00:00:00+00:00http://adamlineberry.ai/logistic-regression<p>Logistic regression is possibly the most well-known machine learning model for classification tasks. The classic case is binary classification, but it can easily be extended to multiclass or even multilabel classification settings. It’s quite popular in the data science, machine learning, and statistics community for many reasons:</p>
<ul>
<li>The mathematics are relatively simple to understand</li>
<li>It is quite interpretable, both globally and locally; it is not considered a “black box” algorithm.</li>
<li>Training and inference are very fast</li>
<li>It has a minimal number of model parameters, which limits its ability to overfit</li>
</ul>
<p>In many cases, logistic regression can perform quite well on a task. However, more complex architectures will typically perform better if tuned properly. Thus, practioners and researchers alike commonly use logistic regression as a baseline model.</p>
<p>Logistic regression’s architecture is a basic form of a neural network and it is optimized the same way as a neural network. This makes it a valuable case study for those interested in deep learning.</p>
<p>In this post I will be deriving logistic regression’s objective function from Maximum Likelihood Estimation (MLE) principles, deriving the gradient of the objective with respect to the model parameters, and visualizing how a gradient descent update shifts the decision boundary for a misclassified point.</p>
<h2 id="objective-function">Objective Function</h2>
<p>Consider a dataset of \(n\) observations \(\{ (x_i, y_i) \}_{i=1}^n\) where \(x_i \in \mathbb{R}^d\) are the observed features and \(y_i \in \{0, 1\}\) are the corresponding binary labels.</p>
<p>Logistic regression is a discriminative classifier (as opposed to a generative one), meaning that it models the posterior \(p(y \lvert x)\) directly. It models this posterior as a linear function of \(x\) “squashed” into \([0, 1]\) by the sigmoid function (denoted as \(\sigma\)).</p>
<p>The model is parameterized by a vector \(w \in \mathbb{R}^{d+1}\). For notational convenience we will assume each \(x_i\) is concatenated with a \(1\) so that the model’s bias term is contained in the final component of the vector \(w_{d+1}\).</p>
\[p(y=1 \lvert x) = \sigma(w^\top x) \\
\sigma(w^\top x) = \frac{1}{1 + \exp(-w^\top x)} \\\]
<p>Using the notational convenience that \(y \in \{0, 1\}\), we can write the model equation more generally:</p>
\[p(y \lvert x) = \sigma(w^\top x)^y(1-\sigma(w^\top x))^{1-y}\]
<p>Notice the equivalence with the <a href="https://en.wikipedia.org/wiki/Bernoulli_distribution#Properties">Bernoulli PMF</a>. This means logistic regression is modeling the data with a Bernoulli likelihood. Assuming data are iid, the likelihood of the full dataset is given by:</p>
\[p(y_1, \dots, y_n \lvert x_1, \dots, x_n) =
\prod_{i=1}^n \sigma(w^\top x_i)^{y_i}(1-\sigma(w^\top x_i))^{1-y_i} \\\]
<p>There is no closed form solution to logistic regression, so the log likelihood is optimized instead via gradient descent. The log likelihood objective function \(J\) is given by:</p>
\[\begin{align}
J &= \log p(y_1, \dots, y_n \lvert x_1, \dots, x_n) \\
&= \log \prod_{i=1}^n \sigma(w^\top x_i)^{y_i}(1-\sigma(w^\top x_i))^{1-y_i} \\
&= \sum_{i=1}^n \log \sigma(w^\top x_i)^{y_i}(1-\sigma(w^\top x_i))^{1-y_i} \\
&= \sum_{i=1}^n \log \sigma(w^\top x_i)^{y_i} + \log (1-\sigma(w^\top x_i))^{1-y_i} \\
&= \sum_{i=1}^n y_i \log \sigma(w^\top x_i) + (1-y_i) \log (1-\sigma(w^\top x_i)) \\
\end{align}\]
<h2 id="gradient-of-the-objective-function">Gradient of the Objective Function</h2>
<p>In this section, we’ll be deriving the gradient of the objective function with respect to the model parameters \(w\). The derivation makes liberal use of the chain rule and other basic multivariate calculus rules.</p>
<p>The gradient of a sum is equal to the sum of the gradients, so for simplicity let’s consider a single data point.</p>
<p>Simplify notation:</p>
\[\sigma = \sigma(w^\top x)\]
<p>The derivative of the sigmoid function will come in handy:</p>
\[\frac{d}{dz} \sigma (z) = \sigma(z) (1 - \sigma(z))\]
<p>Derivation:</p>
\[\begin{align*}
\nabla_w J &= \nabla_w y \log \sigma + \nabla_w (1-y) \log (1-\sigma) \\
&= \frac{y}{\sigma}\nabla_w\sigma + \frac{1-y}{1-\sigma}\nabla_w(1-\sigma) \\
&= \frac{y}{\sigma}\nabla_w\sigma - \frac{1-y}{1-\sigma}\nabla_w\sigma \\
&= \bigg(\frac{y}{\sigma} - \frac{1-y}{1-\sigma}\bigg)\nabla_w\sigma \\
&= \frac{y(1-\sigma) - (1-y)\sigma}{\sigma(1-\sigma)} \nabla_w\sigma \\
&= \frac{y - \sigma y - \sigma + \sigma y}{\sigma(1-\sigma)} \nabla_w\sigma \\
&= \frac{y - \sigma}{\sigma(1-\sigma)}\nabla_w\sigma \\
&= \frac{y - \sigma}{\sigma(1-\sigma)} \sigma(1-\sigma) x \\
&= (y - \sigma) x
\end{align*}\]
<h2 id="visualizing-a-gradient-update">Visualizing a Gradient Update</h2>
<p>A nice property of linear models like logistic regression is that the model parameters \(w\) and the data \(x\) share the same space. Take a look at the gradient:</p>
\[\nabla_w J = (y - \sigma(w^\top x)) x\]
<p>Since \(y\) and \(\sigma(w^\top x)\) are both scalars, the gradient (for a particular data point) is just a scaled version of \(x\). To visualize this, let’s set the stage with a simple 2 dimensional dataset and model. Consider the following binary classification data:</p>
<p><img src="http://adamlineberry.ai/images/logistic-regression/data.png" alt="" class="align-center" /></p>
<p class="notice">Note: For simplicity in the example below, I’m using a logistic regression model <em>without</em> a bias term. This results in a non-affine decision boundary (ie the hyperplane must pass through the origin).</p>
<p>Let’s say the logistic regression model has been trained on this data and has learned model parameters \(w = [1, 1]^\top\). In the default setting (before any threshold tuning), the logistic regression decision boundary is defined by \(p(y=1 \lvert x) = \sigma(w^\top x) = 0.5\), which is equivalent to \(w^\top x = 0\) (you can verify this equivalence by looking at the sigmoid function plot below). This means data is classified as Class 1 if \(w^\top x > 0\) and Class 0 if \(w^\top x < 0\).</p>
<p><img src="http://adamlineberry.ai/images/logistic-regression/sigmoid.png" alt="" class="align-center" /></p>
<p>In effect, \(w\) defines a hyperplane (the decision boundary) in \(\mathbb{R}^d\), \(H = \{ x \in \mathbb{R}^d \lvert w^\top x = 0 \}\). \(w\) and \(H\) are plotted with the data below.</p>
<p><img src="http://adamlineberry.ai/images/logistic-regression/db.png" alt="" class="align-center" /></p>
<p>Now, let’s introduce a new data point, \(x = [-3, 1]^\top\) belonging to Class 1 (orange), which is incorrectly classified by the learned model:</p>
<p><img src="http://adamlineberry.ai/images/logistic-regression/new-point.png" alt="" class="align-center" /></p>
<p>Let’s consider what happens if the model is trained on just this incorrectly classified data point. Computing the gradient:</p>
\[x = \begin{bmatrix} -3 \\ 1 \end{bmatrix} \\
y = 1 \\
\sigma(w^\top x) = \sigma \bigg( \begin{bmatrix} 1 & 1 \end{bmatrix} \cdot \begin{bmatrix} -3 \\ 1 \end{bmatrix} \bigg) = 0.12 \\
y - \sigma(w^\top x) = 0.88 \\
\nabla_w J = (y - \sigma(w^\top x)) x = 0.88 \begin{bmatrix} -3 \\ 1 \end{bmatrix} = \begin{bmatrix} -2.64 \\ 0.88 \end{bmatrix}\]
<p>Now we can plot the gradient alongside the data and the existing decision boundary.</p>
<p><img src="http://adamlineberry.ai/images/logistic-regression/gradient.png" alt="" class="align-center" /></p>
<p>The gradient update equation at iteration \(t\) is</p>
<p class="notice">Note: If you’ve seen this equation in the past, the addition sign might throw you off. Usually we <em>minimize</em> the <em>negative</em> log likelihood, but in this case we’re going to <em>maximize</em> the <em>positive</em> log likelihood.</p>
\[w_{t+1} = w_t + \eta \nabla_w J\]
<p>Where \(\eta\) is a learning rate. Let’s compute the updated model parameters:</p>
\[\eta = \text{1e-1} \\
w_{t+1} = \begin{bmatrix} 1 \\ 1 \end{bmatrix} + \text{1e-1} \begin{bmatrix} -2.64 \\ 0.88 \end{bmatrix} = \begin{bmatrix} 0.74 \\ 1.09 \end{bmatrix}\]
<p>Armed with the updated model parameters, we can draw the updated decision boundary:</p>
<p><img src="http://adamlineberry.ai/images/logistic-regression/shifted.png" alt="" class="align-center" /></p>
<p>As you can see, adding a fraction of the gradient to \(w\) effectively rotated the decision boundary towards the misclassified point.</p>
<p>Recalling the gradient equation, let’s consider some other cases for the gradient and how we can interpret the math. Hopefully you can visualize how the gradient will attempt to rotate or even “flip” the decision boundary based on the class and location of the data point.</p>
\[\nabla_w J = (y - \sigma(w^\top x)) x\]
<table>
<thead>
<tr>
<th>description</th>
<th>\(y\)</th>
<th>\(\sigma(w^\top x)\)</th>
<th>\(\nabla_w J\)</th>
<th>comments</th>
</tr>
</thead>
<tbody>
<tr>
<td>model incorrect on a class 1 point</td>
<td>1</td>
<td>0.1</td>
<td>\(0.9x\)</td>
<td>The case illustrated above</td>
</tr>
<tr>
<td>model correct on a class 1 point</td>
<td>1</td>
<td>0.99</td>
<td>\(0.01x\)</td>
<td>Model is confident and correct; very small gradient update</td>
</tr>
<tr>
<td>model incorrect on a class 0 point</td>
<td>0</td>
<td>0.7</td>
<td>\(-0.7x\)</td>
<td>Rotates (or pushes) the decision boundary the opposite direction</td>
</tr>
<tr>
<td>model correct on a class 0 point</td>
<td>0</td>
<td>0.01</td>
<td>\(-0.01x\)</td>
<td>Model is confident and correct; very small gradient update</td>
</tr>
</tbody>
</table>Adam LineberryDerivation of objective and gradient, visualizing a gradient updateNotes on PCA Theory2019-11-07T00:00:00+00:002019-11-07T00:00:00+00:00http://adamlineberry.ai/notes/pca-theory<p>These are non-comprehensive notes on the theoretical machinery behind Principal Component Analysis (PCA). This post is essentially an abridged summary of Tim Roughgarden and Gregory Valiant’s lecture notes for Stanford CS168<sup>1,2</sup>.</p>
<p>Consider a dataset \(\{ \mathbf{x}_i \}_{i=1}^n\) where \(\mathbf{x}_i \in \mathbb{R}^d\) that has been preprocessed such that each \(\mathbf{x}_i\) has been shifted by the sample mean \(\bar{\mathbf{x}} = \frac{1}{n} \sum_{i=1}^n \mathbf{x}_i\). The resulting dataset has a new sample mean equal to the \(\mathbf{0}\) vector.</p>
<p>The goal of PCA is to find a set of \(k\) orthonormal vectors that form a basis for a \(k\)-dimensional subspace that minimizes the squared distances between the data and the data’s projection onto this subspace. This subspace also maximizes the variance of the data’s projection onto it. These orthonormal vectors are called the principal components.</p>
<p>For derivation purposes it is useful to consider finding only the first principal component. Solving for the remaining principal components follows naturally.</p>
<p>For notation simplicity, every time I mention a vector \(\mathbf{v}\) it is assumed to be a unit vector.</p>
<h2 id="objective-function">Objective Function</h2>
<p>The objective function is:</p>
\[\underset{\mathbf{v}}{\mathrm{argmin}}
\frac{1}{n}
\sum_{i=1}^{n}
\big(
\text{distance between }
\mathbf{x}_i
\text{ and line spanned by }
\mathbf{v}
\big)^2\]
<h2 id="connection-to-dot-product">Connection to Dot Product</h2>
<p><img src="http://adamlineberry.ai/images/pca/dot-product-projection.png" alt="" width="400" class="align-center" /></p>
<figcaption>Relationship between dot product and distance to projection<sup>1</sup></figcaption>
<p>Want to minimize \(\text{dist}(\mathbf{x}_i \leftrightarrow \text{line})\). By the Pythagorean Theorum:</p>
\[\begin{align}
||\mathbf{x}_i||^2 &=
\text{dist}(\mathbf{x}_i \leftrightarrow \text{line})^2 +
\langle \mathbf{x}_i, \mathbf{v} \rangle^2 \\
\text{dist}(\mathbf{x}_i \leftrightarrow \text{line})^2 &=
||\mathbf{x}_i||^2 -
\langle \mathbf{x}_i, \mathbf{v} \rangle^2
\end{align}\]
<p>\(\| \mathbf{x}_i \|^2\) is a constant with respect to optimization of \(\mathbf{v}\), so minimizing \(\text{dist}(\mathbf{x}_i \leftrightarrow \text{line})\) is equivalent to maximizing the squared dot product \(\langle \mathbf{x}_i, \mathbf{v} \rangle^2\):</p>
\[\underset{\mathbf{v}}{\mathrm{argmax}}
\frac{1}{n}
\sum_{i=1}^{n}
\langle \mathbf{x}_i, \mathbf{v} \rangle^2\]
<p>Note that this maximizes the variance of the projections onto \(\mathbf{v}\).</p>
<h2 id="matrix-notation-and-connection-to-mathbfxtopmathbfx">Matrix Notation and Connection to \(\mathbf{X}^{\top}\mathbf{X}\)</h2>
<p>Assemble the \(\mathbf{x}_i\)’s into a matrix:</p>
\[\mathbf{X} =
\begin{bmatrix}
-\mathbf{x}_1- \\
-\mathbf{x}_2- \\
\vdots \\
-\mathbf{x}_n-
\end{bmatrix}\]
<p>If we take the inner product of \(\mathbf{Xv}\) with itself, we get the PCA objective function:</p>
\[(\mathbf{Xv})^\top (\mathbf{Xv}) =
\mathbf{v}^\top \mathbf{X}^\top \mathbf{Xv} =
\sum_{i=1}^{n} \langle \mathbf{x}_i, \mathbf{v} \rangle^2\]
<p>The matrix \(\mathbf{A} = \mathbf{X}^\top \mathbf{X}\) is the covariance matrix of \(\mathbf{X}\) and it is symmetric.</p>
<p>The new objective is:</p>
\[\underset{\mathbf{v}}{\mathrm{argmax}} \;
\mathbf{v}^\top \mathbf{Av}\]
<h2 id="understanding-mathbfa">Understanding \(\mathbf{A}\)</h2>
<p>Diagonal matrices (ie, zero everywhere except the diagonal) expand space along the standard axes (ie, in the directions of the standard basis vectors).</p>
<p><img src="http://adamlineberry.ai/images/pca/diagonal-expansion.png" alt="" width="400" class="align-center" /></p>
<figcaption>How a diagonal matrix expands space<sup>2</sup></figcaption>
<p>I won’t prove it here, but an important note for the rest of the derivation is the fact that the solution to \(\mathrm{argmax}_\mathbf{v} \; \mathbf{v}^\top \mathbf{Dv}\) for a diagonal matrix \(\mathbf{D}\) is the standard basis vector corresponding to the dimension with the largest diagonal entry in \(\mathbf{D}\).</p>
<p>Every symmetric matrix can be expressed as a diagonal sandwiched between an orthogonal matrix<sup>[<a href="https://en.wikipedia.org/wiki/Symmetric_matrix#Decomposition">source</a>]</sup>. The lecture notes consider the case where \(\mathbf{Q}\) is a rotation matrix, but note that orthogonal matrices can take on other forms such as reflections and permutations.</p>
\[\begin{align}
\mathbf{A} &= \mathbf{QDQ}^\top \\
\mathbf{X}^\top \mathbf{X} &= \mathbf{QDQ}^\top
\end{align}\]
<p>This means that symmetric matrices still expand space orthogonally, but in directions rotated from the standard basis.</p>
<p><img src="http://adamlineberry.ai/images/pca/symmetric-expansion.png" alt="" width="400" class="align-center" /></p>
<figcaption>How a symmetric matrix expands space<sup>2</sup></figcaption>
<h2 id="eigenvectors">Eigenvectors</h2>
<p>The eigenvectors of a matrix are those vectors that simply get stretched (or scaled) by the linear transformation of the matrix. A vector \(\mathbf{v}\) is an eigenvector of matrix \(\mathbf{A}\) if the following equality holds:</p>
\[\mathbf{Av} = \lambda \mathbf{v}\]
<p>Where \(\lambda\) is the eigenvalue associated with \(\mathbf{v}\), which indicates how much \(\mathbf{v}\) is scaled (or stretched) by \(\mathbf{A}\).</p>
<p>Now consider the eigenvectors and eigenvalues of \(\mathbf{A}\). We know that \(\mathbf{A}\) stretches space orthogonally along some set of axes. These axes are the eigenvectors, and the eigenvalues tell you how much space is stretched in that direction.</p>
<p>As discussed previously, the solution is the unit vector pointing in the direction of maximum stretch. This is an eigenvector of \(\mathbf{A}\).</p>
<p>Now consider the PCA objective and remember that if \(\mathbf{v}\) is a unit vector and an eigenvector of \(\mathbf{A}\) then \(\mathbf{Av} = \lambda_{\mathbf{v}} \mathbf{v}\) and \(\mathbf{v}^\top \mathbf{v} = 1\):</p>
\[\begin{align}
\mathbf{v} &= \underset{\mathbf{v}}{\mathrm{argmax}} \;
\mathbf{v}^\top \mathbf{Av}
\\
&= \underset{\mathbf{v}}{\mathrm{argmax}} \;
\mathbf{v}^\top \lambda_{\mathbf{v}} \mathbf{v}
\\
&= \underset{\mathbf{v}}{\mathrm{argmax}} \;
\lambda_{\mathbf{v}} \mathbf{v}^\top \mathbf{v}
\\
&= \underset{\mathbf{v}}{\mathrm{argmax}} \;
\lambda_{\mathbf{v}}
\end{align}\]
<p>This says the first principal component is the eigenvector of \(\mathbf{A}\) with the largest eigenvalue. The remaining principal components can be found by taking the remaining eigenvectors sorted by their eigenvalues.</p>
<h2 id="references">References</h2>
<p>[1] Tim Roughgarden, Gregory Valiant, <a href="https://web.stanford.edu/class/cs168/l/l7.pdf">CS168: The Modern Algorithmic Toolbox Lecture #7: Understanding and Using Principal
Component Analysis (PCA)</a></p>
<p>[2] Tim Roughgarden, Gregory Valiant, <a href="https://web.stanford.edu/class/cs168/l/l8.pdf">CS168: The Modern Algorithmic Toolbox Lecture #8: How PCA Works</a></p>Adam LineberryThese are non-comprehensive notes on the theoretical machinery behind Principal Component Analysis (PCA)A Quick Primer on KL Divergence2019-07-07T00:00:00+00:002019-07-07T00:00:00+00:00http://adamlineberry.ai/vae-series/KL-divergence<p class="notice--info">This is the first post in my series: <a href="http://adamlineberry.ai/vae-series">From KL Divergence to Variational Autoencoder in PyTorch</a>. The next post in the series is <a href="http://adamlineberry.ai/vae-series/variational-inference">Latent Variable Models, Expectation Maximization, and Variational Inference</a>.</p>
<hr />
<p>The Kullback-Leibler divergence, better known as <em>KL divergence</em>, is a way to measure the “distance” between two probability distributions over the same variable. In this post we will consider distributions \(q\) and \(p\) over the random variable \(z\).</p>
<p>It’s beneficial to be able to recognize the different forms of the KL divergence equation when studying derivations or writing your own equations.</p>
<p>For discrete random variables it takes the forms:</p>
\[KL[ q \lVert p ] = \sum\limits_{z} q(z) \log\frac{q(z)}{p(z)} = -\sum\limits_{z} q(z)\log\frac{p(z)}{q(z)}\]
<p>For continuous random variables it takes the forms:</p>
\[KL[ q \lVert p ] = \int q(z) \log \frac{q(z)}{p(z)}dz = - \int q(z) \log \frac{p(z)}{q(z)}dz\]
<p>And in general it can be written as an expected value:</p>
\[KL[ q \lVert p ] = \mathbb{E_{q(z)}} \log \frac{q(z)}{p(z)} = - \mathbb{E_{q(z)}} \log \frac{p(z)}{q(z)}\]
<p>To build some intuition, let’s focus on the following form:</p>
\[KL[ q \lVert p ] = \mathbb{E_{q(z)}} \log \frac{q(z)}{p(z)}\]
<p>Notice that the term \(\log q(z)/p(z)\) is the difference between two log probabilities: \(\log q(z) - \log p(z)\). So, the intuition stems from the fact that KL divergence is the expected difference in log probabilities over \(z\). Although not entirely technically correct, imagine the following to help build an intuition: consider two, perhaps similar, univariate probability density functions \(q(z)\) and \(p(z)\) and imagine sliding across the domain of \(z\) and observing the difference \(q(z)-p(z)\) at every point. This is kind of how KL divergence quantifies the “distance” between two distributions.</p>
<p>Now, a couple of important properties that I won’t prove:</p>
\[KL[q||p] \neq KL[p||q]\]
\[KL[q||p] \geq 0 \quad \forall q, p\]
<p>The asymmetric property begs the question: should I use \(KL[q\|p]\) or \(KL[p\|q]\)? This leads to the subject of forward versus reverse KL divergence.</p>
<h2 id="forward-vs-reverse-kl-divergence">Forward vs. Reverse KL Divergence</h2>
<p>In practice, KL divergence is typically used to learn an approximate probability distribution \(q\) to estimate a theoretic but intractable distribution \(p\). Typically \(q\) will be of simpler form than \(p\), since \(p\)’s complexity is what drives us to approximate it in the first place. As a simple example, \(p\) could be a bimodal distribution and \(q\) a unimodal one. When thinking about forward versus backward KL, think of \(p\) as fixed and \(q\) as something fluid that we are free to mold to \(p\).</p>
<p>Forward KL takes the form</p>
\[KL[ p || q ] = \sum\limits_{z}p(z) \log\frac{p(z)}{q(z)}\]
<p>As you can see from this equation and the figure below, there is a penalty anywhere \(p(z) > 0\) that \(q\) is not covering. In fact, if \(q(z)=0\) in a region where \(p(z)>0\), the KL divergence blows up because \(\lim_{q(z) \to 0} \log \frac{p(z)}{q(z)} \to \infty\). This results in learning a \(q\) that spreads out to cover all regions where \(p\) has any density. This is known as “zero avoiding”.</p>
<p><img src="http://adamlineberry.ai/images/vae/forward-KL.png" alt="" width="400" class="align-center" /></p>
<figcaption>Illustration of the "zero-avoiding" behavior of forward KL. Shows a reasonable distribution q with a high forward KL divergence (top), and a different distribution q with a lower forward KL divergence (bottom).</figcaption>
<p>Reverse KL takes the form</p>
\[KL[ q || p ] = \sum\limits_{z}q(z) \log\frac{q(z)}{p(z)}\]
<p>As seen from the equation and the figure below, reverse KL has a much different behavior. Now, the KL divergence will blow up anywhere \(p(z)=0\) unless the weighting term \(q(z)=0\). In other words, \(q(z)\) is encouraged to be zero everywhere that \(p(z)\) is zero. This is called “zero-forcing” behavior.</p>
<p>For example, if \(p\) has probability density in two disjoint regions in space, a \(q\) with limited complexity may not be able to span the zero-probability space between these regions. In this case, the learned \(q\) would only have density in one of the two dense regions of \(p\).</p>
<p><img src="http://adamlineberry.ai/images/vae/reverse-KL.png" alt="" width="400" class="align-center" /></p>
<figcaption>Illustration of the "zero-forcing" behavior of reverse KL. Shows a reasonable distribution q with a high reverse KL divergence (top), and a different distribution q with a lower reverse KL divergence (bottom).</figcaption>
<h2 id="conclusion">Conclusion</h2>
<p>KL divergence is roughly a measure of distance between two probability distributions. There are different forms of the KL divergence equation. You can bring a negative out front by flipping the fraction inside the logarithm. You can also write it as an expectation.</p>
<p>Numerous machine learning models and algorithms use KL divergence as part of their loss function. By exploiting the structure of the specific model at hand, the KL divergence equation can often be simplified and optimized via gradient descent.</p>
<p>KL divergence is asymmetric and it’s important to understand the differences between forward and reverse KL.</p>
<p>My <a href="http://adamlineberry.ai/vae-series/variational-inference">next post</a> builds on KL divergence to explore latent variable models, expectation maximization, variational inference, and the ELBO.</p>
<h2 id="resources">Resources</h2>
<p>[1] Eric Jang, <a href="https://blog.evjang.com/2016/08/variational-bayes.html">A Beginner’s Guide to Variational Methods: Mean-Field Approximation</a></p>Adam LineberryIntuitive introduction to KL divergence, including discussion on its asymmetryLatent Variable Models, Expectation Maximization, and Variational Inference2019-07-07T00:00:00+00:002019-07-07T00:00:00+00:00http://adamlineberry.ai/vae-series/variational-inference<p class="notice--info">This is the second post in my series: <a href="http://adamlineberry.ai/vae-series">From KL Divergence to Variational Autoencoder in PyTorch</a>. The previous post in the series is <a href="http://adamlineberry.ai/vae-series/kl-divergence">A Quick Primer on KL Divergence</a> and the next post is <a href="http://adamlineberry.ai/vae-series/vae-theory">Variational Autoencoder Theory</a>.</p>
<hr />
<p>Latent variable models are a powerful form of [typically unsupervised] machine learning used for a variety of tasks such as clustering, dimensionality reduction, data generation, and topic modeling. The basic premise is that there is some latent and unobserved variable \(z_{i}\) that causes the observed data point \(x_{i}\). Here is the graphical model (or Bayesian network) representation:</p>
<p><img src="http://adamlineberry.ai/images/vae/graphical-model.png" alt="" width="200" class="align-center" /></p>
<p>Latent variable models model the probability distribution:</p>
\[p_{\theta}(x, z) = p_{\theta}(x \lvert z)p_{\theta}(z)\]
<p>and are trained by maximizing the marginal likelihood:</p>
\[p_{\theta}(x) = \int p_{\theta}(x \lvert z)p_{\theta}(z)dz\]
<p>The introduction of latent variables allow us to more accurately model the data and discover valuable insights from the latent variables themselves. In the topic modeling case we know beforehand that each document in a corpus tends to have a focus on a particular topic or subset of topics. For example, articles in a newspaper typically address topics such as politics, business, or sports. Real world corpora encountered in industry, such as customer support transcripts, product reviews, or legal contracts, can be more complex and ambiguous but the concept still applies. By structuring a model to incorporate this knowledge we are able to more accurately calculate the probability of a document, and perhaps more importantly, discover the topics being discussed in a corpus and provide topic assignment to individual documents.</p>
<p>The learned probability distributions such as \(p_{\theta}(x)\), \(p_{\theta}(z \lvert x)\), and \(p_{\theta}(x \lvert z)\) can be used directly for tasks like anomaly detection or data generation. More commonly though, the main contribution is inference of the latent variables themselves from these distributions. In the Gaussian mixture model (GMM) the latent variables are the cluster assignments. In latent Dirichlet allocation (LDA) the latent variables are the topic assignments. In the variational autoencoder (VAE) the latent variables are the compressed representations of that data.</p>
<h2 id="marginal-likelihood-training">Marginal likelihood training</h2>
<p>Latent variable models are trained by maximizing the marginal likelihood of the observed data under the model parameters \(\theta\). Since the logarithm is a monotonically increasing function, the marginal log likelihood is maximized instead since the logarithm simplifies the computation.</p>
\[\theta = \underset{\theta}{\mathrm{argmax}}\ p_{\theta}(x) = \underset{\theta}{\mathrm{argmax}}\ \log p_{\theta}(x)\]
<p>Given a training dataset \(X\) consisting of \(N\) data points \(\{x_1, x_2, \dots, x_N\}\), the marginal log likelihood is expressed as</p>
\[\begin{align}
\log p_{\theta}(X) &= \log \prod_{i=1}^{N} p_{\theta}(x_i) \tag{1} \\
&= \sum_{i=1}^{N} \log p_{\theta}(x_{i}) \\
&= \sum_{i=1}^{N} \log \int p_{\theta}(x_i, z_i)dz \\
&= \sum_{i=1}^{N} \log \int p_{\theta}(x_i \lvert z_i)p_{\theta}(z_i)dz
\end{align}\]
<p>Ideally we would maximize this expression directly, but the integral is typically intractable. For example, if \(z\) is high dimensional, the integral takes the form \(\int\int\int\dots\int\).</p>
<p>As previously discussed, we must also be able to compute the posterior of the latent variables in order to gain utility from these models:</p>
\[p_{\theta}(z \lvert x) = \frac{p_{\theta}(x \lvert z)p_{\theta}(z)}{p_{\theta}(x)}\]
<p>Again, this calculation is typically intractable because \(p_{\theta}(x)\) appears in the denominator. There are two main approaches to handling this issue: Monte Carlo sampling and variational inference. We will be focusing on variational inference in this post.</p>
<h2 id="derivation-of-variational-lower-bound">Derivation of Variational Lower Bound</h2>
<p>To start, let’s assume that the posterior \(p_{\theta}(z \lvert x)\) is intractable. To deal with this we will introduce a new distribution \(q_{\phi}(z)\). We would like \(q_{\phi}(z)\) to closely approximate \(p_{\theta}(z \lvert x)\) and we are free to choose any form we like for \(q\). For example, we could choose \(q\) to be static or conditional on \(x\) in some way (as you might guess, \(q\) <strong>is</strong> typically conditioned on \(x\)). A good approximation can be seen as one that minimizes the KL divergence between the distributions (for a primer on KL divergence, see <a href="http://adamlineberry.ai/vae-series/kl-divergence">this post</a>):</p>
<p class="notice">Note: For simplicity I will be using summations instead of integrals in the derivation.</p>
\[KL[q_{\phi}(z) \lVert p_{\theta}(z \lvert x)] =
-\underset{z}{\sum} q_{\phi}(z) \log \frac{p_{\theta}(z \lvert x)}{q_{\phi}(z)}\]
<p>Now, substituting using Bayes’ rule and arranging variables in a convenient way:</p>
\[\begin{align}
&= -\underset{z}{\sum} q_{\phi}(z) \log \bigg( \frac{p_{\theta}(x,z)}{q_{\phi}(z)} \cdot \frac{1}{p_{\theta}(x)} \bigg) \\
&= -\underset{z}{\sum} q_{\phi}(z) \bigg( \log\frac{p_{\theta}(x,z)}{q_{\phi}(z)} - \log p_{\theta}(x) \bigg) \\
&= -\underset{z}{\sum} q_{\phi}(z) \log \frac{p_{\theta}(x,z)}{q_{\phi}(z)} + \underset{z}{\sum} q_{\phi}(z) \log p_{\theta}(x)
\end{align}\]
<p>Note that in the second term, \(\underset{z}{\sum} q_{\phi}(z) \log p_{\theta}(x)\), \(\log p_{\theta}(x)\) is constant w.r.t. the summation so it can be moved outside, leaving \(\log p_{\theta}(x) \underset{z}{\sum} q_{\phi}(z)\). By definition of a probability distribution, \(\underset{z}{\sum} q_{\phi}(z) = 1\), so the term ultimately simplifies to \(\log p_{\theta}(x)\). So, we are left with:</p>
\[KL[q_{\phi}(z) \lVert p_{\theta}(z \lvert x)] =
-\underset{z}{\sum} q_{\phi}(z) \log \frac{p_{\theta}(x,z)}{q_{\phi}(z)} + \log p_{\theta}(x)\]
<p>Rearranging for clarity:</p>
\[\log p_{\theta}(x) = KL[q_{\phi}(z) \lVert p_{\theta}(z \lvert x)] + \underset{z}{\sum} q_{\phi}(z) \log \frac{p_{\theta}(x,z)}{q_{\phi}(z)} \tag{2}\]
<p>Now, let’s circle back to Eq. 1. Notice that we have derived an expression for the marginal log likelihood \(\log p_{\theta}(x)\) composed of two terms. The first term is the KL divergence between our variational distribution \(q_{\phi}(z)\) and the intractable posterior \(p_{\theta}(z \lvert x)\). The second term is is called the <strong>variational lower bound</strong> or evidence lower bound (the acronym <strong>ELBO</strong> is frequently used in the literature).</p>
\[\begin{align}
\mathcal{L} &= \underset{z}{\sum} q_{\phi}(z) \log \frac{p_{\theta}(x,z)}{q_{\phi}(z)} \\
&= \mathbb{E_{q_{\phi}(z)}} \log \frac{p_{\theta}(x,z)}{q_{\phi}(z)}
\end{align}\]
<p>Looking at Eq. 2 and noting that KL divergence is non-negative, you can see that \(\mathcal{L}\) must be a lower bound for the marginal log likelihood: \(\mathcal{L} \leq \log p_{\theta}(x)\). Variational inference methods focus on the tractable task of maximizing the ELBO instead of maximizing the likelihood directly.</p>
<h2 id="expectation-maximization">Expectation Maximization</h2>
<p>In the simplest case, when \(p_{\theta}(z \lvert x)\) is tractable (e.g., GMMs), the expectation maximization (EM) algorithm can be applied. First, parameters \(\theta\) are randomly initialized. EM then exploits the tractable posterior by holding \(\theta\) fixed and updating \(\phi\) by simply setting \(q_{\phi}(z) = p_{\theta}(z \lvert x)\) in the <em>E-step</em>.</p>
<p>Notice that since we are holding \(\theta\) fixed, the left hand side of Eq. 2 is a constant during this step, and the update to \(\phi\) sets the KL term to zero. This means the ELBO term is equal to the log likelihood, which is the best possible optimization step. It’s interesting because, in this interpretation, the EM algorithm does not bother with the ELBO directly in the E-step and instead maximizes it indirectly by minimizing the KL term.</p>
<p>In the <em>M-step</em>, \(\phi\) is fixed and \(\theta\) is updated by maximizing the ELBO. Isolating the terms that depend on \(\theta\)</p>
\[\begin{align}
\mathcal{L} &= \mathbb{E_{q_{\phi}(z)}} \log \frac{p_{\theta}(x,z)}{q_{\phi}(z)} \\
&= \mathbb{E_{q_{\phi}(z)}} \big( \log p_{\theta}(x,z) - \log q_{\phi}(z) \big) \\
&= \mathbb{E_{q_{\phi}(z)}} \log p_{\theta}(x,z) - \mathbb{E_{q_{\phi}(z)}} \log q_{\phi}(z)
\end{align}\]
<p>Since the second term does not depend on \(\theta\), we see that the M-step is maximizing the expected joint likelihood of the data</p>
\[\theta = \underset{\theta}{\mathrm{argmax}}\ \mathbb{E_{q_{\phi}(z)}} \log p_{\theta}(x,z)\]
<p>Although I won’t prove it here, EM has some nice convergence guarantees; it always converges to a local maximum or a saddle point of the marginal likelihood.</p>
<h2 id="conclusion">Conclusion</h2>
<p>In this post we introduced latent variable models provided some insight on their utility in real-world scenarios. Maximum likelihood training is typically intractable so we derived the variational lower bound (or ELBO) which is maximized instead.</p>
<p>In the simplest case, when \(p_{\theta}(z \lvert x)\) is tractable (e.g., GMMs), we showed how the expectation maximization algorithm can be applied.</p>
<p>However, there are plenty of cases where the posterior \(p_{\theta}(z \lvert x)\) is not tractable. A more recent approach to solving this problem is to use deep neural networks to jointly learn \(q_{\phi}(z \lvert x)\) and \(p_{\theta}(x \lvert z)\) with an ELBO loss function, such as in the variational autoencoder. For more on this see my <a href="http://adamlineberry.ai/vae-series/vae-theory">post on variational autoencoder theory</a>, where we will further refine the theory presented here to form the basis for the variational autoencoder.</p>
<h2 id="resources">Resources</h2>
<p>[1] Volodymyr Kuleshov, Stefano Ermon, <a href="https://ermongroup.github.io/cs228-notes/learning/latent/">Learning in latent variable models</a></p>
<p>[2] Ali Ghodsi, <a href="https://youtu.be/uaaqyVS9-rM">Lec : Deep Learning, Variational Autoencoder, Oct 12 2017 [Lect 6.2]</a></p>
<p>[3] Daniil Polykovskiy, Alexander Novikov, National Research University Higher School of Economics, Coursera, <a href="https://www.coursera.org/learn/bayesian-methods-in-machine-learning">Bayesian Methods for Machine Learning</a></p>Adam LineberryIntroduction to latent variable models, derivation of the ELBO, and the relationship with Expectation MaximizationVariational Autoencoder Code and Experiments2019-07-07T00:00:00+00:002019-07-07T00:00:00+00:00http://adamlineberry.ai/vae-series/vae-experiments<p class="notice--info">This is the fourth and final post in my series: <a href="http://adamlineberry.ai/vae-series">From KL Divergence to Variational Autoencoder in PyTorch</a>. The previous post in the series is <a href="http://adamlineberry.ai/vae-series/vae-theory">Variational Autoencoder Theory</a>.</p>
<hr />
<p>In this post we will build and train a variational autoencoder (VAE) in PyTorch, tying everything back to the theory derived in my <a href="http://adamlineberry.ai/vae-series/vae-theory">post on VAE theory</a>. The first half of the post provides discussion on the key points in the implementation. The second half provides the code itself along with some annotations.</p>
<p>The VAE in this post is trained on the MNIST dataset on a laptop CPU. The images (originally 28x28) are flattened into a 784 dimensional vector for simplicity. The MNIST pixel intensity values, originally continuous \(\in [0,1]\) are binarized such that each pixel value is \(\in \{0,1\}\).</p>
<p>Before diving into the code, let’s set the stage by recapping the theory that has led us to this point.</p>
<p>In variational inference for latent variable models, learning a model to maximize the marginal likelihood directly is intractable so we turn to maximizing a lower bound of it instead (referred to as the evidence lower bound, or “ELBO”). We won’t go into any further details on variational inference since it is covered in depth in my <a href="http://adamlineberry.ai/vae-series/variational-inference">post on variational inference</a>. The ELBO is then arranged in a particular way to form the objective function for the VAE:</p>
\[\begin{align}
\mathcal{L} &= \mathbb{E_{q_{\phi}(z \lvert x)}} \log p_{\theta}(x \lvert z) -
KL[q_{\phi}(z \lvert x) \lVert p_{\theta}(z)] \\
&= \sum_i \big[ \mathbb{E_{q_{\phi}(z_i \lvert x_i)}} \log p_{\theta}(x_i \lvert z_i) -
KL[q_{\phi}(z_i \lvert x_i) \lVert p_{\theta}(z_i)] \big]
\end{align}\]
<p>The basic intuition behind this objective is that the first term acts as a reconstruction loss and the KL term acts as a regularizer. This intuition is discussed in much more detail in the previous post.</p>
<p>The VAE sets a unit diagonal Gaussian prior on the latent variable: \(p_{\theta}(z) = \mathcal{N}(0, I)\), and learns the distributions \(q_{\phi}(z \lvert x)\) and \(p_{\theta}(x \lvert z)\) jointly in a single neural network. The first half of the network that maps data into a distribution over latent space is known as the <em>probabilistic encoder</em>. The second half of the network that maps samples from the latent space back into the original space is known as the <em>probabilistic decoder</em>.</p>
<p><img src="http://adamlineberry.ai/images/vae/vae-architecture.png" alt="" class="align-center" /></p>
<figcaption>Illustration of the VAE model architecture<sup>3</sup></figcaption>
<h2 id="from-the-elbo-objective-to-a-pytorch-loss-function">From the ELBO objective to a PyTorch loss function</h2>
<p>In this section we will walk carefully from the theoretical ELBO objective function to specific PyTorch commands. We will focus on the objective one term at a time.</p>
<h3 id="first-term-reconstruction">First term (reconstruction)</h3>
<p>The first term of the ELBO objective is the expected reconstruction probability:</p>
\[\mathbb{E_{q_{\phi}(z \lvert x)}} \log p_{\theta}(x \lvert z)\]
<p>Since the data is binary in this experiment, we will construct \(p_{\theta}(x \lvert z)\) to model a multivariate factorized Bernoulli distribution. (Note, the distribution chosen to model the reconstruction is dataset-specific. If you have continuous data then a diagonal Gaussian may be more appropriate.) This means that, for each data point, we view the 784 binary pixels values as independent Bernoulli observations. As such, the decoder network will output 784 Bernoulli parameters. The Bernoulli parameter is the probability of success in a binary outcome trial \(p \in [0, 1]\) (e.g., the probability of heads when flipping a biased coin).</p>
<p>Let’s take the \(j^{th}\) pixel of the \(i^{th}\) image as an example and call it \(x_{ij}\). Since we’re dealing with binary pixel values, \(x_{ij} \in \{0,1\}\) can be interpreted as the result of a Bernoulli trial. The model’s output will be the Bernoulli parameter corresponding to that pixel; let’s call that specific output \(p_{ij} \in [0,1]\). The likelihood of that pixel \(p_{\theta}(x_{ij} \lvert z_i)\) is then given by the Bernoulli PMF:</p>
\[p_{ij}^{x_{ij}}(1-p_{ij})^{1-x_{ij}}\]
<p>Since the first term in the objective deals with the log probability, we can write the log likelihood instead:</p>
\[x_{ij} \log p_{ij} + (1-x_{ij}) \log (1-p_{ij})\]
<p>This equation may look familiar. The negative of it is commonly known as binary cross entropy and is implemented in PyTorch by <a href="https://pytorch.org/docs/stable/nn.html?highlight=binary_cross_entropy#torch.nn.BCELoss"><code class="language-plaintext highlighter-rouge">torch.nn.BCELoss</code></a>.</p>
<p>Now, the log likelihood of the full data point \(x_i\) is given by</p>
\[\begin{align}
\log p_{\theta}(x_i \lvert z_i) &= \log \prod_{j=1}^{784} p_{\theta}(x_{ij} \lvert z_i) \\
&= \sum_{j=1}^{784} \log p_{\theta}(x_{ij} \lvert z_i) \\
&= \sum_{j=1}^{784} \bigg[ x_{ij} \log p_{ij} + (1-x_{ij}) \log (1-p_{ij}) \bigg]
\end{align}\]
<p>In PyTorch the final expression is implemented by <a href="https://pytorch.org/docs/stable/nn.html#binary-cross-entropy"><code class="language-plaintext highlighter-rouge">torch.nn.functional.binary_cross_entropy</code></a> with <code class="language-plaintext highlighter-rouge">reduction='sum'</code>. Since we are training in minibatches, we want the sum of log probabilities for all pixels in that minibatch. This is accomplished by simply passing full batches through the same function call. You can think of the operation performed as first summing the 784 values for each datapoint and then summing over data points in the batch. In reality, a <code class="language-plaintext highlighter-rouge">(batch_size, 784)</code> size tensor of cross entropy values will be computed and then summed over all axes.</p>
<p>The expectation of the log likelihood over \(q_{\phi}(z \lvert x)\) is satisfied by simply sampling one point from \(q_{\phi}(z \lvert x)\) and passing it through the decoder. Note that there are no additional complexities here; this is a basic forward pass. As discussed in the previous post, this is the Monte Carlo approximation of the expected value of a function.</p>
<h3 id="second-term-kl-divergence-regularization">Second term (KL divergence, regularization)</h3>
<p>The second term of the ELBO objective is the negative KL divergence between the variational posterior and the prior on the latent variable \(z\):</p>
\[-KL[q_{\phi}(z \lvert x) \lVert p_{\theta}(z)]\]
<p>Since we have defined the prior to be a diagonal unit Gaussian and we have defined the variational posterior to also be a diagonal Gaussian, this KL term has a clean closed-form solution. The solution is essentially just a function of the means and covariances of the two distributions. The negative KL term simplifies to</p>
\[-\frac{1}{2} \sum_{j=1}^{J} (1 + \log \sigma_j^2 - \mu_j^2 - \sigma_j^2)\]
<p>Where \(J\) is the size of the latent space (number of dimensions), and \(\mu\) and \(\sigma^2\) are the mean and variance vectors output from the probabilistic encoder.</p>
<p>In order to compute this, the forward pass of the network must also return mean and variance vectors output from the encoder, not just the reconstruction portion. In other words, the full model must return the outputs from both the encoder and the decoder.</p>
<p>The KL term can be computed across a minibatch with the following:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">logvar</span> <span class="o">-</span> <span class="n">mu</span><span class="p">.</span><span class="nb">pow</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="n">logvar</span><span class="p">.</span><span class="n">exp</span><span class="p">())</span>
</code></pre></div></div>
<p>Where <code class="language-plaintext highlighter-rouge">mu</code> and <code class="language-plaintext highlighter-rouge">logvar</code> are tensors of means and log variances across the minibatch, respectively. Both of these tensors will have size <code class="language-plaintext highlighter-rouge">(batch_size, latent_space_size)</code>.</p>
<h3 id="putting-the-terms-together">Putting the terms together</h3>
<p>In the following implementation, the binary cross entropy (BCE) and the KL divergence are calculated across the minibatch separately and simply summed at the end.</p>
<h2 id="sampling-from-the-encoder">Sampling from the encoder</h2>
<p>A key step in the flow of the VAE is sampling a data point from the encoder \(q_{\phi}(z_i \lvert x_i)\). The reparameterization trick is used to perform this sampling without introducing a discontinuity in the network (as discussed in the previous post).</p>
\[z_i = g_{\phi}(x_i, \epsilon_i) = \mu_{\phi}(x_i) + diag(\sigma_{\phi}(x_i)) \cdot \epsilon_i \\
\epsilon_i \sim \mathcal{N}(0, I)\]
<p>In the forward pass, the vector of means and log variances are collected from the encoder. These vectors are used to generate a data sample as such</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">reparameterize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">):</span>
<span class="n">std</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="mf">0.5</span><span class="o">*</span><span class="n">logvar</span><span class="p">)</span>
<span class="n">eps</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">std</span><span class="p">)</span>
<span class="k">return</span> <span class="n">mu</span> <span class="o">+</span> <span class="n">eps</span><span class="o">*</span><span class="n">std</span>
</code></pre></div></div>
<h2 id="experiment-results">Experiment results</h2>
<h3 id="data-generation">Data generation</h3>
<p>At various points during training, I sampled a grid of points from the latent space. The points are linearly spaced coordinates on the unit square, transformed through the inverse Gaussian CDF. This results in a grid of points with evenly spaced quantiles of the Gaussian. In plain English (sort of), this means slicing the Gaussian into equal sized chunks.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">scipy.stats</span> <span class="kn">import</span> <span class="n">norm</span>
<span class="n">n</span> <span class="o">=</span> <span class="mi">20</span>
<span class="bp">self</span><span class="p">.</span><span class="n">grid_x</span> <span class="o">=</span> <span class="n">norm</span><span class="p">.</span><span class="n">ppf</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.95</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span>
<span class="bp">self</span><span class="p">.</span><span class="n">grid_y</span> <span class="o">=</span> <span class="n">norm</span><span class="p">.</span><span class="n">ppf</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.95</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span>
</code></pre></div></div>
<p>The following gifs show the maturation of the model’s latent space and data generating capabilities at various points throughout training. At the beginning of the animations, the generated data are mostly noise. But as training (and the animation) progresses, you begin to recognize shapes. Keep in mind that the images you’re seeing here are essentially “fake”, in that they are not images from any dataset.</p>
<p>Animation throughout the entire training process:</p>
<p><img src="http://adamlineberry.ai/images/vae/datagen_tracking.gif" alt="" width="500" class="align-center" /></p>
<p>Animation for just the early stages of training:
<img src="http://adamlineberry.ai/images/vae/datagen_tracking_early.gif" alt="" width="500" class="align-center" /></p>
<p>The final state of the learned manifold after training has completed:
<img src="http://adamlineberry.ai/images/vae/datagen_final.png" width="500" class="align-center" /></p>
<p>As you can see, there are regions dedicated to individual digits with smooth transitions in between. I tried hand drawing the boundaries between digits to aid the visualization:
<img src="http://adamlineberry.ai/images/vae/datagen_final_handdrawn_partitions.png" width="500" class="align-center" /></p>
<h3 id="data-reconstruction">Data reconstruction</h3>
<p>At various points throughout training I also tracked how well the model was reconstructing five hand-selected images:</p>
<p><img src="http://adamlineberry.ai/images/vae/output_6_0.png" alt="png" class="align-center" /></p>
<p>The following animation shows how the model’s ability to reconstruct data improves over the training process:</p>
<p><img src="http://adamlineberry.ai/images/vae/recon_tracking_early.gif" alt="" width="320" class="align-center" /></p>
<h3 id="anomaly-detection">Anomaly detection</h3>
<p>Anomalous data can be detected by leveraging the probabilistic nature of the VAE. One way to detect anomalies is to measure the KL divergence between the encoder distribution \(q_{\phi}(z_i \lvert x_i)\) and the prior \(p_{\theta}(z)\) and compare it to the average across the training (or test) set.</p>
<p>I computed this KL divergence for every point in the training set and plotted the resulting distribution:</p>
<p><img src="http://adamlineberry.ai/images/vae/kl_dist.png" alt="" width="400" class="align-center" /></p>
<p>I then generated a noise sample:</p>
<p><img src="http://adamlineberry.ai/images/vae/noise.png" alt="" class="align-center" /></p>
<p>And calculated its KL divergence: 51.763. As you can see from the distribution plot, this value is a significant outlier and would be easy to detect using automated anomaly detection systems.</p>
<h2 id="pytorch-code">PyTorch Code</h2>
<p>The data loading, data transformation, model architecture, loss function, and training loop are presented in this section. Detailed discussion on the key points of implementation are discussed above, but additional code annotation is provided for clarity. For the full code including visualization generation and experiment execution, please see <a href="https://github.com/acetherace/alcore/tree/master/notebooks/VAE.ipynb">this notebook</a> on Github.</p>
<h3 id="imports-and-helpers">Imports and Helpers</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span><span class="p">,</span> <span class="n">optim</span>
<span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">datasets</span><span class="p">,</span> <span class="n">transforms</span>
<span class="kn">from</span> <span class="nn">torchvision.utils</span> <span class="kn">import</span> <span class="n">save_image</span>
<span class="kn">from</span> <span class="nn">fastprogress</span> <span class="kn">import</span> <span class="n">master_bar</span><span class="p">,</span> <span class="n">progress_bar</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">from</span> <span class="nn">scipy.stats</span> <span class="kn">import</span> <span class="n">norm</span>
<span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
<span class="kn">import</span> <span class="nn">imageio</span>
<span class="o">%</span><span class="n">matplotlib</span> <span class="n">inline</span>
</code></pre></div></div>
<p>Set configuration parameters for model training:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">128</span>
<span class="n">epochs</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">seed</span> <span class="o">=</span> <span class="mi">199</span>
<span class="n">log_interval</span><span class="o">=</span><span class="mi">10</span>
<span class="n">device</span><span class="o">=</span><span class="s">'cpu'</span>
<span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
<span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
</code></pre></div></div>
<h3 id="load-and-prep-data">Load and Prep Data</h3>
<p>I added the <code class="language-plaintext highlighter-rouge">lambda x: x.round()</code> transformation to convert the images into binary form. We’re assuming the data likelihood to follow a Bernoulli distribution and this connection is more clear when the data is binary.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">xforms</span> <span class="o">=</span> <span class="n">transforms</span><span class="p">.</span><span class="n">Compose</span><span class="p">([</span>
<span class="n">transforms</span><span class="p">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">.</span><span class="nb">round</span><span class="p">()</span>
<span class="p">])</span>
</code></pre></div></div>
<p>I hand picked five images to use for visualizing reconstruction performance throughout training.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">ds</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">.</span><span class="n">MNIST</span><span class="p">(</span><span class="s">'../data'</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">xforms</span><span class="p">)</span>
<span class="n">recon_base_imgs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">15</span><span class="p">,</span> <span class="mi">22</span><span class="p">]:</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">ds</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
<span class="n">recon_base_imgs</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="n">recon_base_img</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">np</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">numpy</span><span class="p">())</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">recon_base_imgs</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">del</span> <span class="n">ds</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">()</span>
<span class="n">ax</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">recon_base_img</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s">'gray'</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="s">'none'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<p><img src="http://adamlineberry.ai/images/vae/output_6_0.png" alt="png" /></p>
<p>Instantiate data loaders.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">datasets</span><span class="p">.</span><span class="n">MNIST</span><span class="p">(</span><span class="s">'../data'</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
<span class="n">transform</span><span class="o">=</span><span class="n">xforms</span><span class="p">),</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">datasets</span><span class="p">.</span><span class="n">MNIST</span><span class="p">(</span><span class="s">'../data'</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">xforms</span><span class="p">),</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="p">.</span><span class="n">dataset</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>
<span class="c1"># (60000, 10000)
</span></code></pre></div></div>
<h3 id="define-model-and-training-functions">Define Model and Training Functions</h3>
<p>The encoder and decoder modules are defined separately as <code class="language-plaintext highlighter-rouge">VAEEncoder</code> and <code class="language-plaintext highlighter-rouge">BernoulliVAEDecoder</code>, respectively. The <code class="language-plaintext highlighter-rouge">BernoulliVAE</code> class combines them to form the full model. It allows for a variable number of hidden layers and hidden layer sizes in both the encoder and decoder. It uses the ReLU activation function at each hidden layer.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">VAEEncoder</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="s">"""
Standard encoder module for variational autoencoders with tabular input and
diagonal Gaussian posterior.
"""</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_size</span><span class="p">,</span> <span class="n">hidden_sizes</span><span class="p">,</span> <span class="n">latent_size</span><span class="p">):</span>
<span class="s">"""
Args:
data_size (int): Dimensionality of the input data.
hidden_sizes (list[int]): Sizes of hidden layers (not including the
input layer or the latent layer).
latent_size (int): Size of the latent space.
"""</span>
<span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="p">.</span><span class="n">data_size</span><span class="o">=</span><span class="n">data_size</span>
<span class="c1"># construct the encoder
</span> <span class="n">encoder_szs</span> <span class="o">=</span> <span class="p">[</span><span class="n">data_size</span><span class="p">]</span> <span class="o">+</span> <span class="n">hidden_sizes</span>
<span class="n">encoder_layers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">in_sz</span><span class="p">,</span><span class="n">out_sz</span><span class="p">,</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">encoder_szs</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">encoder_szs</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span>
<span class="n">encoder_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_sz</span><span class="p">,</span> <span class="n">out_sz</span><span class="p">))</span>
<span class="n">encoder_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">())</span>
<span class="bp">self</span><span class="p">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">encoder_layers</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">encoder_mu</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">encoder_szs</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">latent_size</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">encoder_logvar</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">encoder_szs</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">latent_size</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">gaussian_param_projection</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">encoder_mu</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="bp">self</span><span class="p">.</span><span class="n">encoder_logvar</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">reparameterize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">):</span>
<span class="n">std</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="mf">0.5</span><span class="o">*</span><span class="n">logvar</span><span class="p">)</span>
<span class="n">eps</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">std</span><span class="p">)</span>
<span class="k">return</span> <span class="n">mu</span> <span class="o">+</span> <span class="n">eps</span><span class="o">*</span><span class="n">std</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">encode</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">gaussian_param_projection</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">reparameterize</span><span class="p">(</span><span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">)</span>
<span class="k">return</span> <span class="n">z</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span>
<span class="k">class</span> <span class="nc">BernoulliVAEDecoder</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="s">"""
VAE decoder module that models a diagonal multivariate Bernoulli
distribution with a feed-forward neural net.
"""</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_size</span><span class="p">,</span> <span class="n">hidden_sizes</span><span class="p">,</span> <span class="n">latent_size</span><span class="p">):</span>
<span class="s">"""
Args:
data_size (int): Dimensionality of the input data.
hidden_sizes (list[int]): Sizes of hidden layers (not including the
input layer or the latent layer).
latent_size (int): Size of the latent space.
"""</span>
<span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
<span class="c1"># construct the decoder
</span> <span class="n">hidden_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">latent_size</span><span class="p">]</span> <span class="o">+</span> <span class="n">hidden_sizes</span>
<span class="n">decoder_layers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">in_sz</span><span class="p">,</span><span class="n">out_sz</span><span class="p">,</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">hidden_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">hidden_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span>
<span class="n">decoder_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_sz</span><span class="p">,</span> <span class="n">out_sz</span><span class="p">))</span>
<span class="n">decoder_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">())</span>
<span class="n">decoder_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_sizes</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">data_size</span><span class="p">))</span>
<span class="n">decoder_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Sigmoid</span><span class="p">())</span>
<span class="bp">self</span><span class="p">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">decoder_layers</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">BernoulliVAE</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="s">"""
VAE module that combines a `VAEEncoder` and a `BernoulliVAEDecoder` resulting
in full VAE.
"""</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_size</span><span class="p">,</span> <span class="n">encoder_szs</span><span class="p">,</span> <span class="n">latent_size</span><span class="p">,</span> <span class="n">decoder_szs</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
<span class="c1"># if decoder_szs not specified, assume symmetry
</span> <span class="k">if</span> <span class="n">decoder_szs</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">decoder_szs</span> <span class="o">=</span> <span class="n">encoder_szs</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="c1"># construct the encoder
</span> <span class="bp">self</span><span class="p">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">VAEEncoder</span><span class="p">(</span><span class="n">data_size</span><span class="o">=</span><span class="n">data_size</span><span class="p">,</span> <span class="n">hidden_sizes</span><span class="o">=</span><span class="n">encoder_szs</span><span class="p">,</span>
<span class="n">latent_size</span><span class="o">=</span><span class="n">latent_size</span><span class="p">)</span>
<span class="c1"># construct the decoder
</span> <span class="bp">self</span><span class="p">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">BernoulliVAEDecoder</span><span class="p">(</span><span class="n">data_size</span><span class="o">=</span><span class="n">data_size</span><span class="p">,</span> <span class="n">latent_size</span><span class="o">=</span><span class="n">latent_size</span><span class="p">,</span>
<span class="n">hidden_sizes</span><span class="o">=</span><span class="n">decoder_szs</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">data_size</span> <span class="o">=</span> <span class="n">data_size</span>
<span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">z</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">p_x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
<span class="k">return</span> <span class="n">p_x</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span>
</code></pre></div></div>
<p>The loss function is discussed in detail above.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Reconstruction + KL divergence losses summed over all elements and batch
</span><span class="k">def</span> <span class="nf">loss_function</span><span class="p">(</span><span class="n">recon_x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">):</span>
<span class="n">BCE</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">recon_x</span><span class="p">,</span> <span class="n">x</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">),</span> <span class="n">reduction</span><span class="o">=</span><span class="s">'sum'</span><span class="p">)</span>
<span class="c1"># see Appendix B from VAE paper:
</span> <span class="c1"># Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
</span> <span class="c1"># https://arxiv.org/abs/1312.6114
</span> <span class="c1"># 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
</span> <span class="n">KLD</span> <span class="o">=</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">logvar</span> <span class="o">-</span> <span class="n">mu</span><span class="p">.</span><span class="nb">pow</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="n">logvar</span><span class="p">.</span><span class="n">exp</span><span class="p">())</span>
<span class="k">return</span> <span class="n">BCE</span> <span class="o">+</span> <span class="n">KLD</span>
</code></pre></div></div>
<p>Function to execute one epoch of training. This function also generates visualizations every <code class="language-plaintext highlighter-rouge">figure_interval</code> batches.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">mb</span><span class="p">,</span> <span class="n">figure_interval</span><span class="p">,</span> <span class="n">viz_helper</span><span class="p">):</span>
<span class="n">model</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
<span class="n">train_loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">pb</span> <span class="o">=</span> <span class="n">progress_bar</span><span class="p">(</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">parent</span><span class="o">=</span><span class="n">mb</span><span class="p">)</span>
<span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">_</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">pb</span><span class="p">):</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span>
<span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="n">recon_batch</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">loss_function</span><span class="p">(</span><span class="n">recon_batch</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">)</span>
<span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">train_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
<span class="k">if</span> <span class="n">batch_idx</span> <span class="o">%</span> <span class="n">figure_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">viz_helper</span><span class="p">.</span><span class="n">execute</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
<span class="k">return</span> <span class="n">train_loss</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>
</code></pre></div></div>
<p>Function to perform evaluation on the test set.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">test</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">mb</span><span class="p">):</span>
<span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="n">test_loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">pb</span> <span class="o">=</span> <span class="n">progress_bar</span><span class="p">(</span><span class="n">test_loader</span><span class="p">,</span> <span class="n">parent</span><span class="o">=</span><span class="n">mb</span><span class="p">)</span>
<span class="k">for</span> <span class="n">data</span><span class="p">,</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">test_loader</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span>
<span class="n">recon_batch</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">test_loss</span> <span class="o">+=</span> <span class="n">loss_function</span><span class="p">(</span><span class="n">recon_batch</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">).</span><span class="n">item</span><span class="p">()</span>
<span class="k">return</span> <span class="n">test_loss</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>
</code></pre></div></div>
<p>Function to fit the model over a number of epochs and generate visualizations.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">epochs</span><span class="p">,</span> <span class="n">figure_interval</span><span class="p">,</span> <span class="n">viz_helper</span><span class="p">):</span>
<span class="n">mb</span> <span class="o">=</span> <span class="n">master_bar</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">epochs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">viz_helper</span><span class="p">.</span><span class="n">execute</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">mb</span><span class="p">:</span>
<span class="n">trn_loss</span> <span class="o">=</span> <span class="n">train</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">mb</span><span class="p">,</span> <span class="n">figure_interval</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">viz_helper</span><span class="o">=</span><span class="n">viz_helper</span><span class="p">)</span>
<span class="n">tst_loss</span> <span class="o">=</span> <span class="n">test</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">mb</span><span class="p">)</span>
<span class="n">mb</span><span class="p">.</span><span class="n">write</span><span class="p">(</span><span class="sa">f</span><span class="s">'epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s">, train loss: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">trn_loss</span><span class="p">,</span><span class="mi">6</span><span class="p">)</span><span class="si">}</span><span class="s">, test loss: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">tst_loss</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
</code></pre></div></div>
<h3 id="vae-with-20-d-latent-space">VAE with 20-d Latent Space</h3>
<p>Train a VAE with a 20 dimensional latent space. This VAE will be used to generate the data reconstruction visualizations.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">viz_helper_20d</span> <span class="o">=</span> <span class="n">VAEVizHelper</span><span class="p">(</span><span class="n">recon_base_imgs</span><span class="p">,</span> <span class="n">datagen_tracking</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">BernoulliVAE</span><span class="p">(</span><span class="n">data_size</span><span class="o">=</span><span class="mi">784</span><span class="p">,</span> <span class="n">encoder_szs</span><span class="o">=</span><span class="p">[</span><span class="mi">400</span><span class="p">],</span> <span class="n">latent_size</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span>
<span class="n">decoder_szs</span><span class="o">=</span><span class="p">[</span><span class="mi">400</span><span class="p">]).</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span>
<span class="c1"># BernoulliVAE(
# (encoder): VAEEncoder(
# (encoder): Sequential(
# (0): Linear(in_features=784, out_features=400, bias=True)
# (1): ReLU()
# )
# (encoder_mu): Linear(in_features=400, out_features=20, bias=True)
# (encoder_logvar): Linear(in_features=400, out_features=20, bias=True)
# )
# (decoder): BernoulliVAEDecoder(
# (decoder): Sequential(
# (0): Linear(in_features=20, out_features=400, bias=True)
# (1): ReLU()
# (2): Linear(in_features=400, out_features=784, bias=True)
# (3): Sigmoid()
# )
# )
# )
</span></code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fit</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="n">viz_helper_20d</span><span class="p">)</span>
</code></pre></div></div>
<p>Total time: 02:10 <p>epoch 1, train loss: 157.707501, test loss: 116.365121<p>epoch 2, train loss: 108.549474, test loss: 102.344373<p>epoch 3, train loss: 99.825192, test loss: 96.556084<p>epoch 4, train loss: 95.784532, test loss: 94.104183<p>epoch 5, train loss: 93.294786, test loss: 91.994745<p>epoch 6, train loss: 91.638687, test loss: 90.58567<p>epoch 7, train loss: 90.407814, test loss: 89.906129<p>epoch 8, train loss: 89.389802, test loss: 88.779571<p>epoch 9, train loss: 88.574026, test loss: 88.075248<p>epoch 10, train loss: 87.911918, test loss: 87.646744</p>
<h2 id="vae-with-2-d-latent-space">VAE with 2-d Latent Space</h2>
<p>Train a VAE with 2 dimensional latent space. This model will be used to generate the visualizations of data generation across the latent manifold. It is much easier to visualize a 2-d manifold.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">BernoulliVAE</span><span class="p">(</span><span class="n">data_size</span><span class="o">=</span><span class="mi">784</span><span class="p">,</span> <span class="n">encoder_szs</span><span class="o">=</span><span class="p">[</span><span class="mi">400</span><span class="p">,</span><span class="mi">150</span><span class="p">],</span> <span class="n">latent_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
<span class="n">decoder_szs</span><span class="o">=</span><span class="p">[</span><span class="mi">150</span><span class="p">,</span><span class="mi">400</span><span class="p">]).</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span>
<span class="c1"># BernoulliVAE(
# (encoder): VAEEncoder(
# (encoder): Sequential(
# (0): Linear(in_features=784, out_features=400, bias=True)
# (1): ReLU()
# (2): Linear(in_features=400, out_features=150, bias=True)
# (3): ReLU()
# )
# (encoder_mu): Linear(in_features=150, out_features=2, bias=True)
# (encoder_logvar): Linear(in_features=150, out_features=2, bias=True)
# )
# (decoder): BernoulliVAEDecoder(
# (decoder): Sequential(
# (0): Linear(in_features=2, out_features=150, bias=True)
# (1): ReLU()
# (2): Linear(in_features=150, out_features=400, bias=True)
# (3): ReLU()
# (4): Linear(in_features=400, out_features=784, bias=True)
# (5): Sigmoid()
# )
# )
# )
</span></code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fit</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="n">viz_helper_2d</span><span class="p">)</span>
</code></pre></div></div>
<p>Total time: 02:50 <p>epoch 1, train loss: 164.487584, test loss: 159.490324<p>epoch 2, train loss: 155.384058, test loss: 152.910404<p>epoch 3, train loss: 150.389809, test loss: 149.377793<p>epoch 4, train loss: 147.043944, test loss: 146.478713<p>epoch 5, train loss: 144.72857, test loss: 144.263316<p>epoch 6, train loss: 142.759334, test loss: 143.399154<p>epoch 7, train loss: 141.358421, test loss: 141.273972<p>epoch 8, train loss: 140.11891, test loss: 141.396691<p>epoch 9, train loss: 139.222043, test loss: 140.119026<p>epoch 10, train loss: 138.443796, test loss: 140.24413</p>
<h2 id="conclusion">Conclusion</h2>
<p>In this post we drew the final connections between the abstract theory of variational autoencoders and a concrete implementation in PyTorch. By sampling a grid from the latent space and using the probabilistic decoder to map these samples into synthetic digits, we saw how the model has learned a highly structured latent space with smooth transitions between digit classes. We also discussed a simple example demonstrating how the VAE can be used for anomaly detection.</p>
<h2 id="resources">Resources</h2>
<p>[1] PyTorch, <a href="https://github.com/pytorch/examples/tree/master/vae">Basic VAE Example</a></p>Adam LineberryDetailed walkthrough of a PyTorch VAE implementation trained on MNIST, including visualizations of data generation, reconstruction, and anomaly detectionVariational Autoencoder Theory2019-07-07T00:00:00+00:002019-07-07T00:00:00+00:00http://adamlineberry.ai/vae-series/vae-theory<p class="notice--info">This is the third post in my series: <a href="http://adamlineberry.ai/vae-series">From KL Divergence to Variational Autoencoder in PyTorch</a>. The previous post in the series is <a href="http://adamlineberry.ai/vae-series/variational-inference">Latent Variable Models, Expectation Maximization, and Variational Inference</a> and the next post is <a href="http://adamlineberry.ai/vae-series/vae-code-experiments">Variational Autoencoder Code and Experiments</a>.</p>
<hr />
<p>The Variational Autoencoder has taken the machine learning community by storm since Kingma and Welling’s seminal paper was released in 2013<sup>1</sup>. It was one of the first model architectures in the mainstream to establish a strong connection between deep learning and Bayesian statistics. Quite frankly, it’s also just really cool. A VAE trained on image data results in the ability to create spectacular visualizations of the latent factors it learns and the realistic images it can generate. In the world of data science it’s an excellent bridge between statistics and computer science. It’s interesting to think about and tinker with, and it makes a great sandbox to learn and build intuition about deep learning and statistics.</p>
<figure class="half" style="display:flex">
<img src="http://adamlineberry.ai/images/vae/datagen_final.png" height="100" />
<img src="http://adamlineberry.ai/images/vae/frey_face.png" height="100" />
<figcaption>(left) Synthesized digits from MNIST sampled from a grid on the learned latent manifold. Notice the smooth transitions between digits. (right) Synthesized faces sampled from a grid on the manifold of a VAE trained on the Frey Face dataset<sup>1</sup>. Notice that the VAE has learned interpretable latent factors: left-to-right adjusts head orientation, top-to-bottom adjusts level of frowning or smiling. </figcaption>
</figure>
<p>It isn’t just a playground though; there are extremely valuable applications for the VAE on real world problems. It can be used for representation learning/feature engineering/dimensionality reduction to improve performance on downstream tasks such as classification models or recommender systems. You can also leverage its probabilistic nature to perform anomaly detection. Its data generation capability also lends itself to assist in the training of reinforcement learning systems.</p>
<p>The VAE seems very similar to other autoencoders. At a high level, an autoencoder is a deep neural network that is trained to reconstruct its own input. There are many variations of this fundamental idea that accomplish different end tasks, such as the vanilla autoencoder, the denoising autoencoder, and the sparse autoencoder. But the VAE stands apart from the rest in that it is a fully probabilistic model.</p>
<p>In this post we are going to introduce the theory of the VAE by building on concepts introduced in the previous post, such as variational inference and maximizing the Evidence Lower Bound (ELBO).</p>
<div class="notice">
<p><strong>Table of contents:</strong></p>
<ol>
<li>Derivation of the VAE objective function</li>
<li>Intuition behind the VAE objective function</li>
<li>Model architecture</li>
<li>Optimization</li>
<li>Practical uses of the VAE</li>
</ol>
</div>
<h2 id="derivation-of-the-vae-objective-function">Derivation of the VAE objective function</h2>
<p>As discussed in my <a href="http://adamlineberry.ai/vae-series/variational-inference">post on variational inference</a>, the intractable data likelihood which we would like to maximize can be decomposed into the following expression:</p>
\[\log p_{\theta}(x) = KL[q_{\phi}(z) \lVert p_{\theta}(z \lvert x)] + \underset{z}{\sum} q_{\phi}(z) \log \frac{p_{\theta}(x,z)}{q_{\phi}(z)}\]
<p>The focus of variational inference methods, including the VAE, is to maximize the second term in this expression, commonly known as the ELBO or variational lower bound:</p>
\[\mathcal{L} = \underset{z}{\sum} q_{\phi}(z) \log \frac{p_{\theta}(x,z)}{q_{\phi}(z)}\]
<p>In order to set the stage for the VAE, let’s rearrange \(\mathcal{L}\) slightly by first writing it as an expectation, substituting Bayes’ Rule, splitting up the logarithm, and recognizing a KL divergence term:</p>
\[\begin{align}
\mathcal{L} &= \underset{z}{\sum} q_{\phi}(z) \log \frac{p_{\theta}(x,z)}{q_{\phi}(z)} \\
&= \mathbb{E_{q_{\phi}(z)}} \log \frac{p_{\theta}(x,z)}{q_{\phi}(z)} \\
&= \mathbb{E_{q_{\phi}(z)}} \log \frac{p_{\theta}(x \lvert z)p_{\theta}(z)}{q_{\phi}(z)} \\
&= \mathbb{E_{q_{\phi}(z)}} \log p_{\theta}(x \lvert z) + \mathbb{E_{q_{\phi}(z)}} \log \frac{p_{\theta}(z)}{q_{\phi}(z)} \\
&= \mathbb{E_{q_{\phi}(z)}} \log p_{\theta}(x \lvert z) -
KL[q_{\phi}(z) \lVert p_{\theta}(z)]
\end{align}\]
<p>Since \(q\) is intended to approximate the posterior \(p_{\theta}(z \lvert x)\) we will choose \(q\) to be conditional on \(x\): \(q_{\phi}(z) = q_{\phi}(z \lvert x)\). Now we’re ready to write down the objective function for the VAE:</p>
\[\mathcal{L} = \mathbb{E_{q_{\phi}(z \lvert x)}} \log p_{\theta}(x \lvert z) -
KL[q_{\phi}(z \lvert x) \lVert p_{\theta}(z)] \tag{1}\]
<h2 id="intuition-behind-the-vae-objective-function">Intuition behind the VAE objective function</h2>
<p>It’s easy to get get lost in the weeds here, so let’s zoom back out to the big picture for a moment: we want to learn a latent variable model of our data that maximizes the likelihood of the observed data \(x\). We have already shown that it is intractable to maximize this likelihood directly, so we have turned to approximating \(p_{\theta}(z \lvert x)\) with a new distribution \(q_{\phi}\) and maximizing the ELBO instead.</p>
<p>The practical items we would like to extract from this model are the ability to map data into latent space using \(q_{\phi}(z \lvert x)\) for exploration and/or dimensionality reduction, and the ability to synthesize new data by sampling from the latent space according to \(p_{\theta}(z)\) and then generating new data from \(p_{\theta}(x \lvert z)\).</p>
<p>Now, let’s begin unpacking the objective function by defining the prior on \(z\). The VAE sets this prior to a diagonal unit Gaussian: \(p_{\theta}(z) = \mathcal{N}(0, I)\). It can be shown that a simple Gaussian such as this can be mapped into very complicated distributions as long as the mapping function is sufficiently complex (e.g. a neural network)<sup>2</sup>. This choice also simplifies the optimization problem as we will see shortly.</p>
<p>Next, let’s discuss the first term in the objective.</p>
\[\mathbb{E_{q_{\phi}(z \lvert x)}} \log p_{\theta}(x \lvert z)\]
<p>We want to learn two distributions, \(q\) and \(p\). The \(q\) we learn should be able to map data points \(x_i\) into a latent representation \(z_i\) from which \(p_{\theta}(x \lvert z)\) is able to successfully reconstruct the original data point \(x_i\). This term is something very similar to the standard reconstruction loss (e.g., MSE) used in vanilla autoencoders. In fact, under certain conditions, it can be shown that this term simplifies to be almost identical to MSE.</p>
<p>Simultaneously, the KL term is pushing \(q\) to look like our Gaussian prior \(p_{\theta}(z)\).</p>
\[- KL[q_{\phi}(z \lvert x) \lVert p_{\theta}(z)]\]
<p>This term is commonly interpreted as a form of regularization. It prevents the model from memorizing the training data and forces it to learn an informative latent manifold that pairs nicely with \(p_{\theta}(x \lvert z)\). Without it, the greedy model would learn distributions \(q_{\phi}(z \lvert x)\) with zero variance, essentially degrading to a vanilla autoencoder. By enforcing \(q_{\phi}(z \lvert x)\) to have some variance, the learned \(p_{\theta}(x \lvert z)\) must be robust against small changes in \(z\). This results in a smooth latent space \(z\) that can be reliably sampled from to generate new, realistic data, whereas sampling from the latent space of a vanilla autoencoder will almost always return junk<sup>5</sup>.</p>
<h2 id="model-architecture">Model architecture</h2>
<p>We choose \(q_{\phi}(z \lvert x)\) to be an infinite mixture of diagonal multivariate Gaussians</p>
\[q_{\phi}(z \lvert x) = \mathcal{N}(\mu_{\phi}(x), diag(\sigma^2_{\phi}(x)))\]
<p>Where the Gaussian parameters \(\mu\) and \(\sigma^2\) are modeled as parametric functions of \(x\). Note that \(\sigma^2\) is a vector of the diagonal elements of the covariance matrix. This choice provides us with a flexible distribution on \(z\) which is data point-specific because of its explicit conditioning on \(x\).</p>
<p>The VAE models the parameters of \(q\), \(\{\mu_{\phi}(x), \sigma^2_{\phi}(x)\}\), with a neural network that outputs a vector of means \(\mu\) and a vector of variances \(\sigma^2\) for each data point \(x_i\).</p>
<p>Similarly, the distribution \(p_{\theta}(x \lvert z)\) is modeled as an infinite mixture of diagonal distributions, where a neural network outputs parameters of the distribution. Depending on the type of data, this distribution is typically chosen to be Gaussian or Bernoulli. When working with binary data (like in the next post) the Bernoulli is used:</p>
\[p_{\theta}(x \lvert z) = \mathcal{Bern}(h_{\theta}(z))\]
<p>Where \(h_{\theta}(z)\) is an MLP mapping from the latent dimension to the data dimension. The output vector of \(h_{\theta}(z)\) contains Bernoulli parameters that are used to form the probability distribution \(p_{\theta}(x \lvert z)\).</p>
<p>Distributions \(p_{\theta}(x \lvert z)\) and \(q_{\phi}(z \lvert x)\) are learned jointly in the same neural network:</p>
<p><img src="http://adamlineberry.ai/images/vae/vae-architecture.png" alt="" class="align-center" /></p>
<figcaption>Illustration of the VAE model architecture<sup>3</sup></figcaption>
<p>It is clear how the VAE model architecture closely resembles that of standard autoencoders. The first half of the network which is modeling \(q_{\phi}(z \lvert x)\) is known as the <em>probabilistic encoder</em> and the second half of the network which models \(p_{\theta}(x \lvert z)\) is known as the <em>probabilistic decoder</em>. This interpretation further extends the analogy between VAEs and standard autoencoders, but it should be noted that the mechanics and motivations are actually quite different.</p>
<p>The neural network weights are updated via SGD to maximize the objective function discussed previously:</p>
\[\mathcal{L} = \mathbb{E_{q_{\phi}(z \lvert x)}} \log p_{\theta}(x \lvert z) -
KL[q_{\phi}(z \lvert x) \lVert p_{\theta}(z)]\]
<h2 id="optimization">Optimization</h2>
<p>Let’s first describe the overall flow and inner workings of this neural network. Data points \(x_i\) are fed into the encoder which produces vectors of means and variances defining a diagonal Gaussian distribution at the center of the network. A latent variable \(z_i\) is then sampled from \(q_{\phi}(z_i \lvert x_i)\) and fed into the decoder. The decoder outputs another set of parameters defining \(p_{\theta}(x_i \lvert z_i)\) (as discussed previously, these parameters could be means and variances of another Gaussian, or the parameters of a multivariate Bernoulli). During training, the likelihood of the data point \(x_i\) under \(p_{\theta}(x_i \lvert z_i)\) can then be calculated using the Bernoulli PMF or Gaussian PDF, and maximized via gradient descent.</p>
<p>In addition to maximizing the data likelihood, which corresponds to the first term in the objective function, the KL divergence between the encoder distribution \(q_{\phi}(z \lvert x)\) and the prior \(p_{\theta}(z)\) is also minimized. Thankfully, since we have chosen Gaussians for both the prior and the approximate posterior \(q_{\phi}\), the KL divergence term has a closed form solution which can be optimized directly.</p>
<p>Performing gradient descent on the first term also presents additional complications. For one, computing the actual expectation over \(q_{\phi}\) requires an intractable integral (i.e., computing \(\log p_{\theta}(x \lvert z)\) for all possible values of \(z\)). Instead, this expectation is approximated by Monte Carlo sampling. The Monte Carlo approximation states that the expectation of a function can be approximated by the average value of the function across \(N_s\) samples from the distribution:</p>
\[\mathbb{E_{q_{\phi}(z \lvert x)}} \log p_{\theta}(x \lvert z) \approx
\frac{1}{N_s}\sum_{s=1}^{N_s} \log p_{\theta}(x \lvert z_s)\]
<p>In the case of the VAE we approximate the expectation using the single sample from \(q_{\phi}(z \lvert x)\) that we’ve already discussed. This is an unbiased estimate that converges over the training loop.</p>
<p>Another gradient descent-related complication is the sampling step that occurs between the encoder and the decoder. Without getting into the details, directly sampling \(z\) from \(q_{\phi}(z \lvert x)\) introduces a discontinuity that cannot be backpropogated through.</p>
<p><img src="http://adamlineberry.ai/images/vae/architecture-no-reparam.png" alt="" width="500" class="align-center" /></p>
<figcaption>
Diagram of the VAE without the reparameterization trick. Dashed arrows represent the sampling operation.
</figcaption>
<p>The neat solution to this is called the <em>reparameterization trick</em>, which moves the stochastic operation to an input layer and results in continuous linkage between the encoder and decoder allowing for backpropogation all the way through the encoder. Instead of sampling directly from the encoder \(z_i \sim q_{\phi}(z_i \lvert x_i)\), we can represent \(z_i\) as a deterministic function of \(x_i\) and some noise \(\epsilon_i\):</p>
\[z_i = g_{\phi}(x_i, \epsilon_i) = \mu_{\phi}(x_i) + diag(\sigma_{\phi}(x_i)) \cdot \epsilon_i \\
\epsilon_i \sim \mathcal{N}(0, I)\]
<p>You can show that \(z\) defined in this way follows the distribution \(q_{\phi}(z \lvert x)\).</p>
<p><img src="http://adamlineberry.ai/images/vae/architecture-with-reparam.png" alt="" width="550" class="align-center" /></p>
<figcaption>
Diagram of the VAE with the reparameterization trick. Dashed arrows represent the sampling operation.
</figcaption>
<h2 id="practical-uses-of-vae">Practical uses of VAE</h2>
<p>Probably the most famous use of the VAE is to generate/synthesize/hallucinate new data. The synthesis procedure is very simple: draw a random sample from the prior \(p_{\theta}(z)\), and feed that sample through the decoder \(p_{\theta}(x \lvert z)\) to produce a new \(x\). Since the decoder outputs distribution parameters and not real data, you can take the most probable \(x\) from this distribution. When the decoder is Gaussian, this equates to simply taking the mean vector. When it’s Bernoulli, simply round the probabilities to the nearest integer \(\in \{0, 1\}\). Note that for data generation purposes, you can effectively throw away the encoder.</p>
<p>Another practical use is representation learning. It is certainly possible that using the latent representation of your data will improve performance of downstream tasks, such as clustering or classification. After training the VAE you can transform your data by passing it through the encoder and taking the most probable latent vectors \(z\) (which equates to taking the mean vector outputted from the encoder). Data outside of the training set can also be transformed by a previously-trained VAE. Of course, performance will be best when the new data is similar to the training data, i.e., comes from the same domain or natural distribution. As an extreme example, it probably wouldn’t make much sense to transform medical image data using a VAE that was trained on MNIST.</p>
<p>Yet another use is anomaly detection. There are various ways to leverage the probabilistic nature of the VAE to determine when a new data point is very improbable and therefore anomalous. Some examples:</p>
<ul>
<li>Pass the new data through the encoder and measure the KL divergence between the encoder’s distribution and the prior. A high KL divergence would indicate that the new data is dissimilar to the data the VAE saw during training.</li>
<li>Pass the new data through the full VAE and measure the reconstruction probability. Data with a very low reconstruction probability is dissimilar from the training set.</li>
<li>With some additional work it’s possible compute the actual log likelihood \(\log p_{\theta}(x)\) for new data. This approach requires leveraging importance sampling to efficiently compute a new expectation. Look out for more details on this approach in a future post.</li>
</ul>
<h2 id="conclusion">Conclusion</h2>
<p>In this post we introduced the VAE and showed how it is a modern extension of the same theory that motivates the classical expectation maximization algorithm. We also derived the VAE’s objective function and explained some of the intuition behind it.</p>
<p>Some of the important details regarding the neural network architecture and optimization were discussed. We saw how the probabilistic encoder and probabilistic decoder are modeled as neural networks and how the reparameterization trick is used to allow for backpropogation through the entire network.</p>
<p>To see the VAE in action, check out my <a href="http://adamlineberry.ai/vae-series/vae-code-experiments">next post</a> which draws a strong connection between the theory presented here and actual PyTorch code and presents the results of several interesting experiments.</p>
<h2 id="resources">Resources</h2>
<p>[1] Diederik P. Kingma, Max Welling, <a href="https://arxiv.org/abs/1312.6114">Auto-Encoding Variational Bayes</a></p>
<p>[2] Carl Doersch, <a href="https://arxiv.org/abs/1606.05908">Tutorial on Variational Autoencoders</a></p>
<p>[3] Rebecca Vislay Wade, <a href="https://www.kaggle.com/rvislaywade/visualizing-mnist-using-a-variational-autoencoder">Visualizing MNIST with a Deep Variational Autoencoder</a></p>
<p>[4] Volodymyr Kuleshov, Stefano Ermon, <a href="https://ermongroup.github.io/cs228-notes/extras/vae/">The variational auto-encoder</a></p>
<p>[5] Irhum Shafkat, <a href="https://towardsdatascience.com/intuitively-understanding-variational-autoencoders-1bfe67eb5daf">Intuitively Understanding Variational Autoencoders</a></p>
<p>[6] Daniil Polykovskiy, Alexander Novikov, National Research University Higher School of Economics, Coursera, <a href="https://www.coursera.org/learn/bayesian-methods-in-machine-learning">Bayesian Methods for Machine Learning</a></p>
<p>[7] Martin Krasser, <a href="http://krasserm.github.io/2018/04/03/variational-inference/">From expectation maximization to stochastic variational inference</a></p>Adam LineberryFormulation of the VAE objective, neural network architecture design, optimization, and practical usesBlog Post Series: From KL Divergence to Variational Autoencoder in PyTorch2019-07-07T00:00:00+00:002019-07-07T00:00:00+00:00http://adamlineberry.ai/vae-series<p>In this series of four posts, I attempt to build up the theory, mathematics, and intuition of variational autoencoders (VAE), starting with some basic fundamentals and then moving closer and closer to a full PyTorch implementation with each post.</p>
<p><img src="http://adamlineberry.ai/images/vae/vae-architecture.png" alt="" class="align-center" /></p>
<figcaption>Illustration of the VAE model architecture</figcaption>
<p>The ultimate goal of the series is to provide the full picture of variational autoencoders, all the way from expected values to Python classes. The first two posts discuss the general theory and derivations of variational inference and understanding the evidence lower bound (ELBO). The third post evolves general theory to VAE-specific theory. The fourth post establishes the final connections between theory and code, provides a full VAE implementation written in PyTorch, and shows some interesting experiments.</p>
<figure class="half" style="display:flex">
<img src="http://adamlineberry.ai/images/vae/datagen_final.png" height="100" />
<img src="http://adamlineberry.ai/images/vae/frey_face.png" height="100" />
</figure>
<p>Posts:</p>
<ol>
<li><a href="http://adamlineberry.ai/vae-series/kl-divergence"><strong>A Quick Primer on KL Divergence</strong></a>. KL divergence is fundamental tool that is used everywhere in this series. This is a quick introduction for those who may not be familiar already.</li>
<li><a href="http://adamlineberry.ai/vae-series/variational-inference"><strong>Latent Variable Models, Expectation Maximization, and Variational Inference</strong></a>. This post dives into latent variable models and how to train them. It also introduces expectation maximization (EM), which is very related to the VAE.</li>
<li><a href="http://adamlineberry.ai/vae-series/vae-theory"><strong>Variational Autoencoder Theory</strong></a>. Develops the theory of variational inference into the VAE objective function and discusses how the VAE model architecture is designed to achieve specific probabilistic goals.</li>
<li><a href="http://adamlineberry.ai/vae-series/vae-code-experiments"><strong>Variational Autoencoder Code and Experiments</strong></a>. The culmination of the series, this post draws the final connections between the theoretical framework of variational autoencoders and a PyTorch implementation.</li>
</ol>Adam LineberryLanding page for the blog post seriesInteresting Details about ROC Curve Calculations2019-05-20T00:00:00+00:002019-05-20T00:00:00+00:00http://adamlineberry.ai/how-auroc-is-calculated<p>The area under the receiver operating characteristic curve (commonly known as “AUC” or “AUROC”) is a widely used metric for evaluating binary classifiers. Most data scientists are familiar with the famous curve itself, which plots the true positive rate against the false positive rate, and are familiar with integrals (i.e., area under the curve). So, it’s a pretty straightforward concept theoretically, but how is it actually calculated for a real dataset and a real model? That is what we’ll be digging into in this post. There’s some interesting intuition to be gained by understanding the exact implementation (which is quite simple).</p>
<p>Quick disclaimer here: It is not the intent of this post to show how these calculations are implemented in production; there are variations and optimizations to the methodology and code presented. Rather, the intent is to show a basic, easy to understand implementation with the objective of building the reader’s intuition.</p>
<p><img src="http://adamlineberry.ai/images/2019-05-20-how-auroc-is-calculated/Roccurves.png" alt="Sample ROC Curve" height="200" class="align-center" /></p>
<figcaption>Sample ROC Curve</figcaption>
<h2 id="general-discussion">General Discussion</h2>
<p>Before jumping into the code, let’s take a stroll down conversation street and provide a general, high-level, and undoubtedly hand-wavy treatment of the famed ROC curve. Receiver Operator Characteristic. The origin of the name (and the method) traces its roots back to World War II. Radar operators (or receivers) sat in front of a display and were tasked with sounding an alarm whenever an enemy aircraft was detected. Of course, radar signals can be quite noisy and it was difficult to distinguish between an enemy bomber and something far less menacing, such as a flock of geese. So, in effect, these radar operators were functioning as binary classifiers. There was a dire need to identify as many enemy aircraft as possible (recall, true positive rate), while minimizing the number of times the base went into high alert over an innocent flock of geese (false positive rate). Thus, the ROC curve was introduced as a method to analyze the performance of radar operators.</p>
<p><img src="http://adamlineberry.ai/images/2019-05-20-how-auroc-is-calculated/700px-Precisionrecall.svg.png" alt="" height="100" width="300" class="align-center" /></p>
<figcaption>Binary Classification Space</figcaption>
<p>The idealized ROC curve is continuous across all possible classification thresholds. Points that are plotted on the ROC curve correspond to particular classification thresholds \(T \in (-\infty, \infty)\). In the real world we are dealing with a discrete number of data points with which we would like to estimate the ROC curve for a classifier of interest. This manifests itself in ROC curves that can look a bit jumpy rather than smooth. Instead of considering all possible thresholds, we only have \(N\) thresholds to consider, where \(N\) is the number of data points in the dataset we are evaluating.</p>
<p>The way I like to think about calculating ROC and AUC is to consider a simple table with columns for \(Y\) and \(\hat{Y}\), sorted descending by \(\hat{Y}\). You then iterate over rows of this table and the threshold you consider at any given moment is wherever the cursor of your iterator is. There is no need to quantify this threshold (e.g., \(T=0.75\)), it is simply something that classifies all data points above it as the positive class and all data points below it as the negative class. From here it is easy to calculate and record the FPR and TPR for this threshold. This (FPR, TPR) pair will then become a data point plotted on the ROC curve. When you are finished iterating over your data points, you have \(N\) (FPR, TPR) data points which are plotted to form the full ROC curve. This exact algorithm is implemented in code later in this post.</p>
<p>The only thing that matters in calculating the ROC curve and its AUC is the rank ordering of the predictions. Typically, normalized model outputs \(p \in [0, 1]\) are used for this, but as I will show in this post, unnormalized model outputs, such as outputs from a linear layer before sigmoid application, \(s \in (-\infty, \infty)\) are equally valid.</p>
<p>A common mistake to be avoided at all costs is calculating AUC using binarized predictions, e.g., \(\hat{Y} \in \{0, 1\}\) instead of scores or probabilities \(\hat{Y} \in (-\infty, \infty)\). The scary thing about this mistake is that most implementations like scikit-learn’s <code class="language-plaintext highlighter-rouge">roc_auc_score</code> will not throw an error. The computation can still be performed, but the critical sorting step doesn’t make sense anymore and the result will be something… strange.</p>
<p>Many references will describe the computation of the area under the ROC curve using an integral and leave it at that. An integral may be a technically correct description, but it doesn’t give the reader any intuition about how this area calculation is actually performed. It’s actually quite simple. Once you understand the algorithm described above, you can see that the ROC curve itself is really just a bunch of right angles. Thus, the area under the curve can be calculated as the sum of the area of several rectangles.</p>
<p><a href="https://www.r-bloggers.com/calculating-auc-the-area-under-a-roc-curve/">
<img src="http://adamlineberry.ai/images/2019-05-20-how-auroc-is-calculated/roc-curve-rectangles.png" alt="ROC Curve Rectangles" height="200" class="align-center" /></a></p>
<figcaption>ROC Curve Composed of Rectangles</figcaption>
<h2 id="tutorial">Tutorial</h2>
<p>In this section we will illustrate the concepts discussed above with a Python implementation.</p>
<p>Imports…</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span><span class="p">,</span> <span class="n">DataLoader</span>
<span class="kn">from</span> <span class="nn">fastprogress</span> <span class="kn">import</span> <span class="n">master_bar</span><span class="p">,</span> <span class="n">progress_bar</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>
<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">roc_auc_score</span><span class="p">,</span> <span class="n">roc_curve</span><span class="p">,</span> <span class="n">auc</span>
</code></pre></div></div>
<h3 id="generate-synthetic-dataset">Generate Synthetic Dataset</h3>
<p>To keep things on the data front simple, I generate 100 data points each from two 2-dimensional Gaussians. Care is taken to ensure the classes are not linearly separable.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x0</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">multivariate_normal</span><span class="p">([</span><span class="o">-</span><span class="mf">1.5</span><span class="p">,</span><span class="mi">0</span><span class="p">],</span> <span class="p">[[</span><span class="mi">1</span><span class="p">,</span><span class="mi">0</span><span class="p">],[</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">]],</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">x1</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">multivariate_normal</span><span class="p">([</span><span class="mf">1.5</span><span class="p">,</span><span class="mi">0</span><span class="p">],</span> <span class="p">[[</span><span class="mi">1</span><span class="p">,</span><span class="mi">0</span><span class="p">],[</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">]],</span> <span class="mi">100</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span><span class="mi">5</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">x0</span><span class="p">[:,</span><span class="mi">0</span><span class="p">],</span><span class="n">x0</span><span class="p">[:,</span><span class="mi">1</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s">'class 0'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">x1</span><span class="p">[:,</span><span class="mi">0</span><span class="p">],</span><span class="n">x1</span><span class="p">[:,</span><span class="mi">1</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s">'class 1'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">'Synthetically Generated Dataset'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">'feature1'</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">'feature2'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<p><img src="http://adamlineberry.ai/images/2019-05-20-how-auroc-is-calculated/output_3_0.png" alt="png" /></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">X</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">x0</span><span class="p">,</span><span class="n">x1</span><span class="p">]),</span> <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">'feature1'</span><span class="p">,</span><span class="s">'feature2'</span><span class="p">])</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="mi">0</span><span class="p">]</span><span class="o">*</span><span class="mi">100</span> <span class="o">+</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">*</span><span class="mi">100</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">y</span><span class="p">.</span><span class="n">shape</span>
<span class="c1"># ((200, 2), (200,))
</span></code></pre></div></div>
<p>Train/test split.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">X_train</span><span class="p">,</span> <span class="n">X_valid</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">y_valid</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">199</span><span class="p">)</span>
</code></pre></div></div>
<h3 id="pytorch-data-and-model">Pytorch Data and Model</h3>
<p>I define a Pytorch implementation of logistic regression to model the data.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">SimpleDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="bp">self</span><span class="p">.</span><span class="n">X</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">X</span><span class="p">.</span><span class="n">values</span><span class="p">).</span><span class="nb">type</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">FloatTensor</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">y</span><span class="p">).</span><span class="nb">type</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">FloatTensor</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">X</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">X</span><span class="p">[</span><span class="n">index</span><span class="p">],</span> <span class="bp">self</span><span class="p">.</span><span class="n">y</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># instantiate datasets and dataloaders for train and valid data
</span><span class="n">train_ds</span> <span class="o">=</span> <span class="n">SimpleDataset</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">valid_ds</span> <span class="o">=</span> <span class="n">SimpleDataset</span><span class="p">(</span><span class="n">X_valid</span><span class="p">,</span> <span class="n">y_valid</span><span class="p">)</span>
<span class="n">train_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">150</span><span class="p">)</span>
<span class="n">valid_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">valid_ds</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">50</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">LogisticRegression</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">):</span>
<span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
<span class="c1"># single linear layer. non-linearity is handled
</span> <span class="c1"># by the loss function
</span> <span class="bp">self</span><span class="p">.</span><span class="n">lin</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">lin</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span>
</code></pre></div></div>
<h3 id="train-model">Train Model</h3>
<p>The model is trained on a CPU for 100 epochs at a fairly low learning rate for this data.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">train_ds</span><span class="p">.</span><span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">lr</span> <span class="o">=</span> <span class="mf">3e-2</span>
<span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">)</span>
<span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">()</span>
<span class="n">final_actn</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Sigmoid</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
<span class="n">train_losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">valid_losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">mb</span> <span class="o">=</span> <span class="n">master_bar</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">num_epochs</span><span class="p">))</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">mb</span><span class="p">:</span>
<span class="n">pb</span> <span class="o">=</span> <span class="n">progress_bar</span><span class="p">(</span><span class="n">train_dl</span><span class="p">,</span> <span class="n">parent</span><span class="o">=</span><span class="n">mb</span><span class="p">)</span>
<span class="n">train_batch_losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">pb</span><span class="p">:</span>
<span class="n">optim</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">optim</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
<span class="n">train_batch_losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">())</span>
<span class="n">train_losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">train_batch_losses</span><span class="p">).</span><span class="n">mean</span><span class="p">())</span>
<span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="n">valid_batch_losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">X_val</span><span class="p">,</span> <span class="n">y_val</span> <span class="ow">in</span> <span class="n">valid_dl</span><span class="p">:</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">X_val</span><span class="p">)</span>
<span class="n">val_loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="n">y_val</span><span class="p">).</span><span class="n">item</span><span class="p">()</span>
<span class="n">valid_batch_losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">val_loss</span><span class="p">)</span>
<span class="n">valid_losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">valid_batch_losses</span><span class="p">).</span><span class="n">mean</span><span class="p">())</span>
</code></pre></div></div>
<p>Total time: 00:00 <p></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">()</span>
<span class="n">x</span> <span class="o">=</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">train_losses</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">train_losses</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'train'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">valid_losses</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'valid'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">'Training and Validation Losses'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">'epoch'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">'loss'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<p><img src="http://adamlineberry.ai/images/2019-05-20-how-auroc-is-calculated/output_15_0.png" alt="png" /></p>
<p>As seen in the plot above, the model was still improving when training was stopped and beginning to slightly overfit.</p>
<h3 id="calculate-auc-using-scikit-learn-function">Calculate AUC using Scikit-Learn Function</h3>
<p>Score the validation set using the trained model:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">valid_scores</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="k">for</span> <span class="n">X_val</span><span class="p">,</span> <span class="n">y_val</span> <span class="ow">in</span> <span class="n">valid_dl</span><span class="p">:</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">X_val</span><span class="p">)</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">).</span><span class="n">tolist</span><span class="p">()</span>
<span class="n">valid_scores</span><span class="p">.</span><span class="n">extend</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
</code></pre></div></div>
<p>As you can see, since the model was not defined with a sigmoid output layer, the raw model outputs are unnormalized scores being emitted from the single linear layer.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">valid_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">valid_scores</span><span class="p">)</span>
<span class="n">valid_scores</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([ 1.3256, -3.9111, 0.9515, 3.7141, -3.9993, -3.3840, -1.6937, -1.7872,
5.4634, -2.6962, 0.3090, -3.8332, 2.0432, -0.4319, -1.3281, -1.3519,
1.3732, 2.6428, 0.5165, -0.6518, 1.5274, 4.4482, -1.7946, -1.2051,
-0.7633, 2.7398, -2.3134, 2.7641, 4.1584, -0.0191, -2.0982, 2.8374,
-1.0771, -2.8697, 2.5235, -2.8222, 4.1701, -0.9285, 4.1537, -2.7113,
2.5709, -3.7759, 3.6061, 1.5652, -1.8460, 1.0918, -0.2882, 3.0891,
5.1594, -2.1279])
</code></pre></div></div>
<p>These unnormalized scores are mapped into probabilities using the sigmoid function:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">valid_probas</span> <span class="o">=</span> <span class="n">final_actn</span><span class="p">(</span><span class="n">valid_scores</span><span class="p">)</span>
<span class="n">valid_probas</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([0.7901, 0.0196, 0.7214, 0.9762, 0.0180, 0.0328, 0.1553, 0.1434, 0.9958,
0.0632, 0.5766, 0.0212, 0.8853, 0.3937, 0.2095, 0.2056, 0.7979, 0.9336,
0.6263, 0.3426, 0.8216, 0.9884, 0.1425, 0.2306, 0.3179, 0.9393, 0.0900,
0.9407, 0.9846, 0.4952, 0.1093, 0.9447, 0.2541, 0.0537, 0.9258, 0.0561,
0.9848, 0.2832, 0.9845, 0.0623, 0.9290, 0.0224, 0.9736, 0.8271, 0.1363,
0.7487, 0.4284, 0.9564, 0.9943, 0.1064])
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">valid_scores</span> <span class="o">=</span> <span class="n">valid_scores</span><span class="p">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">valid_probas</span> <span class="o">=</span> <span class="n">valid_probas</span><span class="p">.</span><span class="n">numpy</span><span class="p">()</span>
</code></pre></div></div>
<p>Calculate the AUC using the normalized model outputs \(\hat{Y} \in [0, 1]\), as is typically done:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">roc_auc_score</span><span class="p">(</span><span class="n">y_valid</span><span class="p">,</span> <span class="n">valid_probas</span><span class="p">)</span>
<span class="c1"># 0.9759615384615384
</span></code></pre></div></div>
<p>In contrast, calculate the AUC using the unnormalized outputs \(\hat{Y} \in (-\infty, \infty)\). Notice that the AUC is exactly the same.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">roc_auc_score</span><span class="p">(</span><span class="n">y_valid</span><span class="p">,</span> <span class="n">valid_scores</span><span class="p">)</span>
<span class="c1"># 0.9759615384615384
</span></code></pre></div></div>
<h3 id="manually-construct-roc-curve-and-auc-calculation">Manually Construct ROC Curve and AUC Calculation</h3>
<p>To begin our manual calculation, let’s toss the model probabilities and true values into a dataframe:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">auc_data</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">({</span><span class="s">'probas'</span><span class="p">:</span><span class="n">valid_probas</span><span class="p">,</span> <span class="s">'y_true'</span><span class="p">:</span><span class="n">y_valid</span><span class="p">})</span>
<span class="n">auc_data</span><span class="p">.</span><span class="n">head</span><span class="p">()</span>
</code></pre></div></div>
<div>
<style scoped="">
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
<table border="1" class="dataframe">
<thead>
<tr style="text-align: right;">
<th></th>
<th>probas</th>
<th>y_true</th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td>0.790115</td>
<td>1</td>
</tr>
<tr>
<th>1</th>
<td>0.019625</td>
<td>0</td>
</tr>
<tr>
<th>2</th>
<td>0.721424</td>
<td>1</td>
</tr>
<tr>
<th>3</th>
<td>0.976202</td>
<td>1</td>
</tr>
<tr>
<th>4</th>
<td>0.017999</td>
<td>0</td>
</tr>
</tbody>
</table>
</div>
<p>Sort the data by the model probabilities:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">auc_data</span><span class="p">.</span><span class="n">sort_values</span><span class="p">(</span><span class="n">by</span><span class="o">=</span><span class="s">'probas'</span><span class="p">,</span> <span class="n">ascending</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">inplace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">auc_data</span> <span class="o">=</span> <span class="n">auc_data</span><span class="p">.</span><span class="n">reset_index</span><span class="p">(</span><span class="n">drop</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">auc_data</span><span class="p">.</span><span class="n">head</span><span class="p">()</span>
</code></pre></div></div>
<div>
<style scoped="">
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
<table border="1" class="dataframe">
<thead>
<tr style="text-align: right;">
<th></th>
<th>probas</th>
<th>y_true</th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td>0.995779</td>
<td>1</td>
</tr>
<tr>
<th>1</th>
<td>0.994287</td>
<td>1</td>
</tr>
<tr>
<th>2</th>
<td>0.988435</td>
<td>1</td>
</tr>
<tr>
<th>3</th>
<td>0.984784</td>
<td>1</td>
</tr>
<tr>
<th>4</th>
<td>0.984608</td>
<td>1</td>
</tr>
</tbody>
</table>
</div>
<p>Create a simple “rank” column:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">auc_data</span><span class="p">[</span><span class="s">'rank'</span><span class="p">]</span> <span class="o">=</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="nb">len</span><span class="p">(</span><span class="n">auc_data</span><span class="p">)</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span>
<span class="n">auc_data</span><span class="p">.</span><span class="n">head</span><span class="p">()</span>
</code></pre></div></div>
<div>
<style scoped="">
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
<table border="1" class="dataframe">
<thead>
<tr style="text-align: right;">
<th></th>
<th>probas</th>
<th>y_true</th>
<th>rank</th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td>0.995779</td>
<td>1</td>
<td>1</td>
</tr>
<tr>
<th>1</th>
<td>0.994287</td>
<td>1</td>
<td>2</td>
</tr>
<tr>
<th>2</th>
<td>0.988435</td>
<td>1</td>
<td>3</td>
</tr>
<tr>
<th>3</th>
<td>0.984784</td>
<td>1</td>
<td>4</td>
</tr>
<tr>
<th>4</th>
<td>0.984608</td>
<td>1</td>
<td>5</td>
</tr>
</tbody>
</table>
</div>
<p><strong>Delete the model probabilities data</strong> in order to illustrate the point that they aren’t needed for ROC or AUC calculations (after they’ve been used to rank order):</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">auc_data</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">'probas'</span><span class="p">],</span> <span class="n">inplace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">auc_data</span><span class="p">.</span><span class="n">head</span><span class="p">()</span>
</code></pre></div></div>
<div>
<style scoped="">
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
<table border="1" class="dataframe">
<thead>
<tr style="text-align: right;">
<th></th>
<th>y_true</th>
<th>rank</th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td>1</td>
<td>1</td>
</tr>
<tr>
<th>1</th>
<td>1</td>
<td>2</td>
</tr>
<tr>
<th>2</th>
<td>1</td>
<td>3</td>
</tr>
<tr>
<th>3</th>
<td>1</td>
<td>4</td>
</tr>
<tr>
<th>4</th>
<td>1</td>
<td>5</td>
</tr>
</tbody>
</table>
</div>
<p>Precompute a cumulative sum of the true values. This will come in handy later when we’re performing the calculations.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">auc_data</span><span class="p">[</span><span class="s">'y_true_cumsum'</span><span class="p">]</span> <span class="o">=</span> <span class="n">auc_data</span><span class="p">[</span><span class="s">'y_true'</span><span class="p">].</span><span class="n">cumsum</span><span class="p">()</span>
<span class="n">auc_data</span><span class="p">.</span><span class="n">head</span><span class="p">()</span>
</code></pre></div></div>
<div>
<style scoped="">
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
<table border="1" class="dataframe">
<thead>
<tr style="text-align: right;">
<th></th>
<th>y_true</th>
<th>rank</th>
<th>y_true_cumsum</th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td>1</td>
<td>1</td>
<td>1</td>
</tr>
<tr>
<th>1</th>
<td>1</td>
<td>2</td>
<td>2</td>
</tr>
<tr>
<th>2</th>
<td>1</td>
<td>3</td>
<td>3</td>
</tr>
<tr>
<th>3</th>
<td>1</td>
<td>4</td>
<td>4</td>
</tr>
<tr>
<th>4</th>
<td>1</td>
<td>5</td>
<td>5</td>
</tr>
</tbody>
</table>
</div>
<p>Now it’s time for the real computation. As discussed previously, we are going to iterate over the sorted predictions, consider the cursor as a threshold, and compute statistics for each iteration.</p>
<p>As a refresher, recall that the ROC curve plots True Positive Rate (TPR) vs. False Positive Rate (FPR).</p>
<ul>
<li>True Positive Rate (TPR), Recall, “Probability of Detection”</li>
<li>False Positive Rate (FPR), “Probability of False Alarm”</li>
</ul>
<p>Precompute the number of data points in the positive and negative classes:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_pos</span> <span class="o">=</span> <span class="p">(</span><span class="n">auc_data</span><span class="p">[</span><span class="s">'y_true'</span><span class="p">]</span><span class="o">==</span><span class="mi">1</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
<span class="n">n_neg</span> <span class="o">=</span> <span class="p">(</span><span class="n">auc_data</span><span class="p">[</span><span class="s">'y_true'</span><span class="p">]</span><span class="o">==</span><span class="mi">0</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>
<p>This is the tricky bit. I did what I could to explain each step in the code comments:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">tpr</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">];</span> <span class="n">fpr</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">];</span> <span class="n">area</span> <span class="o">=</span> <span class="p">[]</span>
<span class="c1"># iterate over data points, ie, **thresholds**
</span><span class="k">for</span> <span class="n">_</span><span class="p">,</span><span class="n">row</span> <span class="ow">in</span> <span class="n">auc_data</span><span class="p">.</span><span class="n">iterrows</span><span class="p">():</span>
<span class="c1"># the "rank" column conveniently proxies for the number of
</span> <span class="c1"># data points being predicted as the positive class
</span> <span class="n">num_pred_p</span> <span class="o">=</span> <span class="n">row</span><span class="p">[</span><span class="s">'rank'</span><span class="p">]</span>
<span class="c1"># the cumulative sum of y_true equals the the number of
</span> <span class="c1"># true positives at this threshold
</span> <span class="n">num_tp</span> <span class="o">=</span> <span class="n">row</span><span class="p">[</span><span class="s">'y_true_cumsum'</span><span class="p">]</span>
<span class="c1"># the number of false positives is then the difference
</span> <span class="c1"># between the total number of predicted positives
</span> <span class="c1"># and the number of true positives
</span> <span class="n">num_fp</span> <span class="o">=</span> <span class="n">num_pred_p</span> <span class="o">-</span> <span class="n">num_tp</span>
<span class="c1"># compute TPR and FPR at this threshold and store it
</span> <span class="n">tpr_tmp</span> <span class="o">=</span> <span class="n">num_tp</span> <span class="o">/</span> <span class="n">n_pos</span>
<span class="n">fpr_tmp</span> <span class="o">=</span> <span class="n">num_fp</span> <span class="o">/</span> <span class="n">n_neg</span>
<span class="n">tpr</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">tpr_tmp</span><span class="p">);</span> <span class="n">fpr</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">fpr_tmp</span><span class="p">)</span>
<span class="c1"># compute the area of the little rectangle at this threshold
</span> <span class="n">delta_fpr</span> <span class="o">=</span> <span class="p">(</span><span class="n">fpr</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">fpr</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">])</span>
<span class="n">area_tmp</span> <span class="o">=</span> <span class="n">tpr</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">delta_fpr</span>
<span class="n">area</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">area_tmp</span><span class="p">)</span>
</code></pre></div></div>
<p>Using our hand-calculated values, let’s plot the ROC curve and compute the AUC:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">()</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">fpr</span><span class="p">,</span> <span class="n">tpr</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">'ROC Curve | Manual Calculation'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">'FPR'</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">'TPR'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<p><img src="http://adamlineberry.ai/images/2019-05-20-how-auroc-is-calculated/output_38_0.png" alt="png" /></p>
<p>AUC, manual calculation:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">area</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
<span class="c1"># 0.9759615384615383
</span></code></pre></div></div>
<p>To check our work, let’s plot the ROC curve and compute the AUC using scikit-learn:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fpr_skl</span><span class="p">,</span> <span class="n">tpr_skl</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">roc_curve</span><span class="p">(</span><span class="n">y_valid</span><span class="p">,</span> <span class="n">valid_probas</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">()</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">fpr_skl</span><span class="p">,</span> <span class="n">tpr_skl</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">'ROC Curve | Scikit-learn Calculation'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">'FPR'</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">'TPR'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<p><img src="http://adamlineberry.ai/images/2019-05-20-how-auroc-is-calculated/output_43_0.png" alt="png" /></p>
<p>AUC, scikit-learn calculation:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">auc</span><span class="p">(</span><span class="n">fpr_skl</span><span class="p">,</span> <span class="n">tpr_skl</span><span class="p">)</span>
<span class="c1"># 0.9759615384615384
</span></code></pre></div></div>
<p>Whew! It checks out.</p>
<h2 id="conclusion">Conclusion</h2>
<p>In this post we covered the intuition behind ROC/AUC calculations, and warned against some common mistakes. We also proved the calculations can be performed using unnormalized model scores, and performed hand-calculations for a custom Pytorch logistic regression model trained on synthetic data and verified the results against scikit-learn results.</p>
<p>The notebook associated with the code in this post can be found <a href="https://github.com/acetherace/alcore/blob/master/notebooks/how-auroc-is-calculated.ipynb">here</a>.</p>Adam LineberryCode-based walkthrough showing how AUC is computed for a simple model trained on a simple dataset