VQ-VAE (Vector-Quantized Variational Autoencoder) is a standard approach in the ML literature for quantizing data1. Quantizing data is critical in any situation where we want to use an autoregressive transformer model on data which isn’t naturally tokenized. This is true in most production models for image, video, and audio generation.

In this blog post we demonstrate an alternative to VQ-VAE named FSQ (Finite Scalar Quantization)2 which works better on the MNIST dataset. FSQ avoids the learnable codebook and complex loss balancing in VQ-VAE, resulting in better reconstruction and codebook utilization with simpler training. This post is simply to popularize the alternative FSQ which I have seen work better in my own research.

VQ-VAE

VQ-VAE learns a discrete codebook of embeddings to tokenize continuous data. The encoder produces a continuous representation, then we find the nearest codebook embedding for each position—this gives us discrete tokens.

The core challenge is training the codebook effectively. We need the codebook embeddings to be useful for reconstruction, but we also need the encoder to actually use them. This requires three losses working together: reconstruction, codebook, and commitment. The reconstruction loss is straightforward—standard MSE between input and decoded output. The codebook loss moves embeddings toward encoder outputs, while the commitment loss pulls the encoder toward the codebook, ensuring both components learn together rather than drifting independently.

Here’s the implementation:

def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, dict]:
    z_flat = rearrange(z, 'b c h w -> (b h w) c')
    
    # Find nearest codebook embedding for each spatial position
    distances = (z_flat.pow(2).sum(1, keepdim=True) 
                + self.embeddings.weight.pow(2).sum(1)
                - 2 * z_flat @ self.embeddings.weight.t())
    
    indices = distances.argmin(1)  # Discrete codes
    quantized = self.embeddings(indices)  # Look up embeddings
    quantized = rearrange(quantized, '(b h w) c -> b c h w', ...)
    
    # Two losses: move codebook toward encoder, and vice versa
    commitment_loss = F.mse_loss(quantized.detach(), z)
    codebook_loss = F.mse_loss(quantized, z.detach())
    
    # Straight-through: use quantized in forward, copy gradients in backward
    quantized = z + (quantized - z).detach()
    
    return quantized, {...}

The main challenge with VQ-VAE is balancing these losses and avoiding codebook collapse (where only a few embeddings get used). FSQ offers a simpler alternative.

FSQ

Finite Scalar Quantization (FSQ) is a simpler alternative that eliminates the learnable codebook entirely. Instead of learning embedding vectors, FSQ quantizes each dimension independently to fixed levels. For example, levels=[8,8,8] gives 8 × 8 × 8 = 512 codes without any learnable embedding parameters.

Let’s look at the FSQQuantizer.forward() method:

def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, dict]:
    z = rearrange(z, 'b c h w -> b h w c')
    z = self.project_in(z)  # Project to codebook_dim if needed
    
    # Bound and quantize with straight-through
    eps = 1e-3
    half_l = (self._levels - 1) * (1 + eps) / 2
    offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
    shift = torch.atanh(offset / half_l)
    bounded = torch.tanh(z + shift) * half_l - offset
    codes = (bounded.round() + (bounded - bounded.detach())) / (self._levels // 2)
    
    out = self.project_out(codes)
    out = rearrange(out, 'b h w c -> b c h w')
    return out, {'indices': flat_indices, 'losses': {}, ...}

The encoder produces a continuous representation, which is projected to the codebook dimension. Each dimension is then independently bounded using tanh and rounded to one of its fixed levels. With levels=[8,8,8], each of the 3 dimensions can take 8 discrete values, giving 512 total codes. The straight-through estimator (bounded.round() + (bounded - bounded.detach())) enables gradient flow: use discrete codes in the forward pass, but copy gradients through the continuous values during backpropagation.

The main advantage of FSQ over VQ-VAE is simplicity: no codebook loss, no commitment loss, no learnable embeddings to balance. The encoder simply learns to produce values that reconstruct well when quantized to fixed levels. This avoids codebook collapse and requires fewer hyperparameters.

Results

You can see that the reconstruction error is lower and the codebook utilization is much higher on the validation set using FSQ versus VQ-VAE on the following plots (using the training code below).

FSQ

While MNIST is a simple dataset, these results demonstrate why FSQ is preferable: simpler training with better codebook utilization.

Full Code Reference

The complete implementation code for both VQ-VAE and FSQ quantizers, including the training script and experiments on MNIST, is available as a GitHub Gist here.


  1. Van Den Oord, A., Vinyals, O., & Kavukcuoglu, K. (2017). Neural Discrete Representation Learning. Advances in Neural Information Processing Systems (NIPS 2017). arXiv:1711.00937 ↩︎

  2. Mentzer, F., Minnen, D., Agustsson, E., & Tschannen, M. (2023). Finite Scalar Quantization: VQ-VAE Made Simple. arXiv:2309.15505 ↩︎