长文本流水线并行#
为什么使用流水线并行?#
通过分块预填充(chunked prefill),流水线并行(pipeline parallelism)有潜力降低长文本输入的 TTFT(首字延迟)。对于每个请求,其输入 token 可以被划分为多个分块(chunks),每个分块不长于设定的分块预填充大小。同一请求的不同分块可以由不同节点同时处理,从而实现处理并行化并降低 TTFT。
此外,流水线并行仅在每个流水线阶段的边界处需要跨节点通信,与大规模张量并行(TP)相比,可以实现更好的计算-通信重叠。因此,它也是一种旨在提高吞吐量的极具前景的并行化策略。
基于异步通信的实现重构#
SGLang 已经支持流水线并行(#5724)一段时间,并使其与 PD 分离功能(PD Disaggregation, #8846)兼容,但之前的实现并不完美,仍有巨大的性能提升空间。
为了减少流水线气泡(bubbles),SGLang 现在利用异步发送(asynchronous sends)进行 PP 阶段间的通信。该方法最初在 #7979 中提出,并在 #11852 中经过重新设计并合并。
动态分块指南#
为什么需要动态分块#
固定大小的分块预填充可能会在流水线中产生气泡,尤其是在 PP 规模较大时。这种现象的主要原因是,即使每个分块大小相同,模型的运行时间也是非均匀的(由 Transformer 结构特性导致)。前缀序列长度越长,该分块的运行时间就越长。这些气泡会传播到下一个阶段,并显著降低较大 PP rank 的扩展效率。
为了解决这个问题,我们引入了动态分块机制,并使用二次函数来拟合这种情况:运行时间(前缀序列长度 + 下一分块大小)- 运行时间(前缀序列长度)= 运行时间(初始分块大小)。基于此方法,我们可以动态减小分块大小,以最大限度地减少由阶段对齐失准引起的气泡。
分块预填充大小与平滑因子#
当启用 --enable-dynamic-chunking 时,序列的每个分块大小由二次模型动态确定,该模型根据初始分块长度的估计运行时间来预测下一个分块大小。在这种情况下,我们使用 --chunked-prefill-size 来设置初始分块大小。切换到动态分块模式时,初始分块大小(--chunked-prefill-size)应设置为一个较大的值(与原始固定分块大小相当),以免产生过多的分块。
SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR 是控制动态分块算法平滑因子的参数,默认值为 0.75。它决定了在预填充阶段分块大小的变化幅度。较大的值意味着更激进的分块大小变化,这可能会带来更好的性能,但也会导致更大的分块大小波动(末尾的分块可能变得非常小,导致性能下降)和更多的总分块数。当设置为 1 时,分块大小将严格根据上述预测下一分块大小的二次模型进行调整。较小的值意味着更保守的变化。当设置为 0 时,分块大小将不会动态调整,即等同于传统的固定大小分块预填充。
流水线并行最佳实践#
调优分块预填充大小#
优化分块预填充大小对于平衡流水线效率和资源利用率至关重要。理想的大小取决于模型架构、硬件配置和典型的输入长度。我们建议从较小的分块大小(如 4K)开始,逐步增加,直到找到适合您特定用例的最佳大小。或者,您可以分析硬件能力并根据 Roofline 模型确定最佳分块大小。
启用动态分块并调整平滑因子(实验性功能)#
SGLang 还提供了动态分块方案,可进一步提升性能。该功能目前为实验性功能,需要一定量的调优实验,可能并不适用于所有工作负载。此外,微调平滑因子可以帮助针对特定工作负载和模型特性优化性能。
NVIDIA H20 案例研究#
在评估 2K 到 16K 的固定分块预填充大小时,实验结果显示:4K 分块大小为 DeepSeek-V3.1 提供了最佳的预填充 TTFT 性能;6K 分块大小为 Qwen3-235B-A22B-Thinking-2507-FP8 提供了最佳的预填充 TTFT 性能。
启用动态分块时,我们首先将最佳固定分块大小乘以 3 作为初始分块大小。通过实验,我们发现 2-3 倍的乘数可以提供适当的平衡——既避免了初始流水线气泡过多,又确保了后续分块不会随着上下文长度增加而变得太小。在使用默认动态分块平滑因子 0.75 的基础上,我们进行了参数调优,确定在 12K 初始分块大小下,0.65 的平滑因子对 DeepSeek-V3.1 效果最佳;在 18K 初始分块大小下,0.8 的平滑因子对 Qwen3-235B-A22B-Thinking-2507-FP8 效果最佳。
DeepSeek-V3.1 (128K 输入 Token 长度)#
# prefill node 0 (fixed chunked prefill size)
python3 -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 --trust-remote-code \
--nnodes 4 --node-rank 0 --tp 8 --pp-size 4 \
--port 30000 --dist-init-addr 192.168.0.137:62001 \
--disable-radix-cache --mem-fraction-static 0.8 \
--attention-backend fa3 --host 0.0.0.0 --watchdog-timeout 3600 \
--max-running-requests 128 --chunked-prefill-size 4096
# prefill node 0 (with dynamic chunking)
export SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR=0.65
python3 -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 --trust-remote-code \
--nnodes 4 --node-rank 0 --tp 8 --pp-size 4 \
--port 30000 --dist-init-addr 192.168.0.137:62001 \
--disable-radix-cache --mem-fraction-static 0.8 \
--attention-backend fa3 --host 0.0.0.0 --watchdog-timeout 3600 \
--max-running-requests 128 --chunked-prefill-size 12288 --enable-dynamic-chunking
Qwen3-235B-A22B-Thinking-2507-FP8 (128K 输入 Token 长度)#
# prefill node 0 (fixed chunked prefill size)
python3 -m sglang.launch_server \
--model-path Qwen/Qwen3-235B-A22B-Thinking-2507-FP8 --trust-remote-code \
--nnodes 2 --node-rank 0 --tp 4 --pp-size 2 \
--port 30000 --dist-init-addr 192.168.0.137:62001 \
--disable-radix-cache --mem-fraction-static 0.8 \
--attention-backend fa3 --host 0.0.0.0 --watchdog-timeout 3600 \
--max-running-requests 128 --chunked-prefill-size 6144
# prefill node 0 (with dynamic chunking)
export SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR=0.8
python3 -m sglang.launch_server \
--model-path Qwen/Qwen3-235B-A22B-Thinking-2507-FP8 --trust-remote-code \
--nnodes 2 --node-rank 0 --tp 4 --pp-size 2 \
--port 30000 --dist-init-addr 192.168.0.137:62001 \
--disable-radix-cache --mem-fraction-static 0.8 \
--attention-backend fa3 --host 0.0.0.0 --watchdog-timeout 3600 \
--max-running-requests 128 --chunked-prefill-size 18432 --enable-dynamic-chunking
注意:--disable-radix-cache 仅为了可复现的基准测试而启用。不建议在生产环境中使用它。
带有 PD 分离的流水线并行最佳实践#
待补充。请关注 PD 分离流水线并行的最新更新。