LLM推理过程
大约 3 分钟
最近真是高产,其实这些东西也断断续续的学了很久了
直到最近一周,才有一种连点成线的感觉
Tokenize: 把文本切分为token序列
文本在预处理阶段,经过tokenize(分词)和embedding(嵌入),转为机器可以理解、计算的向量:
- 加载
tokenizer,格式如下(Base64编码):
X24= 1107
aWdo= 1108
IHRoYW4= 1109
- 构造前缀树(Trie)
# 词表: "a", "ab", "abc", "b"
root
├── a (是完整token)
│ ├── b (是完整token)
│ │ └── c (是完整token)
└── b (是完整token)
- 匹配算法(以BPE为例)
- 从当前位置开始,在前缀树中查找最长可能匹配
- 找到"abc"匹配成功
- 输出对应的Token ID
- 处理未知字符(OOV)
- 退回到字节级别编码(如UTF-8字节)
- 使用特殊token(如
<unk>)表示未知 - 或者拆分成更小的子词单元
Embedding: 对每个token结合语义和位置信息进行embedding
- input embedding: 对语义进行嵌入
- position encoding: 添加位置信息
Transformer: 计算KV
在得到向量后,向量将作为decoder结构的输入,在N个Transformer Decoder Blocks上进行计算
- Self-Attention(自注意力): 使用带有
Causal Mask的自注意力进行计算,确保模型在训练和推理过程中,都不能看到未来的token,进而达到训推一致。 - 概率分布: 最终通过
LM Head (Linear + Softmax) → logits得到下一个Token出现的概率分布
输出Token选择
- Temperature: 温度越大概率分布越平缓,因此创造性越强;反之就越稳定,不容易出现幻觉
- Top-k: 考虑最高概率的
k个token - Top-p: 动态调整候选集大小,只累积到概率和达到p的最小token集合 这里直接上一个代码更好理解
# 简化版流程
def select_next_token(logits, temperature=1.0, top_k=None, top_p=None):
"""
logits: [vocab_size] 每个token的未归一化分数
返回: 选择的token_id
"""
# 1. 应用温度参数
logits = logits / temperature
# 2. 转换为概率分布
probs = softmax(logits)
# 3. 应用top-k过滤(如果启用)
if top_k is not None:
probs = apply_top_k(probs, top_k)
# 4. 应用top-p过滤(如果启用)
if top_p is not None:
probs = apply_top_p(probs, top_p)
# 5. 从处理后的分布中采样
next_token = sample_from_distribution(probs)
return next_token
Prefill 和 Decode两阶段
| 项目 | Prefill 入口 | Decode 入口 |
|---|---|---|
| 输入长度 | >1(通常是几十到几万) | =1 |
| past_kv 参数 | None 或空 | 非空(来自上一步) |
| 是否生成 KV Cache | 是(初始化) | 是(追加) |
| 是否并行 | 是 | 否(自回归) |
| 在代码中的标志 | 第一次 forward 调用 | 第二次及以后的 forward 调用 |
Prefill
在实际的请求中,第一次推理时,整个prompt都是新的输入,所以一次性送入模型,最终获取第一个输出的Token,这个过程就是Prefill。
# 假设 prompt_tokens = [1061, 338, 2945, 0] # "What is AI?"
logits, kv_cache = model.forward(
input_ids=prompt_tokens, # ← 这就是 Prefill 的入口!
past_kv=None # 没有历史缓存
)
next_token = sample(logits[-1]) # 取最后一个位置的 logits 预测下一个 token
在这个调用中:
- 模型对所有 L 个 token 并行计算 attention。
- 生成每一层的 K、V,并保存为 kv_cache。
- 输出 logits 的最后一个位置用于预测第 L 个 token(即第一个生成的 token)。
Decode
在第一个token生成之后,每次只输入一个新的token,并附带之前的KV Cache
# 第一次 decode:生成第 L 个 token 后,继续生成第 L+1 个
input_id = next_token.unsqueeze(0) # shape: [1]
logits, kv_cache = model.forward(
input_ids=input_id, # ← 只传入一个 token!
past_kv=kv_cache # ← 复用之前 Prefill 生成的缓存
)
next_token = sample(logits[0]) # logits 只有一个位置
正因为两个过程,都是执行的同样的forward所以才有chunked prefill的优化空间