profile pic
⌘ '
raccourcis clavier

The following document describes and summarizes existing works in vLLM to improve general guided decoding performance. 1

This design will largely affect how logit_processor are currently being handle within the vLLM architecture.

Main mega thread: vllm-project/vllm#5423

Goal:

  • Improve general TPS when using guided decoding.
  • Standardize logit processor interface 2
  • separate compute_logits and preparing logits into two separate steps

Orthogonal, but still goals:

Scope: logit_processor, sampling controller interface

background

flow
flow

reference: vllm-project/vllm#5329

Currently, generations with FSM is super slow, even with warmup steps to initialize given FSM. This behaviour is further exemplified when running with context longer than 4096 tokens.

Additionally, all outlines logit processors are considered stateful, which slows down the model executor, given in V0 logit processors are applied row-by-row blocking

Thus comparing to sglang, vLLM v0 is currently not up to par.

plan

Implement structured decoding from scheduler, given that we can compute token bitmask and broadcast towards GPU workers

@cadedaniel: “tree scoring in [spec decode] could use the same API as multi-path jump decoding.”

How should we handle FSM per requests?

  • Currently, users can specify different schemas per request, which means the FSM will be compiled per request. This is suboptimal because it slows down general TTFT.
  • For most use cases, we should assume JSON schema similar to how the system prompt is currently being handled (pass during server init)

Why should we follow the plugins system?

  • If going with the best options, then what is the reasoning behind supporting different backends?
  • Agree for extensibility, but seems to add additional overhead.

appendix.

The following includes background information about guided generations.

batched constrained decoding using pushdown automaton

Implemented in mlc-ai/xgrammar

Quote

calculate adaptive token bit-mask per batch

IMPORTANT

operating on string level, not token_id

GrammarMatcher FSM in xgrammar

questions

  • byte-level automaton

overhead of token_id string

Token for context-independent tokens vs dependent tokens within the generation masks

async pre-compile

synchronize apply mask for CPU GPU?

How do we apply said masks to GPU block? Zero-overhead generations?

worst-case scenario for grammar compilation?

mask gen overhead: 36 μs\mu s

time linearly increase for batch size?

parallelize for compilation.

do we need to parallelize on vLLM?

no, xgrammar parallelize it, with pthread

shape of masks?

bitmask, tensors of vocab size concat with recast GPU

supported tokenizers?

GLM yet to be supported (Nov 22nd)

Given that detokenizer is in a separate process with vLLM, then can we stops duplicating this process?

Currently with xgrammar: detokenizer included in mask generations.

token_id tokens

future plans

  • Function calling support
  • Support more grammar (CFG, Python grammar)

compressed FSM for jump-ahead tokens.

Implemented in (Zheng et al., 2024)

Method 1: FSM-based decoding

  • intuition: Using FSM (Willard & Louf, 2023) to guide generations by increasing logit bias for tokens that conform to given JSON schema. This allows us to track the current state during decoding and filter out invalid tokens by applying logit bias to the output.

  • limitation: we can see that given construction of FSM requires token-level access, it can only transition the state by only one token at a time, resulting in slow decoding.

Method 2: Interleaved-based

  • intuition: breaks down JSON schemas, each containing either a chunk prefill part or constrained decoding part. They are then executed interleaved by inference system. Faster than per-token decoding given that chunked prefill components can process multiple tokens per forward pass

    See also https://github.com/guidance-ai/guidance#guidance-acceleration using llama.cpp as backend.

  • limitation:

    • interleaved-based require custom syntax, making it less expressive compared to regex.
    • struggles to deal with tokenization boundaries due to conflicts between decode and chunked prefill segments.
    • frequent communications between interpreter and back-end adds additional overhead.

Method 3: Jump-Forward Decoding with compressed FSM

tokenization boundary handling

During decoding, it is preferred to combine multiple characters into a single tokens.

For example, when decoding "Hello" in context of JSON decoding, LLM might output the following token ", He, llo, ",

This may cause some strange behaviour if we combine the last " with , (this regex "[\w\d\s]*" with the last , will lead to endless decoding because this token ", is not valid even if the LM wants to stop.)

Fix:

  • implement re-tokenization mechanism during jump-forward phase (append string instead of the tokens, followed with re-tokenization of the entire text) \to add approximately 4% of overhead
  • use a comprehensive regex to guide the decoding phase, instead of employing multiple concatenated regex 3

Coalescence

intuition: Instead of expanding to nn state, we can compress certain chunks into one state to reduce the size of said FSM.

figure 1: initial FSM state

figure 2: compressed FSM state

A way to adapt character regex to work with tokens in outlines:

import outlines.fsm as fsm
from outlines.fsm.regex import make_deterministic_fsm, create_fsm_index_tokenizer
 
new_fsm, _ = make_deterministic_fsm(fsm)
idx, _ = create_fsm_index_tokenizer(new_fsm, tokenizer)
stateDiagram-v2
    [*] --> InputPrompt: Start

    state "input prompt" as InputPrompt
    state "next-token probability distribution" as GetProb
    state "valid tokens" as ListTokens {
        [*] --> CheckTransitions
        CheckTransitions --> FilterTokens: Get index[0].keys()
        FilterTokens --> [*]
    }
    state "Sample Token" as SampleToken
    state "Update FSM State" as UpdateState

    InputPrompt --> GetProb: "model.generate"
    GetProb --> ListTokens: Get next-token distribution
    ListTokens --> SampleToken: Use filtered token list
    SampleToken --> UpdateState: Selected token X
    UpdateState --> [*]: new_state = index[0]["X"]
idx_with_tokens = {
  state: {tokenizer.tokenizer.decode([key]): value for key, value in transitions.items()}
  for state, transitions in idx.items()
}

note: each state of FSM represents a forward pass to the LM. In vanilla generation, this is essentially necessary. Thus there is no added overhead of FSM for controlling the generated outputs.

From state 2-6, we observer that there are eight different paths to get the same generations of name. We probably don’t need to do this, given that it will all give us result name

But suffice to say, we can hijack this behaviour to accelerate generations by append either of the following tokens word to currently generated sequence:

  • [”name”]
  • [”n”, “a”, “m”, “e”]
  • [”na”, “m”, “e”]
  • [”nam”, “e”]
  • [”n”, “am”, “e”]
  • [”n”, “ame”]
  • [”na”, “me”]
  • [”n”, “a”, “me”]

A simplified index can be shown as:

simplified_index = {
    0: {'{"': 2},
    2: {"name": 6},
    6: {'":"': 9},
    9: {'Paul': 14, 'John': 14},
    14: {'","': 17},
    17: {'age': 20},
    20: {'":': 22},
    22: {'20': 24, '30': 24},
    24: {'}': 25},
}

That’s at least a 5x speedup over structured generations, given that out of the 9 tokens, two states are single-state transitions. Therefore we only need to call the model twice!!

Guided generations with FSM.

(Willard & Louf, 2023), implemented at https://github.com/dottxt-ai/outlines

assumption: we are building against autoregressive transformers models

  • Let FP(V)\mathcal{F} \subset \mathcal{P}(\mathcal{V}), where P\mathcal{P} is the power set operator, be subset of multi-token string that ends with tokens EOSV\text{EOS} \in \mathcal{V}.
  • Text generation tasks is to draw samples from F\mathcal{F}

Notable sampling methods include greedy decoding (generate tokens recursively with highest probability tokens), beam search (but using heuristic to find the mode of distribution) 4

A pseudocode for sampling procedure is as follow:

"\\begin{algorithm}\n\\caption{LLM token sampling}\n\\begin{algorithmic}\n\\Function{sample}{$L$}\n \\State $s \\gets ()$\n \\For{$i \\gets 1, L$}\n \\State $\\alpha \\gets \\text{LM}(s, \\theta)$\n \\State Sample $s \\sim \\text{Categorical}(\\alpha)$\n \\If{$s = \\text{EOS}$}\n \\State \\textbf{break}\n \\EndIf\n \\State $s \\gets \\text{append}(s, s)$\n \\EndFor\n \\State \\Return $s$\n\\EndFunction\n\\end{algorithmic}\n\\end{algorithm}"

Algorithm 1 LLM token sampling

1:function sample(LL)

2:s()s \gets ()

3:for i1,Li \gets 1, L do

4:αLM(s,θ)\alpha \gets \text{LM}(s, \theta)

5:Sample sCategorical(α)s \sim \text{Categorical}(\alpha)

6:if s=EOSs = \text{EOS} then

7:break

8:end if

9:sappend(s,s)s \gets \text{append}(s, s)

10:end for

11:

12:return ss

13:end function

Given that we are dealing with finite discrete distribution, we can then compute an un-normalized conditional distribution by applying a boolean mask m:P(V){0,1}Nm: \mathcal{P}(\mathcal{V}) \to \{0,1\}^N, which restricts the support of original distribution:

α=LM(St~,θ)α~=m(St~)αst+1~Categorial(α~)\begin{aligned} \alpha &= \text{LM}(\tilde{S_t}, \theta) \\ \tilde{\alpha} &= m(\tilde{S_t}) \odot \alpha \\ \tilde{s_{t+1}} &\approx \text{Categorial}(\tilde{\alpha}) \end{aligned}

augmentation upon sampling algorithm

"\\begin{algorithm}\n\\caption{token sampling with masking}\n\\begin{algorithmic}\n\\Function{sample}{$L$}\n \\State $s \\gets ()$\n \\For{$i \\gets 1, L$}\n \\State $\\alpha \\gets \\text{LM}(s, \\theta)$\n \\State Construct the mask m($s$)\n \\State $\\tilde{\\alpha} \\gets m \\odot \\alpha$\n \\State Sample $\\tilde{s} \\sim \\text{Categorical}(\\tilde{\\alpha})$\n \\If{$\\tilde{s} = \\text{EOS}$}\n \\State \\textbf{break}\n \\EndIf\n \\State $s \\gets \\text{append}(s, \\tilde{s})$\n \\EndFor\n \\State \\Return $s$\n\\EndFunction\n\\end{algorithmic}\n\\end{algorithm}"

Algorithm 2 token sampling with masking

1:function sample(LL)

2:s()s \gets ()

3:for i1,Li \gets 1, L do

4:αLM(s,θ)\alpha \gets \text{LM}(s, \theta)

5:Construct the mask m(ss)

6:α~mα\tilde{\alpha} \gets m \odot \alpha

7:Sample s~Categorical(α~)\tilde{s} \sim \text{Categorical}(\tilde{\alpha})

8:if s~=EOS\tilde{s} = \text{EOS} then

9:break

10:end if

11:sappend(s,s~)s \gets \text{append}(s, \tilde{s})

12:end for

13:

14:return ss

15:end function

finite automaton

We define a finite-state machine, given by (Q,Σ,δ,q0,F)(Q, \Sigma , \delta, q_0, F) 5 where character comprising the strings in V\mathcal{V} are drawn from Σ\Sigma, i.e: VP(Σ)\mathcal{V} \in \mathcal{P}(\Sigma)

> FSM making for regular expression ([0-9]*)?\.?[0-9]*

determinism

Looping through the vocabulary is still the biggest issue. For that, we preprocess the vocabulary using Regex’s FSM and build a index. Thus a proceeding for producing matches starting at any point in the FSM is required.

We define finding sub-sequences of FSM MM that accept string vv as follow:

"\\begin{algorithm}\n\\caption{Find sub-sequences of the FSM $M$ that accept the string $v$}\n\\begin{algorithmic}\n\\Function{FindSubSequences}{$M, v$}\n \\State $M = (Q, \\Sigma, \\delta, q_0, F)$\n \\State $\\texttt{res} \\gets ()$\n \\For{$r \\in \\delta^{-1}(\\cdot, v_0)$} \\Comment{$\\text{ Loop through states that read } v_0$}\n \\State $p \\gets (r)$\n \\For{$i \\gets 1, |v| - 1$} \\Comment{$\\text{ Walk the FSM}$}\n \\If{$\\delta(r, v_i) = \\emptyset$} \\Comment{$\\text{ The FSM does not read } v_i$}\n \\State $p \\gets ()$\n \\State \\textbf{break} \\Comment{$\\text{ Stop walking and try the next start state}$}\n \\EndIf\n \\State $r \\gets \\delta(r, v_i)$\n \\State $p \\gets \\text{append}(p, r)$\n \\EndFor\n \\State $\\texttt{res} \\gets \\text{append}(\\texttt{res}, p)$\n \\EndFor\n \\State \\Return $\\texttt{res}$\n\\EndFunction\n\\end{algorithmic}\n\\end{algorithm}"

Algorithm 3 Find sub-sequences of the FSM MM that accept the string vv

1:function FindSubSequences(M,vM, v)

2:M=(Q,Σ,δ,q0,F)M = (Q, \Sigma, \delta, q_0, F)

3:res()\texttt{res} \gets ()

4:for rδ1(,v0)r \in \delta^{-1}(\cdot, v_0) do Loop through states that read v0\text{ Loop through states that read } v_0

5:p(r)p \gets (r)

6:for i1,v1i \gets 1, |v| - 1 do Walk the FSM\text{ Walk the FSM}

7:if δ(r,vi)=\delta(r, v_i) = \emptyset then The FSM does not read vi\text{ The FSM does not read } v_i

8:p()p \gets ()

9:break Stop walking and try the next start state\text{ Stop walking and try the next start state}

10:end if

11:rδ(r,vi)r \gets \delta(r, v_i)

12:pappend(p,r)p \gets \text{append}(p, r)

13:end for

14:resappend(res,p)\texttt{res} \gets \text{append}(\texttt{res}, p)

15:end for

16:

17:return res\texttt{res}

18:end function

We can then define construction of σ\sigma

"\\begin{algorithm}\n\\caption{Construct a map from FSM states to subsets of $\\mathcal{V}$}\n\\begin{algorithmic}\n\\Function{MapStatesToVocab}{$M, \\mathcal{V}$}\n \\State $M = (Q, \\Sigma, \\delta, q_0, F)$\n \\State Initialize the map $\\sigma$ with empty sets for each element in $Q$\n \\For{$v \\in \\mathcal{V}$} \\Comment{$\\text{Loop through the vocabulary}$}\n \\State $Z \\gets \\text{find\\_sub\\_sequences}(M, v)$\n \\For{$z \\in Z$} \\Comment{$\\text{Loop through state sequences accepting } v$}\n \\State $\\sigma(z_0) \\gets \\sigma(z_0) \\cup v$\n \\EndFor\n \\EndFor\n \\State \\Return $\\sigma$\n\\EndFunction\n\\end{algorithmic}\n\\end{algorithm}"

Algorithm 4 Construct a map from FSM states to subsets of V\mathcal{V}

1:function MapStatesToVocab(M,VM, \mathcal{V})

2:M=(Q,Σ,δ,q0,F)M = (Q, \Sigma, \delta, q_0, F)

3:Initialize the map σ\sigma with empty sets for each element in QQ

4:for vVv \in \mathcal{V} doLoop through the vocabulary\text{Loop through the vocabulary}

5:Zfind_sub_sequences(M,v)Z \gets \text{find\_sub\_sequences}(M, v)

6:for zZz \in Z doLoop through state sequences accepting v\text{Loop through state sequences accepting } v

7:σ(z0)σ(z0)v\sigma(z_0) \gets \sigma(z_0) \cup v

8:end for

9:end for

10:

11:return σ\sigma

12:end function

Bibliographie

  • Lew, A. K., Zhi-Xuan, T., Grand, G., & Mansinghka, V. K. (2023). Sequential Monte Carlo Steering of Large Language Models using Probabilistic Programs. arXiv preprint arXiv:2306.03081 [arxiv]
  • Willard, B. T., & Louf, R. (2023). Efficient Guided Generation for Large Language Models. arXiv preprint arXiv:2307.09702 [arxiv]
  • Zheng, L., Yin, L., Xie, Z., Sun, C., Huang, J., Yu, C. H., Cao, S., Kozyrakis, C., Stoica, I., Gonzalez, J. E., Barrett, C., & Sheng, Y. (2024). SGLang: Efficient Execution of Structured Language Model Programs. arXiv preprint arXiv:2312.07104 [arxiv]

Remarque

  1. Benchmark script can be found at vllm-project/vllm#10046.

    Current RFC vllm-project/vllm#5423

  2. vllm-project/vllm#6273 proposed a sampling controller interface, but @cadedaniel shares some concerns wrt fast-forward tokens

  3. this phenomena is also known as coalescence in structured generations, where it exploit deterministic structures in desired outputs to skip expensive forward pass

  4. (Lew et al., 2023) recently proposes a sequential Monte Carlo steering. The idea is to classify causal generations as a posteriori inference problem in a class of discrete probabilistic sequence models.

    See also Feynman-Kac transformers models

  5. finite state machine

    • QQ is a finite set of states
    • Σ\Sigma is a finite alphabet
    • δ:Q×ΣQ\delta: Q \times \Sigma \to Q is the transition function
    • q0Qq_0 \in Q is the start state
    • FQF \subseteq Q is the set of all accepted states.