Speculative Decoding: the important parts

Standard autoregressive decoding generates one token per forward pass of the model. Each forward pass has the same cost regardless of whether the token is a predictable word like “the” or a surprising one. Leviathan et al.[1] and Chen et al.[2] independently observed that a small draft model can cheaply guess several tokens ahead, and a large target model can verify multiple guesses in parallel — turning sequential token generation into a bet on the draft model’s accuracy.

The draft model is small and fast. The target model is large and slow — but a single forward pass of the target model can process multiple token positions in parallel (just like prefill). Speculative decoding exploits this:

Past sequence: [T1, T2, T3]

Draft model (serial, but cheap):
  A' = draft(T1, T2, T3)
  B' = draft(T1, T2, T3, A')
  C' = draft(T1, T2, T3, A', B')

Target model (one parallel forward pass):
  A = target(T1, T2, T3)         ← verify A'
  B = target(T1, T2, T3, A')     ← verify B' (using draft's A')
  C = target(T1, T2, T3, A', B') ← verify C' (using draft's A', B')

The key insight: the target model computes A, B, and C in parallel in a single forward pass. This is the same parallel computation the target model already does during prefill — processing multiple token positions at once. It takes the draft model’s guesses (A’, B’, C’) as input and produces the target distribution at each position simultaneously. Three tokens verified for roughly the cost of one.

Acceptance: The target model checks each draft token left to right:

Rejection: If any token doesn’t match, discard it and everything after it:

In the best case, one expensive target forward pass produces multiple tokens. In the worst case, it produces exactly one — the same as standard decoding. Speculative decoding never makes things worse, it can only help. Note that this only holds when inference is memory-bound — speculative decoding increases total arithmetic operations (the draft model’s work plus verifying multiple positions), but since autoregressive decoding is bottlenecked by memory bandwidth, the extra compute fits within what’s already being spent on a single forward pass.

Note on sampling: The description above uses exact token matching (A = A’), which only works under greedy decoding. In general, LLM decoding samples from a probability distribution (see Token Prediction) — the property speculative decoding must preserve is the target model’s output distribution, not any single token. Greedy matching cannot do this: it biases the output toward the draft distribution q rather than the target distribution p. Leviathan et al.[1:1] generalize verification with speculative sampling: accept a draft token with probability min(1, p(x)/q(x)), and on rejection sample from norm(max(0, p - q)). This guarantees the output distribution is identical to sampling from p alone.

Using a tree model over a draft model

In standard speculative decoding, the draft model generates A’, B’, C’ sequentially — B’ depends on A’, C’ depends on A’ and B’. Even though the draft model is small, this sequential dependency is a bottleneck. MEDUSA[3] eliminates the draft model entirely by adding lightweight prediction heads directly to the target model.

Recall from Token Prediction that the LM head projects the final hidden state Z[-1] (a vector of size d_model) to a probability distribution over the vocabulary: Z[-1] · W_vocab → softmax → probabilities. MEDUSA adds extra heads that do the same thing — take Z[-1] in, produce a vocabulary-sized probability distribution out — but each head is trained to predict a different future position. Head k predicts the token at position t+k+1 instead of t+1.

Each MEDUSA head has its own learned weight matrices, trained separately from the base model (the base model’s weights stay frozen). A head is slightly more expressive than the LM head: it first transforms Z[-1] through a small residual block (linear(d→d) → SiLU → add back Z[-1]), then projects to vocabulary size. The residual connection (add back Z[-1]) means that at initialization (with the linear layer’s weights set to zero), each head behaves identically to the original LM head — training then specializes each head to predict its assigned future position:

Single forward pass of [T1, T2, T3] → Z[-1] →
    LM head        → A' (position 4)
    MEDUSA head 1  → B' (position 5)
    MEDUSA head 2  → C' (position 6)

All heads use the same Z[-1] — no sequential dependency. But this is an approximation: head 1 predicts B’ without knowing A’, so it is less accurate than the sequential draft model above.

Because each head predicts independently and is less accurate, the top-1 prediction is often wrong. So we take the top-k predictions from each head. With top-2 from each:

- LM head:  A₁, A₂
- Head 1:   B₁, B₂
- Head 2:   C₁, C₂

We create a tree of all possible combinations:

         root
        /    \
      A₁      A₂
     / \     / \
    B₁  B₂  B₁  B₂
   /|  /|  /|  /|
  C₁C₂C₁C₂C₁C₂C₁C₂

This tree is the Cartesian product of the top-k predictions from each head — every combination of candidates across all positions. With top-2 from 3 heads, that is 2 + 4 + 8 = 14 candidate tokens. This grows exponentially: top-3 from 5 heads would give 3 + 9 + 27 + 81 + 243 = 363 tokens, far too many for a single verification forward pass.

MEDUSA handles this by pruning the tree to a fixed budget (e.g., 64 nodes). Offline, on a held-out set of representative prompts (the calibration data), it measures how often each head’s i-th ranked prediction is correct — e.g., head 1’s top-1 is correct 70% of the time, head 2’s top-2 is correct 15%. These statistics are computed once before deployment, not at decode time. Each path’s probability of being fully correct is estimated as the product of its nodes’ individual accuracies. A greedy algorithm then adds nodes with the highest estimated accuracy until the budget is full. The result can be asymmetric — A₁ might get more children than A₂ if it is statistically more likely to be correct. This pruned tree shape is chosen once and reused for all inputs and all decoding rounds.

All candidate tokens are packed into a single forward pass of the target model, using an attention mask that encodes the tree structure. To illustrate, consider a simpler two-level tree (top-2 from LM head, top-2 from head 1):

        [T1, T2, T3]  (past context, in KV cache)
        /          \
      A₁            A₂
     / \           / \
    B₁  B₂       B₃  B₄

The 6 candidate tokens are packed into a flat sequence: [A₁, A₂, B₁, B₂, B₃, B₄]. The tree attention mask (1 = can attend, 0 = cannot):

       A₁  A₂  B₁  B₂  B₃  B₄
  A₁ [  1   0   0   0   0   0 ]
  A₂ [  0   1   0   0   0   0 ]
  B₁ [  1   0   1   0   0   0 ]  ← B₁ attends to A₁ (parent)
  B₂ [  1   0   0   1   0   0 ]  ← B₂ attends to A₁ (parent)
  B₃ [  0   1   0   0   1   0 ]  ← B₃ attends to A₂ (parent)
  B₄ [  0   1   0   0   0   1 ]  ← B₄ attends to A₂ (parent)

(All tokens also attend to T1, T2, T3 via the KV cache.)

Each token can only attend to its ancestors in the tree. B₁ and B₂ attend to A₁ (their parent) but not to A₂ or any B on A₂’s branch. This is neither a standard causal mask (which would let B₃ attend to A₁, B₁, B₂) nor a dense mask — it is a custom sparse mask determined by the tree topology.

Recall from Attention Mechanism that attention computes A = Q · K^T, then applies the causal mask, then softmax. The tree mask replaces this causal mask. The mask is stored as a dense boolean matrix and applied element-wise — it is not stored sparsely, even though it is sparse.

Since MEDUSA’s tree shape is fixed (chosen once on calibration data, as described above), this mask’s sparsity pattern is identical for every decoding round — only the token values change.

Because of the tree mask, each token’s attention output reflects only its ancestral path — B₁ sees [T1, T2, T3, A₁], while B₃ sees [T1, T2, T3, A₂]. So the logits at each position represent the target model’s prediction for the next token given that specific path. Verification walks the tree:

  1. The target model’s logits from the past context predict position t+1. Check: does A₁ or A₂ match? Say A₁ matches.
  2. The logits at the A₁ position predict position t+2 (given path [T1, T2, T3, A₁]). Check: does B₁ or B₂ match?
  3. Continue until a mismatch (truncate) or a leaf (all accepted).

This is the same principle as the earlier verification diagram — the target model processes all candidates in parallel and produces the target distribution at each position. The tree mask ensures each candidate sees only its own ancestral path, so all branches are verified simultaneously in a single forward pass.

Speculative Speculative Decoding: parallelizing drafting and verification

In both standard speculative decoding and tree-based methods like MEDUSA, drafting and verification are sequential: the draft model generates candidates, then the target model verifies them, then the draft model generates the next round. The draft model sits idle during verification because the target model’s forward pass saturates the GPU.

Speculative Speculative Decoding (SSD)[4] breaks this by running the draft and target models on separate hardware. While the target model verifies the current round, the draft model predicts what the verification outcome will be and pre-computes draft tokens for the likely next rounds, storing them in a speculation cache. On a cache hit, the next round’s drafts are returned immediately; on a miss, it falls back to synchronous speculative decoding.

Related

Acknowledgements

Thanks to Tarindu Jayatilaka for discussions about Speculative Decoding.



  1. Leviathan, Y., Kalman, M., & Matias, Y. (2023). Fast Inference from Transformers via Speculative Decoding. Proceedings of the 40th International Conference on Machine Learning (ICML). ↩︎ ↩︎

  2. Chen, C., Borgeaud, S., Irving, G., Lespiau, J.-B., Sifre, L., & Jumper, J. (2023). Accelerating Large Language Model Decoding with Speculative Sampling. Technical Report. ↩︎

  3. Cai, T., Li, Y., Geng, Z., Peng, H., Lee, J. D., Chen, D., & Dao, T. (2024). Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads. Proceedings of the 41st International Conference on Machine Learning (ICML). ↩︎

  4. Bhendawade, N., Sheng, Y., Dalmia, A., Hop, P., Huang, K., Bittner, J., Mahoney, M. W., Keutzer, K., & Gholami, A. (2026). Speculative Speculative Decoding. International Conference on Learning Representations (ICLR). ↩︎

  5. Leviathan, Y., Kalman, M., & Matias, Y. (2024). Looking back at speculative decoding. Google Research Blog. ↩︎

  6. NVIDIA Developer Blog. An Introduction to Speculative Decoding for Reducing Latency in AI Inference. ↩︎