VAE (Variational AutoEncoder)

Takeaway:

  1. What is the architecture of VAE and how does the generation process look like?
  2. What is the training objective of VAE?
  3. Why we need to use Variational Inference in VAE?
  4. How to implement and train a VAE model?

--- VAE Architecture¶

The overall architecture of VAE consists of three parts:

  • Encoder: A neural network that maps an input $\mathbf{x}$ to the parameters of an approximate posterior $q_\phi(\mathbf{z} \mid \mathbf{x})$, where $\phi$ refers to parameters of encoder nerwork. For example, suppose $q_\phi(\mathbf{z} \mid \mathbf{x})$ is Gaussian, the decoder outputs parameters $\mu, \sigma$.
  • Latent Space: In training, a latent vector $\mathbf{z}$ is sampled from $q_\phi(\mathbf{z} \mid \mathbf{x})$ (which is regularized to follow a prior $p(\mathbf{z})$ such as $\mathcal{N}(0,I)$); in generation, a latent vector $\mathbf{z}$ is sampled from the prior $p(\mathbf{z})$.
  • Decoder: A neural network that takes $\mathbf{z}$ and produces the parameters of the conditional likelihood distribution $p_\theta(\mathbf{x} \mid \mathbf{z})$, where $\theta$ refers to parameters of decoder network. For example, suppose $p_\theta(\mathbf{x} \mid \mathbf{z})$ is Bernoulli, the decoder outputs probabilities $\mathbf{p}$, or suppose $p_\theta(\mathbf{x} \mid \mathbf{z})$ is Gaussian, the decoder outputs parameters $\mu, \sigma$.

⚠️ The output of the encoder and decoder, are parameters of probability distributions, instead of the generated data itself.

The latent space captures hidden factors such as digit identity, stroke thickness, tilt, or writing style, and generative models start from this space to mimic the true underlying process of data formation. The complete generation process is then:

  1. sample latent vector from the prior $\mathbf{z} \sim \mathcal{N}(0,I)$,
  2. pass $\mathbf{z}$ through the decoder to obtain distribution parameters (thereby defining the full distribution),
  3. sample data $\mathbf{x} \sim p_\theta(\mathbf{x} \mid \mathbf{z})$ as the generated output.

In VAE’s generation process, we only use the latent space and the decoder—there is no role for the encoder. The design philosophy of VAEs can be seen as starting with a generative system from latent variables (which is common; for example, GANs map latent variables to data with a generator, normalizing flows transform latent variables through invertible mappings, and diffusion models gradually transform latent noise into data) via decoder. How do we optimize this decoder? We want to find the MLE under the input data. Since this MLE cannot be computed in closed form, we turn to Bayesian rule and variational inference. We discover that by introducing $q_\phi(\mathbf{z} \mid \mathbf{x})$—an approximation of $p(\mathbf{z}\mid \mathbf{x})$, and optimizing jointly, the decoder can be pushed arbitrarily close to the true MLE. This approximation is exactly modeled by the encoder, which is why we need it. The encoder serves the purpose of training and inference, but not generation.

--- Training Objective of VAE: From Maximizing Likelihood¶

We want the trained decoder to generate samples that resemble the data distribution. The model distribution $p_\theta(\mathbf{x}) = \int p_\theta(\mathbf{x} \mid \mathbf{z}) \, p(\mathbf{z}) \, d\mathbf{z}$ should be as close as possible to the true data distribution $p_\text{data}(\mathbf{x})$. A natural metric is the KL divergence between those two distributions: $$ D_{\text{KL}}(p_\text{data}(\mathbf{x}) \| p_\theta(\mathbf{x})) = \mathbb{E}_{p_\text{data}(\mathbf{x})}\big[ \log p_\text{data}(\mathbf{x}) - \log p_\theta(\mathbf{x}) \big]. $$ Since $\log p_\text{data}(\mathbf{x})$ does not depend on model parameters, minimizing the KL is equivalent to maximizing the probability that the model assigns to the observed data (Maximum Likelihood Estimation (MLE)). In fact, the definition of MLE itself also fits generative modeling: it chooses parameters $\theta$ that make the observed dataset most probable under the model $$ \max_\theta \; \mathbb{E}_{p_\text{data}(\mathbf{x})} [ \log p_\theta(\mathbf{x}) ] = \max_\theta \; \int p_\text{data}(\mathbf{x}) \, \log p_\theta(\mathbf{x}) \, d\mathbf{x} $$ where $\log p_\theta(\mathbf{x})$ is the marginal likelihood. The empirical distribution $p_\text{data}(\mathbf{x})$ does not depend on $\theta$, so optimizing $\mathbb{E}_{p_\text{data}(\mathbf{x})}[\log p_\theta(\mathbf{x})]$ is equivalent to maximizing $\log p_\theta(\mathbf{x})$ over all samples.

And for VAE, the marginal likelihood is $$ \log p_\theta(\mathbf{x}) = \log \int p_\theta(\mathbf{x} \mid \mathbf{z}) \, p(\mathbf{z}) \, d\mathbf{z}, $$ which is ideally the training objective of VAE. We seek to maximize this objective in training. In following part we'll see that this objective is intractable, and introduce a lower bound surrogate of it.

Intractable $p_\theta(\mathbf{x}) = \int p_\theta(\mathbf{x} \mid \mathbf{z}) \, p(\mathbf{z}) \, d\mathbf{z}$¶

In a VAE, the conditional distribution $p_\theta(\mathbf{x}\mid \mathbf{z})$ always has a closed-form expression, even if the decoder is a deep neural network. This is because we assume $p_\theta(\mathbf{x}\mid \mathbf{z})$ belongs to a known distribution family (e.g., Bernoulli, Gaussian), which guarantees an explicit density function. The decoder simply maps the latent variable $\mathbf{z}$ to the parameters of that distribution. For instance, in the Gaussian case the density is $\mathcal{N}(x;\mu,\sigma^2I) = \frac{1}{(2\pi\sigma^2)^{d/2}} \exp\!\left(-\tfrac{1}{2\sigma^2}\|x-\mu\|^2\right)$ and in a VAE we set $\mu=f_\theta(\mathbf{z})$, yielding $$ p_\theta(\mathbf{x}\mid \mathbf{z}) = \frac{1}{(2\pi\sigma^2)^{d/2}} \exp\!\left(-\tfrac{1}{2\sigma^2}\|\mathbf{x}-f_\theta(\mathbf{z})\|^2\right). $$ But since $p_\theta(\mathbf{x} \mid \mathbf{z}) \, p(\mathbf{z})$ is not integrable, the integral $p_\theta(\mathbf{x})= \int p_\theta(\mathbf{x} \mid \mathbf{z}) \, p(\mathbf{z}) \, d\mathbf{z}$ typically has no closed-form solution, even though $p_\theta(\mathbf{x}\mid \mathbf{z})$ has a closed-form. Intuitively, this is like facing a strange integral problem in a calculus class: you might try clever tricks to solve it, but unlike those classroom exercises, here no amount of manipulation will yield an analytic answer.

Generally (not restricted to VAEs), it is almost always the case that the conditional distribution $p_\theta(\mathbf{x}\mid \mathbf{z})$ has a closed-form, while the marginal distribution $p_\theta(\mathbf{x})$ does not, because $p_\theta(\mathbf{x}\mid \mathbf{z})p(\mathbf{z})$ is generally not analytically integrable. A notable exception is the linear Gaussian model. The Gaussian family is closed under linear transformations and convolution, so if $$p(\mathbf{z}) = \mathcal{N}(\mathbf{z};0,I), \quad p_\theta(\mathbf{x}\mid \mathbf{z}) = \mathcal{N}(\mathbf{x};W\mathbf{z}+b,\sigma^2 I),$$ then the marginal also has a closed form, $$p_\theta(\mathbf{x}) = \mathcal{N}(\mathbf{x};b,WW^\top+\sigma^2 I)$$

Lower Bound Surrogate of $p_\theta(\mathbf{x})$ By Variational Inference¶

How can we maximize on $\log p_\theta(\mathbf{x})$ if we cannot compute a closed-form by $p_\theta(\mathbf{x}) = \int p_\theta(\mathbf{x} \mid \mathbf{z}) \, p(\mathbf{z}) \, d\mathbf{z}$? There is an alternative way to compute $p_\theta(\mathbf{x})$ - Bayesian Rule. We explicitly write $p_\theta(\mathbf{z}\mid \mathbf{x})$ (instead of simply $p(\mathbf{z}\mid \mathbf{x})$) to emphasize that it is determined by the decoder likelihood $p_\theta(\mathbf{x}\mid \mathbf{z})$ via Bayesian Rule, and therefore varies with the decoder parameters. In other words, the true posterior is not free-form but constrained by the expressiveness and parameterization of the decoder.

$$ p_\theta(\mathbf{z}\mid \mathbf{x}) = \frac{p_\theta(\mathbf{x}\mid \mathbf{z})\,p(\mathbf{z})}{p_\theta(\mathbf{x})} \quad \Rightarrow \quad p_\theta(\mathbf{x}) = \frac{p_\theta(\mathbf{x}\mid \mathbf{z})\,p(\mathbf{z})}{p_\theta(\mathbf{z}\mid \mathbf{x})}, \quad p_\theta(\mathbf{z}\mid \mathbf{x}) > 0 $$ This is where variational inference (VI) comes in. Instead of working with the intractable posterior $p_\theta(\mathbf{z}\mid \mathbf{x})$ (since $p_\theta(\mathbf{x})$ doesn't have a closed-form, $p_\theta(\mathbf{z}\mid \mathbf{x}) = \frac{p_\theta(\mathbf{x}\mid \mathbf{z})\,p(\mathbf{z})}{p_\theta(\mathbf{x})}$ doesn't have a closed-form neither), we introduce an approximate posterior $q_\phi(\mathbf{z}\mid \mathbf{x})$ modeled by the encoder.

Variational Inference (VI) (for more info, see VI notes in Probability)

We have observed data $\mathbf{x}$, latent variables $\mathbf{z}$ with a prior $p(\mathbf{z})$, and a model likelihood $p(\mathbf{x}\mid \mathbf{z})$. The goal is to compute the posterior (such as in Bayesian Parameter Estimation, Bayesian Linear Regression etc.) $$p(\mathbf{z} \mid \mathbf{x}) = \frac{p(\mathbf{x}\mid \mathbf{z})\,p(\mathbf{z})}{p(\mathbf{x})}$$ but $p(\mathbf{x})$, $p(\mathbf{z} \mid \mathbf{x})$ are usually intractable. Variational Inference (VI) addresses this by introducing a simpler parameterized distribution $p'(\mathbf{z} \mid \mathbf{x})$ (often Gaussian) to approximate $p(\mathbf{z} \mid \mathbf{x})$, and finding its parameters by minimizing a divergence—typically the KL divergence—between the approximation and the target.

Taking the logarithm of the above, we obtain an identity that holds pointwise for each $\mathbf{z}$ — although the right-hand side involves $\mathbf{z}$, the whole expression equals the same constant $\log p_\theta(\mathbf{x})$: $$\ \log p_\theta(\mathbf{x}) = \log p_\theta(\mathbf{x}\mid \mathbf{z})+\log p(\mathbf{z})-\log p_\theta(\mathbf{z}\mid \mathbf{x}) $$ For any approximate distribution $q_\phi(\mathbf{z}\mid \mathbf{x})$ (with support contained in the region where $p_\theta(\mathbf{z}\mid \mathbf{x})>0$, to ensure the logarithm is well-defined), take expectation on both sides: $$ \begin{aligned} \mathbb{E}_{q_\phi}\big[\log p_\theta(\mathbf{x})\big] &= \mathbb{E}_{q_\phi}\!\left[\log p_\theta(\mathbf{x}\mid \mathbf{z})+\log p(\mathbf{z})-\log p_\theta(\mathbf{z}\mid \mathbf{x})\right]. \end{aligned} $$ The left-hand side is the expectation of a constant: $$ \mathbb{E}_{q_\phi}\big[\log p_\theta(\mathbf{x})\big] = \log p_\theta(\mathbf{x})\cdot \underbrace{\int q_\phi(\mathbf{z}\mid \mathbf{x})\,d\mathbf{z}}_{=1} = \log p_\theta(\mathbf{x}). $$ Thus we obtain $$ \log p_\theta(\mathbf{x}) =\mathbb{E}_{q_\phi}\!\left[\log p_\theta(\mathbf{x}\mid \mathbf{z})+\log p(\mathbf{z})-\log p_\theta(\mathbf{z}\mid \mathbf{x})\right] $$ Now we use the ELBO decomposition - a core decomposition formula in VI to break down $\log p_\theta(\mathbf{x})$. Inside the expectation we add and subtract $\log q_\phi(\mathbf{z}\mid \mathbf{x})$: $$ \begin{aligned} \log p_\theta(\mathbf{x}) &=\mathbb{E}_{q_\phi}\!\left[\log p_\theta(\mathbf{x}\mid \mathbf{z})+\log p(\mathbf{z})-\log q_\phi(\mathbf{z}\mid \mathbf{x})\right] \\&\quad+\mathbb{E}_{q_\phi}\!\left[\log q_\phi(\mathbf{z}\mid \mathbf{x})-\log p_\theta(\mathbf{z}\mid \mathbf{x})\right] \end{aligned} $$ This yields the exact decomposition: $$ \boxed{\ \log p_\theta(\mathbf{x}) = \underbrace{\mathbb{E}_{q_\phi}\!\left[\log p_\theta(\mathbf{x}\mid \mathbf{z})+\log p(\mathbf{z})-\log q_\phi(\mathbf{z}\mid \mathbf{x})\right]}_{\text{ELBO}} \;+\; D_{\mathrm{KL}}\!\big(q_\phi(\mathbf{z}\mid \mathbf{x})\,\|\,p_\theta(\mathbf{z}\mid \mathbf{x})\big)\ } $$ Because $D_{\mathrm{KL}} \ge 0$ always, we obtain the lower bound $\log p_\theta(\mathbf{x}) \ge \text{ELBO}$. The lower bound ELBO a classic objective in variational inference. It is what we use as the training objective in VAE. The beauty of ELBO is that it is computable. In the VAE Loss Function part, we will see how to evaluate the ELBO in VAE training, where it directly serves as the negative of training loss function.

Why maximizing the ELBO ≈ maximizing $\log p_\theta(\mathbf{x})$?

  • In the encoder learning stage, we fix $\theta$ and update only $\phi$. In this case, $\log p_\theta(\mathbf{x})$ is a constant. According to $\log p_\theta(\mathbf{x}) = \text{ELBO} + D_{\mathrm{KL}}\!\big(q_\phi(\mathbf{z}\mid \mathbf{x})\,\|\,p_\theta(\mathbf{z}\mid \mathbf{x})\big)$ maximizing the ELBO $\Leftrightarrow$ minimizing the KL divergence. In other words, updating $\phi$ is essentially making the approximate posterior $q_\phi(\mathbf{z}\mid \mathbf{x})$ as close as possible to the true posterior $p_\theta(\mathbf{z}\mid \mathbf{x})$.
  • In the decoder learning stage, we fix $\phi$ and update $\theta$. The direct effect of updating $\theta$ is to make the generative model $p_\theta(\mathbf{x}\mid \mathbf{z})$ better fit the data distribution, thereby improving $\log p_\theta(\mathbf{x})$. Indirectly, this also changes the shape of the true posterior $p_\theta(\mathbf{z}\mid \mathbf{x})$. In the subsequent encoder update, the encoder will then adjust to catch up with this new posterior. In other words, decoder learning pushes the generative model closer to maximum likelihood, while encoder learning continuously adapts to approximate the true posterior determined by the decoder.
  • Ideally, if $q_\phi$ is able to perfectly fit the true posterior, then we can achieve $q_\phi(\mathbf{z}\mid \mathbf{x}) = p_\theta(\mathbf{z}\mid \mathbf{x})$, and the KL divergence becomes zero at this point. In this case, the ELBO equals $\log p_\theta(\mathbf{x})$. In practice, however, the family of $q_\phi$ is limited and may not fully cover $p_\theta(\mathbf{z}\mid \mathbf{x})$. As a result, ELBO $< \log p_\theta(\mathbf{x})$ strictly. Nevertheless, as long as the family of $q_\phi$ is sufficiently flexible, maximizing the ELBO will make $q_\phi(\mathbf{z}\mid \mathbf{x})$ closely approximate $p_\theta(\mathbf{z}\mid \mathbf{x})$, yielding a tight lower bound; and even if the bound is not perfectly tight, increasing the ELBO will still push up $\log p_\theta(\mathbf{x})$, since ELBO always lies below it.

VAE Loss Function: From ELBO to a Practical Implementable Loss¶

Based on the expression of ELBO before, it can be reorganized as $$ \text{ELBO} = \mathbb{E}_{q_\phi}\!\big[\log p_\theta(\mathbf{x}\mid \mathbf{z})\big] - D_{\mathrm{KL}}\!\big(q_\phi(\mathbf{z}\mid \mathbf{x})\,\|\,p(\mathbf{z})\big) $$

The second term is the KL divergence between the approximate posterior $q_\phi(\mathbf{z}\mid \mathbf{x})$ and the latent prior $p(\mathbf{z})$ (the KL can be computed in a closed-form if they are both Gaussian).

The first term corresponds to a reconstruction loss of the input data. The reconstruction loss measures how well the decoder can reproduce the observed data $\mathbf{x}$ from a latent sample $\mathbf{z}$. In practice, the form of this loss depends on the choice of the decoder's output distribution. The three most common cases are Bernoulli, Gaussian, and Categorical, which lead respectively to binary cross-entropy (BCE), mean squared error (MSE), and cross-entropy (CE). Below we show how these arise from the likelihood formulation.

  • Bernoulli (e.g., for binary images like MNIST)

    If we assume $p_\theta(\mathbf{x} \mid \mathbf{z}) = \text{Bernoulli}(\boldsymbol{\pi})$, the decoder outputs a probability vector $\boldsymbol{\pi} \in [0, 1]^D$ through a sigmoid activation, where each element $\pi_i$ represents the probability of the $i$-th pixel being 1. Unlike in the generation process, during training, we do not sample from this distribution (to keep the computation differentiable). Instead, we directly use the probability vector $\boldsymbol{\pi}$ to compute the reconstruction loss against the original binary input $\mathbf{x} \in \{0, 1\}^D$. The reconstruction loss is given by the Binary Cross Entropy (BCE) between the probability vector $\boldsymbol{\pi}$ and the ground truth $\mathbf{x}$: $$ \mathcal{L}_\text{recon} = - \sum_{i=1}^{D} \left[ x_i \log \pi_i + (1 - x_i) \log(1 - \pi_i) \right] $$ In the VAE reconstruction term, compute the likelihood directly by the definition of Bernoulli distribution: $$ \begin{align} \mathbb{E}_{q_\phi}[\log p_\theta(\mathbf{x}\mid \mathbf{z})] & = \mathbb{E}_{q_\phi}\!\Big[\sum_{i=1}^D x_i\log \pi_i(\mathbf{z}) + (1-x_i)\log(1-\pi_i(\mathbf{z}))\Big] \nonumber\newline & \approx \frac1{L}\sum_{l=1}^L \sum_{i=1}^D \Big[x_i\log \pi_i(\mathbf{z}^{(l)}) + (1-x_i)\log(1-\pi_i(\mathbf{z}^{(l)}))\Big] \nonumber \end{align} $$ The second line is a Monte Carlo approximation during training: we sample $\mathbf{z}^{(l)} \sim q_\phi(\mathbf{z} \mid \mathbf{x})$ (via reparameterization). Therefore, the negative reconstruction term is just the binary cross-entropy. $$ -\mathbb{E}_{q_\phi}[\log p_\theta(\mathbf{x}\mid \mathbf{z})] \approx \frac1{L}\sum_{l=1}^L \operatorname{BCE}\big(\mathbf{x},\ \pi(\mathbf{z}^{(l)})\big) $$ Also note that in practice, there is no need to average over multiple samples per datapoint, since the batch dimension already introduces averaging; thus, we can directly use binary cross-entropy as the reconstruction loss.

  • Gaussian

    For continuous data such as grayscale images or audio, the likelihood is typically modeled as a Gaussian distribution $p_\theta(\mathbf{x} \mid \mathbf{z}) = \mathcal{N}(\mu(\mathbf{z}), \sigma^2 I)$ where the decoder outputs the mean vector $\mu \in \mathbb{R}^D$, and optionally the (log) variance. A common simplification is to fix $\sigma$ to a constant (e.g., 1) for stability. We stick into constant var in this case. During training, the mean $\mu$ is treated as the reconstruction of the input $\mathbf{x} \in \mathbb{R}^D$, and we use MSE as the reconstruction loss $$ \mathcal{L}_\text{recon} = \|\mathbf{x} - \mu\|^2 $$ which corresponds to the negative log-likelihood of a Gaussian distribution (up to a constant): $$ \begin{align} -\mathbb{E}_{q_\phi}[\log p_\theta(\mathbf{x}\mid \mathbf{z})] & = \frac{1}{2\sigma^2}\,\mathbb{E}_{q_\phi}\!\big[\|\mathbf{x}-\mu(\mathbf{z})\|^2\big] \;+\; \text{const} \nonumber\newline & \approx \frac{1}{L}\sum_{l=1}^L \frac{1}{2\sigma^2}\,\|\mathbf{x}-\mu(\mathbf{z}^{(l)})\|^2 \;+\; \text{const} \nonumber\newline & \approx \frac1{L}\sum_{l=1}^L \operatorname{MSE}\big(\mathbf{x},\ \mu(\mathbf{z}^{(l)})\big) \nonumber \end{align} $$ The second line is a Monte Carlo approximation during training: we sample $\mathbf{z}^{(l)} \sim q_\phi(\mathbf{z} \mid \mathbf{x})$ (via reparameterization). Therefore, the negative reconstruction term is just the MSE. Also note that in practice, there is no need to average over multiple samples per datapoint, since the batch dimension already introduces averaging; thus, we can directly use binary cross-entropy as the reconstruction loss.

  • Categorical

    For categorical data, such as class labels or word tokens, the likelihood is modeled as a categorical distribution $p_\theta(\mathbf{x} \mid \mathbf{z}) = \text{Categorical}(\mathbf{p})$ where the decoder outputs a probability vector $\mathbf{p} \in [0,1]^K$ via a softmax activation, representing the predicted probability over $K$ classes. During training, $\mathbf{p}$ is directly compared to the one-hot encoded ground truth $\mathbf{x} \in \{0,1\}^K$ using cross entropy loss: $$ \mathcal{L}_\text{recon} = -\sum_{k=1}^{K} x_k \log p_k $$ In the VAE reconstruction term, compute the likelihood directly: $$ \begin{align} \mathbb{E}_{q_\phi}[\log p_\theta(\mathbf{x}\mid \mathbf{z})] &= \mathbb{E}_{q_\phi}\!\Big[\sum_{k=1}^K x_k \log p_k(\mathbf{z})\Big] \nonumber\newline &\approx \frac{1}{L}\sum_{l=1}^L \sum_{k=1}^K x_k \log p_k(\mathbf{z}^{(l)}) \nonumber \end{align} $$ where $\mathbf{z}^{(l)} \sim q\_\phi(\mathbf{z} \mid \mathbf{x})$ (sampled via the reparameterization trick). Therefore, the negative reconstruction term is simply the cross-entropy. $$ -\mathbb{E}_{q_\phi}[\log p_\theta(\mathbf{x}\mid \mathbf{z})] \approx \frac{1}{L}\sum_{l=1}^L \operatorname{CE}\big(\mathbf{x},\ p(\mathbf{z}^{(l)})\big) $$ Similar to the Bernoulli and Gaussian cases, in practice there is usually no need to sample multiple $\mathbf{z}$'s per datapoint, since the batch dimension already provides averaging; thus, we can directly use cross-entropy as the reconstruction loss.

In summary, the VAE training objective is the negative ELBO: $$ \mathcal{L}_\text{VAE}(\mathbf{x}) = - \mathbb{E}_{q_\phi}[\log p_\theta(\mathbf{x} \mid \mathbf{z})] + D_{\mathrm{KL}}\!\big(q_\phi(\mathbf{z}\mid \mathbf{x})\,\|\,p(\mathbf{z})\big), $$ The first term corresponds to the reconstruction loss (BCE, MSE, or CE). The second term is KL divergence between the approximate posterior $q_\phi(\mathbf{z}\mid \mathbf{x})$ and the prior $p(\mathbf{z})$ (typically $\mathcal{N}(0,I)$), which acts as a regularizer encouraging the latent space to follow the prior distribution.

--- VAE Implementation¶

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

This is a classic example of a simple Variational Autoencoder (VAE) with a 2D input and a 2D latent space. The encoder maps the input to a hidden layer of size 16 and outputs the mean and log-variance of the latent distribution. A latent vector is sampled using the reparameterization trick. The decoder mirrors the encoder: it maps the 2D latent vector back through a hidden layer of size 16 to reconstruct the 2D input. Activation functions are ReLU.

This example demonstrates both the training process of a VAE. The plot part goes through the generation process and shows that the distribution of generated data looks like the real data.

In [6]:
class VAE(nn.Module):
    def __init__(self, latent_dim=2, input_dim=2):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder: From input space to latent space (mean and log variance)
        self.fc1 = nn.Linear(input_dim, 16)
        self.fc2_mu = nn.Linear(16, latent_dim)
        self.fc2_logvar = nn.Linear(16, latent_dim)

        # Decoder: From latent space to input space (mean of Gaussian)
        self.fc3 = nn.Linear(latent_dim, 16)
        self.fc4 = nn.Linear(16, input_dim)

    def encode(self, x):
        h = torch.relu(self.fc1(x))
        mu = self.fc2_mu(h)
        logvar = self.fc2_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps * std  # Latent variable z

    def decode(self, z):
        h = torch.relu(self.fc3(z))
        return self.fc4(h)  # Output is mean of Gaussian (the same dim as the input)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Loss function: Reconstruction loss + KL divergence
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum') # MSE loss for Gaussian decoder
    # KL Divergence (between approximate posterior and unit Gaussian prior)
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + 0.5 * kl_div

Training

In [7]:
# Hyperparameters
latent_dim = 2
input_dim = 2
batch_size = 256
epochs = 1000
lr = 0.001

model = VAE(latent_dim, input_dim)
optimizer = optim.Adam(model.parameters(), lr=lr)

def create_data(n_samples):
    return torch.randn(n_samples, input_dim)  # 2D Gaussian data

# Training loop
for epoch in range(epochs):
    model.train()
    real_data = create_data(batch_size)
    
    optimizer.zero_grad()

    # Forward pass
    recon_data_mean, mu, logvar = model(real_data)

    # Compute loss
    loss = vae_loss(recon_data_mean, real_data, mu, logvar)

    # Backpropagation
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch} | Loss: {loss.item():.4f}")

print("Training complete!")
Epoch 0 | Loss: 516.2407
Epoch 100 | Loss: 472.4721
Epoch 200 | Loss: 421.2310
Epoch 300 | Loss: 334.9844
Epoch 400 | Loss: 332.3295
Epoch 500 | Loss: 298.6973
Epoch 600 | Loss: 309.5663
Epoch 700 | Loss: 321.7454
Epoch 800 | Loss: 321.4826
Epoch 900 | Loss: 307.2536
Training complete!
In [14]:
# Visualize the generated latent space and reconstructions
with torch.no_grad():
    model.eval()
    real_data = create_data(1000)
    recon_data_mean, _, _ = model(real_data)
    
    # Create a figure with two subplots side by side
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Left plot: Real data vs mean reconstruction from sampled latent variables
    ax1.scatter(real_data[:, 0], real_data[:, 1], label="Real Data", alpha=0.6)
    # Sample from standard normal distribution in latent space
    z = torch.randn(1000, latent_dim)
    # Pass through decoder to get reconstructions
    sampled_recon = model.decode(z)
    ax1.scatter(sampled_recon[:, 0], sampled_recon[:, 1], alpha=0.6, color='orange', label='Reconstructed Mean (new samples from latent)')
    ax1.legend()
    ax1.set_title("Real vs. Reconstructed Mean (new samples from latent)")
    
    # Right plot: Real data vs generated data from reconstructed mean
    ax2.scatter(real_data[:, 0], real_data[:, 1], label="Real Data", alpha=0.6)
    # Generate new data by sampling around reconstructed mean
    generated = torch.normal(mean=recon_data_mean, std=torch.ones_like(recon_data_mean))
    ax2.scatter(generated[:, 0], generated[:, 1], alpha=0.6, color='green', label='Generated Data (sampled from reconstructed mean)')
    ax2.legend()
    ax2.set_title("Real vs. Generated Data")
    
    plt.tight_layout()
    plt.show()
No description has been provided for this image

--- History of VAE¶

  • 2013 – Auto-Encoding Variational Bayes (Kingma & Welling): introduced the VAE framework and the reparameterization trick.
  • 2014 – Published at ICLR; VAEs entered the mainstream. Together with GANs (Goodfellow et al.), they became one of the two main paradigms in deep generative modeling.
  • 2015–2016 – Early extensions and improvements appeared: Conditional VAE (CVAE); β-VAE (Higgins et al., 2017) enabling disentangled representations; Ladder VAE and Hierarchical VAE providing richer latent structures.
  • 2017–2018 – Integration with other generative models: VAE–GAN hybrids for improved sample quality; PixelVAE and VQ-VAE bringing autoregressive and discrete latent representations into the framework.
  • 2019 onward – Widespread applications emerged, including image generation, representation learning, anomaly detection, molecular design, and reinforcement learning. VAEs were also combined with normalizing flows and energy-based models for enhanced flexibility.

--- VAE Applications¶

While VAEs are primarily designed as generative models for learning latent representations and generating new data, they also have broader applications beyond generation — such as anomaly detection and recommendation systems.

1. Companies in the gaming or movie production industries can use VAEs to generate realistic character designs or environments: The input includes a set of images or, potentially, multi-modal data like images combined with text descriptions. This process enables the VAE to generate new, unseen designs by sampling from the learned latent space.

2. VAEs are widely used for representation learning: By learning a latent space, VAEs compress complex high-dimensional data into lower-dimensional representations, which can be used for visualization, clustering, or as features for downstream tasks. Variants such as β-VAE are often applied to learn disentangled representations, where different latent dimensions capture independent generative factors of the data.

3. VAEs are applied to detect anomalies in medical images: The VAE is trained only on normal images, learning the latent distribution and reconstruction behavior for non-anomalous data. In the inference phase, a test image (which could contain anomalies) is passed through the VAE. If the test image contains an anomaly, the reconstruction will likely have a higher reconstruction error.

4. E-commerce platforms use VAEs to create personalized product recommendations: Each user’s interactions are represented as a binary vector (the row of matrix), where 1 indicates interaction (e.g., viewed, purchased) and 0 indicates no interaction. In the training phase, a user-product interaction matrix is the input, and the decoder tries to reconstruct it. In the inference phase, the interaction vector of a new user is the input, and the decoder generates a vector which is the likelihood of interaction with each product. Product with high probability is recommended. Notice that the same set of products should be used for train and test, so this approach might be a subject of cold start.