Drifting and Chart Transport
Summary of research progress in the last month. We isolated some drifting problems and experimented with chart transport methods. Concludes with some preliminary MNIST experiments.
The page might take a while to load
Mode collapse in plain drifting
We apply MLP-drifting to a bimodal Gaussian in 2D below. This demonstrates mode collapse from disjoint initialization. After converging to a mode, the model has almost no incentive to cover the other mode.
Drag to see different training steps, same for later plots.
Generally, model-sample drifting without importance sampling all suffer from this reverse-KL mode-collapse problem.
Chart transport
The main idea of our proposal is to build an autoencoder (chart) between the latent and model spaces. This way, we can drift the data latent towards the model latent (prior) to optimize forward-KL. We exemplify this method on the same problem above.
Mechanically (right panel), the latent score field guides the model samples (red and blue) towards the model latent distribution (black right). This induces (left panel) the data and sample distributions to match.
We also ran a flow-matching baseline (click to expand)
Chart transport challenge: space-filling
We realized that latent distribution matching becomes fragile when the data has low intrinsic dimension. In the experiment below, we generate multimodal Gaussian data on a 2D manifold in 3D ambient space.
Bad things happen when we try to warp a thin, 2D data latent surface into a thick target latent distribution.
Note how the chart (autoencoder) struggles to warp the 2D data latent into 3D latent target. It is known that space-filling using low-dimensional objects cause ill-conditioning (Hilbert’s thirteenth problem!). This appears to be a fundamental challenge.
Note how the data latents (right) are thin surfaces that can’t warp into the target Gaussian-shaped prior. The generated samples (can toggle off in the plot) pass through data samples (zoom in to see) but also generate off-manifold samples.
Partial solution: stochastic encoding
To address manifold dimension problem, we tried using a stochastic encoder. Each data sample maps to a “fiber” in latent space. We still match sample-space distributions when the marginal latent-fiber distribution match, and distinct fibers are reconstructible. The following example demonstrates how this ~solves the 2D-data in 3D ambient setup.
Note how the latents (right) now become thick 3D blobs which can radially fill the space.
Stochastic encoding is not free (needs to train a latent score critic), and we find it more a band-aid than a fix. We have some hypothesis for why it’s not the full solution, but it is slightly too involved to explain here.
MNIST experiments
Flow-matching MNIST baseline
Flow matching with the same MLP architecture, 50 Euler steps.
Real-world data likely lies on a manifold with some noise; their “manifold dimension” is, in general, dependent upon a pre-chosen noise level.
Still, generation quality is pretty latent-dimension dependent. Stochastic encoding mitigates the failure mode (see below), but generated samples are still the best when the latent dimension matches the intrinsic data manifold dimension, and when decoding is deterministic.
12-dimensional latent
Best sample quality. This is known to be around the intrinsic dimension of the MNIST data (I think CIFAR / Imagenet’s in the 20-80 range, respectively).
Deterministic encoding
Stochastic encoding
32-dimensional latent
Sample qualities are visibly worse than the 12D case.
Deterministic encoding
Stochastic encoding
128-dimensional latent
Both methods break down, with deterministic encoding having much worse failure modes.