- Love me some JSD. Here is a problem most people don't consider with generative modeling (e.g., AI text, image, music, video models): basically all standard pre-training algorithms for generative models (i.e., cross entropy, basically all diffusion/flow formulations) are closer to a Forward KL divergence. In other words, given limited capacity the model will try to stretch itself to cover every mode. This gives you a jack of all trades (lots of knowledge and diversity), but a master of none (you get blurry images and text filled with nonsense).
The real magic in generative modeling comes from the post training process that comes after, which usually (e.g., RLHF) approximates Reverse KL (given limited capacity, try to perfectly cover what you can, but it's fine to drop the rest entirely). This gives amazing results, but is also the cause of AI oddities like the "AI Image Pixar Look", many of the verbal tics of LLMs, and all AI music using the same small set of voices. Jensen-Shannon Divergence sits right in the middle of Forward and Reverse KL and is what many GANs are claimed to approximate. Ideally, it is a better trade-off between diversity and fidelity.
- It has applications outside of machine learning too! I used symmetric Kullback–Leibler divergence for a project with photon number resolving single photon detectors during my PhD. I used it with an adjacency matrix to split a gaussian mixture model (modelling some data with multivariate gaussians) into a series of clusters.
https://snsphd.online/chapter_04/section_05_results/#photon-...
- I've been working on a field guide in working with colleagues. I'm interested if this is helpful for folks wanting a more applied view:
https://lospino.so/statistics/jensen-shannon-divergence/
Feedback welcome both from initiates (on helpfulness) and experts (on correctness)!
- For those wanting alternatives to KL-divergence, the KL and Jensen–Shannon divergences are both F-divergences: https://en.wikipedia.org/wiki/F-divergence
- This looks interesting and I'm curious if anyone has more context for why it's on the frontpage today.
- Every now and then, a random math or science concept hits front page. Usually, people chime in with interesting perspectives on it. Guess we'll see.
- I’d like to know what the advantage is over KL divergence. It seems like the important idea is symmetry? Not clear to me why that matters; I’d love to know what application this is used for.
- There are many applications. I mainly see it used for detecting drift in datasets for ML models. It has a nice benefit over the KL divergence in the case where the two distributions you're measuring have no overlap (KL won't compute, but JS will just return 0). Also, when taking its square root you get a distance rather than a divergence which allows you to compare it to JSD measurements of other distributions.
- > Also, when taking its square root you get a distance
Easy conversion into a distance metric is hugely valuable to making the property amenable to KNN-based dimensionality reduction algos (and I'm sure other things I don't understand, as a non-mathematician)
Here's a library that the creator of UMAP provides (UMAP being a workhorse of dimensional reduction algos), for doing approx nearest neighbor search: https://pynndescent.readthedocs.io/en/latest/api.html#pynnde...
- Iirc (and I could be wrong, this is from memory) JS divergence is what is minimized in GANs (where we simultaneously train a generator and real/synthetic classifier with the goal of each trying to beat the other to converge on real looking synthetic data), at least for some training methods.
I don’t think GANs are used much now in comparison to diffusion models, but as recently as a few years ago they were the standard way to make fake data, a la “this face does not exist”
- The Hacker News hive mind is real!
I was just reading about JSD the other day after reading about KL divergence...seems like a nifty measurement device for things like sim-to-real evaluations in robots (the reason I was going down this rabbit hole.)
I think the appeal over raw KL is that JSD behaves a bit nicer when the simulated and real distributions don't perfectly overlap...which is basically always true in the real world!
- Currently piloting the use of JSD for a synthetic audience survey application, measuring how closely the synthetic response distribution matches a human panel.
Been knee-deep trying to understand this world, so seeing this on Hacker News today is kind of scary.
- There is so much I don't understand
- Every time I end up on wikipedia page for some math or CS term I just give up on reading and search for other source, any at all. I know it is supposed to be an encyclopedia, and I am sure definitions are technically correct but it just isn't what most people need. I remember wikibooks project tried to bridge that gap but it never got popular enough. I guess it is just easier to compose short notes compared to writing full blown manual, and its much harder to split such work.
- Why not use this instead of KL in reinforcement learning?
- To minimise the KL you just calculate the surprisal. The integral can be approximated by sampling over your training data. It's a direct expression of the information loss between your real data and your fitted probability distribution.
Calculating the JSD could be more difficult, the expression uses a mixture between the 'true' and 'fitted' distribution. You can still simulate this, but half the time you'd be fitting the model to itself, and I just don't see why that would be useful.
I think the JSD is most useful when you need an actual metric, but as long as you have a fitted and target distribution the KL divergence is a natural fit since you can interpret the result as information loss.
- It's been used, along with every other divergence and distance you can think of.
In practice, which divergence you use doesn't seem to be very important. The KL is the one with the most theoretic foundation though, i.e. will work with infinite data. The important aspect seems to be that neural networks are Lipschitz bound, and that that is the most important constraint preventing collapse.
- JSD is just symmetrized KL, it's the forward KL + reverse KL.
In reinforcement learning, usually what we want is to find the optimal action, i.e. action that maximizes the reward, this translates to the so-called "mode-seeking" optimization, which is the reverse KL.
- JSD is slightly different to forward KL + reverse KL (which is unbounded, whereas JSD measured in bits is in the range [0, 1]).
One way to interpret JSD(P, Q): Associate the distributions P and Q with two target classes, respectively. Pick a target class based on a fair coin flip. Then sample either from distribution P or distribution Q, depending on the outcome of the coin flip. The JSD is the mutual information between the resulting mixture distribution and the target class.
Alternative intuition: Suppose we want to measure the correlation between a feature X and a binary target class Y. We have a tabular data set with two columns X and Y, whose rows correspond to individual samples. JSD is the mutual information between the feature X and the target class Y, but after we resample our data (rows) to ensure that we have a balanced representation of the target class Y. If we measure the JSD in bits, the quantity 2^(JSD-1) is the fraction of times X correctly predicts Y, assuming balanced classes.