Softmax,你能推导雅可比矩阵吗?你应该关心吗?
Softmax, can you derive the Jacobian? And should you care?

原始链接: https://idlemachines.co.uk/essays/softmax

## Softmax:深入解析 Softmax 是机器学习中普遍使用的函数,用于多类别输出、概率归一化和计算注意力权重。它将实数向量转换为概率分布——介于 0 和 1 之间的值,总和为 1——有效地将向量映射到“概率单纯形”上。 虽然看似简单(softmax(xi) = exi / Σjexj),但 softmax 会夸大值之间的相对差异,使预测具有决定性,但可能对不确定性估计造成问题。它将所有维度耦合在一起;由于总和为 1 的约束,改变一个输入会影响所有输出。 **数值稳定性**至关重要。朴素的实现可能因输入过大而溢出。解决方法?将输入移动最大值(softmax(xi) = softmax(xi - max(x))),确保没有指数超过 0。 **雅可比矩阵**揭示了 softmax 的相互关联性。其结构(对角线加秩 1)允许高效的反向传播,而无需完全实现潜在的巨大矩阵。与交叉熵损失结合使用,反向传播简化为预测概率和真实标签之间的差异 (s - y)。 最后,**轴参数**处理批量处理,沿指定维度进行归一化。**温度缩放**控制分布的“锐度”,影响置信度和确定性。

Hacker News 新闻 | 过去 | 评论 | 提问 | 展示 | 招聘 | 提交 登录 Softmax,你能推导出雅可比矩阵吗?你应该关心吗? (idlemachines.co.uk) 13 分,作者 smaddrellmander 1 小时前 | 隐藏 | 过去 | 收藏 | 讨论 帮助 考虑申请YC 2026年夏季项目!申请截止至5月4日 指南 | 常见问题 | 列表 | API | 安全 | 法律 | 申请YC | 联系方式 搜索:
相关文章

原文

Multiclass output? Softmax. Normalising probabilities? Softmax. Attention weights? Softmax. Partition function? You guessed it, Softmax. This function comes up everywhere, but how often have you really thought about what's going on inside?

What does softmax actually do to your distribution?

The softmax function is deceptively simple:

softmax(xi)=exijexj\mathrm{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}}

We take the exponential of each input and normalize by the sum of all exponentials. This transforms a vector of arbitrary real numbers into values between 0 and 1 that sum to 1, it technically this is a pseudo-probability distribution (they're not derived from a probability space), but it's close enough to a probability distribution and for practical purposes they work just fine.

One useful way to think about softmax is that it maps vectors into a very specific geometric object: the probability simplex. For an n-dimensional output, this is the set of all vectors where each entry is non-negative and everything sums to 1. In 3 dimensions, this looks like a triangle sitting in 3D space; in higher dimensions, it's the same idea generalised. Softmax takes an unconstrained vector in Rn\mathbb{R}^n and smoothly projects it onto this simplex. The constraint that all outputs must sum to 1 is exactly what creates the interactions between dimensions that we'll see later in the Jacobian.

Let's visualize what this actually does in a real language model scenario - predicting the next token after "the cat sat on a":

Distribution Shift
Distribution Shift
Left: raw logit values for candidate tokens. Right: probabilities after softmax. The highest logit ("mat" at 3.2) gets dramatically amplified to 48% probability, while others are suppressed. The transformation turns unbounded scores into a probability distribution that sums to 1.

The transformation is pretty dramatic. The relative differences between values get exaggerated, which means the largest logit value dominates the output, while smaller values are squashed. This is exactly what we want for confident predictions, but it also explains why softmax can be problematic when you want uncertainty estimates — it's very opinionated about which class should win.

We can see this "winner takes most" behavior even more clearly with a batch of vectors:

Softmax Focus
Softmax Focus
Each column is one sample, each row is a class. Left: logits with relatively similar magnitudes (green). Right: probabilities after softmax (purple). Notice how probability mass concentrates sharply on the highest logit value in each sample as we take the softmax over each column.

This focusing behavior is funny, on one hand it's what makes softmax so powerful for classification. By making the outputs more decisive it makes it easier to predict classes and train the model. But on the other hand we're baking in a sense of difference between options that isn't always present in the original logits. That said, a trained model has learned to produce logits where these differences are meaningful.

Numerical Stability: We need to talk about overflow

A naive implementation works fine for small inputs, but as with any function with an exponential we should hear some alarm bells ringing. In the sigmoid we had a single exponential, here we have N of them.

If we feed in x = [1000, 1001, 1002] and x_i = 1002 we get e^{1002} = inf, and the sum is also inf, so we get inf / inf = nan. Not good. This isn't a graceful failure, this is a catastrophic failure.

The interconnected nature of the softmax makes this worse as well, in a sigmoid we have a single input to a single output, so if any activation results in an overflow we only get the one NaN. With the softmax each output is a function of all the inputs, so if any input overflows we get NaN for all outputs. (NB: this is a simplification of the sigmoid, once we backpropagate a NaN into the network it can cause NaNs to spread).

The fix - shifting the inputs

With an exponential in general we are worried about supplying large positive inputs, very negative inputs are less of a problem as they just tend towards 0 and most frameworks will just underflow to 0 without causing a catastrophic failure.

The other thing to note is that the softmax is invariant to shifts in the input, because it ultimately just gives us a normalised distribution and we only really care about the relative sizes.

So we can use the identity that for an exponential function, and some constant C: exi+c=exiece^{x_i + c} = e^{x_i} \cdot e^c

So we can shift every x by any constant c without changing the output of the function. Then naturally we can ask what value of c is best to avoid overflow? The answer is c = -\max(x), so we shift the inputs down by the maximum value, ensuring that the largest input to the exponential is 0, and thus we can guarantee no overflow.

This is super easy to write in numpy, all we need is

1def stable_softmax(x, axis=-1):  # axis=-1 means the last dimension
2    shift_x = x - np.max(x, axis=axis, keepdims=True)
3    exp_shift_x = np.exp(shift_x)
4    return exp_shift_x / np.sum(exp_shift_x, axis=axis, keepdims=True)

The trick means that no values are ever larger than 0, and because it's all about the ratio of exponentials we don't change the output. The keepdims=True ensures that the shapes broadcast correctly when working with batches. You might point out there's always a concern that if our maximum is an outlier, we might be shifting the rest of the values down and losing precision. This is valid. But, if this is the case the normal softmax enhances this difference in the same way, and solving it requires fixing the outlier rather than the softmax function.

The Jacobian

This is where it gets a bit more interesting. The softmax is a vector function, so we have to think about the Jacobian rather than the simple derivative. For anyone unfamiliar with the term, the Jacobian refers to the matrix of all first order partial derivatives of a vector function. So for a function f:RnRmf: \mathbb{R}^n \to \mathbb{R}^m

The most important property of softmax — and the one people tend to miss — is that it couples all dimensions together. Increasing one input doesn’t just increase its own output — it necessarily decreases the others, because the outputs must still sum to 1. This is very different from elementwise functions like ReLU or sigmoid. This coupling is exactly what shows up in the Jacobian: the diagonal terms represent how each output responds to its own input, and the off-diagonal terms capture the competition between different entries.

Computing the Jacobian

Back to the softmax, we have a function that maps from ℝⁿ to ℝⁿ, so the Jacobian is an n x n matrix. The entries of this matrix are given by: Jij=softmax(xi)xjJ_{ij} = \frac{\partial \mathrm{softmax}(x_i)}{\partial x_j}

There are two natural cases to consider here, the case where i = j and the case where i != j. For the case where i = j (on the diagonal) we have:

Jii=s(xi)xi=s(xi)(1s(xi))J_{ii} = \frac{\partial \mathrm{s}(x_i)}{\partial x_i} = \mathrm{s}(x_i) \cdot (1 - \mathrm{s}(x_i))

where we are using s(x_i) as a shorthand for softmax(x_i).

To see why: applying the quotient rule to exikexk\frac{e^{x_i}}{\sum_k e^{x_k}}

This is a nice expression, and we can see that the derivative is always positive, and tends towards 0 as the output of the softmax tends towards either 0 or 1. This should also be familiar as it's the same expression as the derivative of the sigmoid function, which makes sense as the softmax is a generalisation of the sigmoid to multiple dimensions.

For the case where i != j (off the diagonal) we have:

Jij=s(xi)xj=s(xi)s(xj)J_{ij} = \frac{\partial \mathrm{s}(x_i)}{\partial x_j} = -\mathrm{s}(x_i) \cdot \mathrm{s}(x_j)

Here, when we differentiate si=exikexks_i = \frac{e^{x_i}}{\sum_k e^{x_k}}

This derivative is always negative, and tends towards 0 as the outputs of the softmax tend towards either 0 or 1.

Combining the two cases we write the whole Jacobian as:

Jij=diag(s)ssTJ_{ij} = \mathrm{diag}(s) - s \cdot s^T

where diag(s) is a diagonal matrix with the outputs of the softmax on the diagonal, and s·s^T is the outer product of the softmax output with itself.

The structure: diagonal plus rank-1

There’s an important structural detail hidden in this expression. The second term is an outer product, which means it has rank 1. So the full Jacobian is a diagonal matrix with a rank-1 correction. This is exactly why we can compute the backward pass efficiently. Instead of working with an n x n matrix, we only need a dot product and a few elementwise operations. The structure of the Jacobian is doing all the work for us.

Why size matters

In this form we should note something that should worry any engineers out there. What shape is the Jacobian? It's actually n x n, and if we think about the size of this in a typical transformer model where n is the sequence length in the attention mechanism, or is the vocab size in the final output layer. In either case this risks becoming a huge matrix. We'll see in the next section how to avoid fully materialising this.

The backwards pass

In the backwards passes of a neural network, we have an upstream gradient dLds which we want to propagate back through the softmax to get the error with respect to the inputs x. By the chain rule we have:

dLdx=JTdLds\frac{dL}{dx} = J^T \cdot \frac{dL}{ds}

where J^T is the transpose of the Jacobian. This is the canonical way to write the backwards pass, but as we noted above the Jacobian is an n x n matrix which is potentially huge.

We can write this out in full explicitly as

dLdxi=jJjidLdsj=jLsj(sjδijsjsi)\frac{dL}{dx_i} = \sum_{j} J_{ji} \cdot \frac{dL}{ds_j} = \sum_{j} \frac{\partial L}{\partial s_j}(s_j \delta_{ij} - s_j \cdot s_i)
=sidLdsisijsjdLdsj=si(dLdsijsjdLdsj)= s_i\frac{dL}{ds_i} - s_i \sum_{j} s_j \cdot \frac{dL}{ds_j} = s_i\left(\frac{dL}{ds_i} - \sum_{j} s_j \cdot \frac{dL}{ds_j}\right)

Note the final term in the last line is just a dot product, i.e. the weighted average of all upstream gradients under the softmax distribution.

It’s worth pausing to interpret this expression. The term inside the brackets is the difference between the local gradient and the expected gradient under the softmax distribution. In other words, we’re comparing each component to the average behaviour of the system.

  • This means the softmax gradient is doing two things at once:
    • scaling updates by the probability s_i
    • centering them relative to the distribution as a whole This “centering” effect is another reflection of the coupling — updates to one component are always defined relative to the others.

This is just a scalar, the same for all i and we can ultimately write this as

1def softmax_backward(dL_ds, s):
2    # s is the cached softmax output from the forward pass
3    dot = np.sum(dL_ds * s)
4    return s * (dL_ds - dot)

Tallying this all up, we have one dot product, one elementwise multiplication, and a subtraction. We never actually have to materialise the Jacobian at all. The full Jacobian is an N×N matrix. For a vocabulary softmax in a language model, N might be 50,000 — storing the Jacobian would require 20GB just for one sample.

Linking with the cross-entropy loss

The softmax function is often used in conjunction with the cross-entropy loss, which we'll talk about in detail in the next post, but it's worth noting here the super elegant result we get if we consider both functions together. The backwards pass of the softmax is a bit involved, it's okay, but there are a few moving parts. And in all honesty we can say the same about the cross entropy. Until, that is, we think about them together, and we find that the combination of the two gives us a very simple expression for the backwards pass, which is just the difference between the predicted probabilities and the true labels.

dLdx=sy\frac{dL}{dx} = s - y

Where s is the output of the softmax, and y is the one-hot encoded true labels. Keep this in mind, it's important to know the full derivation, but in 9/10 cases this is the expression you'll end up needing to know. There are also computational efficiency reasons why these operations get fused in real systems, but we'll talk about that in the next post.

The batch dimension - why axis matters

Up to this point we've been thinking about softmax as operating on a single vector, but in practice we almost never do this. We're always working with batches of data, or in the case of attention mechanisms, we're dealing with sequences and multiple heads. This is where the axis parameter becomes critical.

Consider a batch of predictions with shape (batch_size, n_classes). We want to apply softmax across the classes for each sample independently. In numpy this means using axis=1:

1def stable_softmax(x, axis=-1):
2    x_shifted = x - np.max(x, axis=axis, keepdims=True)
3    exp_x = np.exp(x_shifted)
4    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
5
6# Batch of 3 samples, 4 classes each
7logits = np.array([
8    [2.0, 1.0, 0.5, 0.1],
9    [0.1, 3.0, 0.2, 0.3],
10    [1.0, 0.5, 2.5, 0.8]
11])
12
13probs = stable_softmax(logits, axis=1)
14print(probs.sum(axis=1))  # [1. 1. 1.] — each row sums to 1

The keepdims=True argument is crucial here — it ensures that after the max and sum operations, the result still has the same number of dimensions as the input, just with size 1 along the reduction axis. This allows the broadcasting to work correctly in the division step.

In attention mechanisms, you might even see axis=2 or axis=3 depending on whether you're computing attention over sequence positions or over multiple heads. The key insight is that softmax always normalises along a specific axis, turning those values into a probability distribution.

Softmax Axis Comparison
Softmax Axis Comparison
Left: original logits (green). Middle: softmax with axis=1 (each row sums to 1). Right: softmax with axis=0 (each column sums to 1). The axis parameter determines which dimension gets normalized.

Temperature scaling

The softmax function focuses the probability distribution on the largest set of logits, but sometimes we want to control how focused it is. This is where the concept of temperature comes in. By introducing a temperature parameter, we can control the "sharpness" of the output distribution.

softmax(xi,τ)=exi/τjexj/τ\mathrm{softmax}(x_i, \tau) = \frac{e^{x_i / \tau}}{\sum_{j} e^{x_j / \tau}}

Lower temperatures make the distribution sharper (more "winner-takes-all") and tends ultimately towards a one-hot vector, while higher temperatures flatten it towards uniform. Here's what happens as we scan from high temperature (smooth/uniform) to low temperature (sharp/one-hot):

Softmax Temperature
Softmax Temperature
Temperature controls how "confident" the softmax is. At T=5.0, all tokens get similar probability. As T approaches 0.1, the distribution becomes increasingly concentrated on the winner, eventually becoming nearly one-hot.

In practice, this is just scaling the logits — but that scaling directly controls how confident (and how brittle) the output distribution becomes. This is a neat use of the exponential function acting on an existing vector, a linear transformation on the vector results in very non-linear outcomes. Language models often use temperature at inference time to control creativity vs determinism.

Putting it all together

1import numpy as np
2
3# Complete implementation with stability and batch support
4def softmax(x, axis=-1):
5    """Stable softmax with axis support for batch operations."""
6    x_shifted = x - np.max(x, axis=axis, keepdims=True)
7    exp_x = np.exp(x_shifted)
8    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
9
10def softmax_backward(dL_ds, s, axis=-1):
11    """Backward pass for softmax with axis support.
12    
13    Args:
14        dL_ds: upstream gradient, same shape as s
15        s: cached softmax output from forward pass
16        axis: axis along which softmax was applied
17    """
18    dot = np.sum(dL_ds * s, axis=axis, keepdims=True)
19    return s * (dL_ds - dot)
20
21# Sanity checks
22x = np.array([1000., 1001., 1002.])
23s = softmax(x)
24print(s)          # [0.09003057 0.24472847 0.66524096] — no NaNs
25
26# Gradient check — uniform upstream gradient should give zero gradient
27# (shifting all inputs equally doesn't change the output)
28dL_ds = np.ones(3)
29dL_dx = softmax_backward(dL_ds, s)
30print(dL_dx)      # [0. 0. 0.] — correct
31
32# Batch example
33batch_logits = np.array([
34    [1000., 1001., 1002.],
35    [0.5, 2.0, 1.5],
36])
37batch_probs = softmax(batch_logits, axis=1)
38print(batch_probs.sum(axis=1))  # [1. 1.] — each row sums to 1

This last check is a neat little sanity check: if we have a uniform upstream gradient, we get zero gradient with respect to the inputs. This makes sense because shifting all inputs equally doesn't change the output of the softmax (it's translation-invariant), so the loss shouldn't change, and thus the gradient should be zero.

If you want to work through this properly, I've put it up as an exercise (forward pass, backward pass, gradient check) here: idlemachines.co.uk/questions/softmax-forward-backward

联系我们 contact @ memedata.com