TPU#
SGLang 通过 SGLang-JAX 后端支持高性能 TPU 推理,该后端专门为 Google Cloud TPU 进行了优化。基于 JAX 的实现为 TPU 硬件上的大语言模型 (LLM) 服务负载提供了卓越的吞吐量和低延迟。
有关 TPU 的特定问题或功能需求,请访问 sglang-jax GitHub issues 页面。
注意: SGLang 的 TPU 支持是通过 SGLang-JAX 后端实现的,这是一个专门的基于 JAX 的推理引擎,维护在独立仓库 sgl-project/sglang-jax 中。
系统要求#
支持的 TPU 硬件#
TPU 型号 |
HBM 显存 |
可用平台 |
|---|---|---|
TPU v6e |
32 GB |
Google Cloud |
TPU v7 |
每个核心 96 GB |
Google Cloud |
软件要求#
Python: 3.12 或更高版本
JAX: 支持 TPU 的最新版本
环境: Google Cloud TPU VM 或兼容的 TPU 运行时
可选: 用于简化云部署的 SkyPilot
功能支持矩阵#
SGLang-JAX 为生产级 LLM 服务提供全面的 TPU 优化功能
功能 |
支持状态 |
描述 |
|---|---|---|
高吞吐连续批处理 (Continuous Batching) |
✅ |
动态请求批处理,最大化 TPU 利用率 |
基数树 (Radix Tree) KV 缓存 |
✅ |
请求间高效的内存前缀共享 |
FlashAttention 后端 |
✅ |
针对长序列优化的 TPU Attention 算子 |
张量并行 (Tensor Parallelism) |
✅ |
将模型分布在多个 TPU 核心上 |
分页注意力 (Paged Attention) |
✅ |
灵活的分页 KV 缓存管理 |
投机采样 (EAGLE/EAGLE3) |
✅ |
兼容模型可提升 20-40% 的吞吐量 |
分块预填充 (Chunked Prefill) |
✅ |
预填充-解码混合批处理 |
OpenAI 兼容 API |
✅ |
OpenAI API 的无缝替代方案 |
数据并行注意力 (Data Parallel Attention) |
🚧 |
开发中 - 结合数据并行的注意力计算 |
量化 |
🚧 |
开发中 - 减少内存占用的模型量化 |
Multi-LoRA |
🚧 |
开发中 - 同时提供多个 LoRA 适配器服务 |
Attention 后端对比#
后端 |
分页注意力 (Paged Attention) |
投机采样 |
MLA |
滑动窗口 |
|---|---|---|---|---|
FlashAttention (fa) |
✅ |
✅ |
❌ |
✅ |
原生支持 |
❌ |
❌ |
❌ |
❌ |
注意: 推荐在生产环境负载中使用 FlashAttention 后端,因其具有卓越的内存效率和性能。
优化模型列表#
以下模型已在 TPU 部署中经过测试和优化
模型家族 |
性能状态 |
|---|---|
⭐ 生产环境推荐 |
|
⭐ 最佳性能 |
|
待改进 |
|
待改进 |
|
待改进 |
|
待改进 |
|
待改进 |
|
已在 TPU 验证 |
|
Bailing MoE |
待改进 |
安装#
方法 1:使用 PyPI(推荐)#
pip install sglang-jax
方法 2:从源码安装#
git clone https://github.com/sgl-project/sglang-jax
cd sglang-jax
uv venv --python 3.12 && source .venv/bin/activate
uv pip install -e "python[all]"
方法 3:使用 Docker#
注意: TPU 的 Docker 支持目前正在开发中。请使用 PyPI 或源码安装方法。
方法 4:通过 SkyPilot 使用云端 TPU#
SkyPilot 提供了在 Google Cloud TPU 上的简易部署方案
安装 SkyPilot 并配置 GCP 访问权限(参见 SkyPilot 文档)
创建一个 SkyPilot 配置文件
SkyPilot YAML: sglang-jax.sky.yaml
# sglang-jax.sky.yaml
resources:
accelerators: tpu-v6e-4
accelerator_args:
tpu_vm: True
runtime_version: v2-alpha-tpuv6e
run: |
git clone https://github.com/sgl-project/sglang-jax.git
cd sglang-jax
uv venv --python 3.12
source .venv/bin/activate
uv pip install -e "python[all]"
启动您的 TPU 集群
# Standard deployment
sky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp
# With spot instances for cost savings
sky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp --use-spot
启动推理引擎服务#
基础示例:Qwen-7B#
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 -u -m sgl_jax.launch_server \
--model-path Qwen/Qwen-7B-Chat \
--trust-remote-code \
--dist-init-addr=0.0.0.0:10011 \
--nnodes=1 \
--tp-size=4 \
--device=tpu \
--random-seed=3 \
--node-rank=0 \
--mem-fraction-static=0.8 \
--max-prefill-tokens=8192 \
--download-dir=/tmp \
--dtype=bfloat16 \
--skip-server-warmup \
--host 0.0.0.0 \
--port 30000
关键参数说明
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache- 启用 JIT 编译缓存,加速后续运行时的服务器启动速度--tp-size=4- 张量并行大小;请将其与您的 TPU 核心数匹配(通常为 1、4 或 8)--device=tpu- 指定 TPU 设备(这是 sglang-jax 的默认设置)--dtype=bfloat16- 使用 bfloat16 精度,这是 TPU 优化的格式--mem-fraction-static=0.8- 为静态内存分配 80% 的 TPU HBM(可在 0.2 到 0.9 之间调节)--max-prefill-tokens=8192- 预填充阶段处理的最大 Token 数量
高性能配置:Qwen3-8B#
适用于追求最佳吞吐量的生产环境负载
python3 -u -m sgl_jax.launch_server \
--model-path Qwen/Qwen3-8B \
--trust-remote-code \
--tp-size=4 \
--device=tpu \
--mem-fraction-static=0.8 \
--chunked-prefill-size=2048 \
--dtype=bfloat16 \
--max-running-requests=256 \
--page-size=128 \
--attention-backend=fa
高级:投机采样 (EAGLE3)#
投机采样可为兼容模型提升 20-40% 的吞吐量
python3 -u -m sgl_jax.launch_server \
--model-path Qwen/Qwen3-32B \
--trust-remote-code \
--device=tpu \
--tp-size=4 \
--mem-fraction-static=0.8 \
--max-prefill-tokens=4096 \
--attention-backend=fa \
--dtype=bfloat16 \
--port=30000 \
--host=0.0.0.0 \
--disable-overlap-schedule \
--speculative-algorithm=EAGLE3 \
--speculative-draft-model-path=AngelSlim/Qwen3-32B_eagle3 \
--page-size=64 \
--speculative-eagle-topk=1 \
--speculative-num-steps=3 \
--speculative-num-draft-tokens=4
注意: 投机采样目前支持 Qwen3 和 LLaMA 模型系列。详见 投机采样文档 以获取详细配置指南。
多节点分布式服务#
适用于需要多个 TPU VM 的大型模型
# Node 0 (coordinator)
python3 -m sgl_jax.launch_server \
--model-path MODEL_PATH \
--dist-init-addr=NODE0_IP:10011 \
--nnodes=2 \
--node-rank=0 \
--tp-size=8 \
[other parameters...]
# Node 1 (worker)
python3 -m sgl_jax.launch_server \
--model-path MODEL_PATH \
--dist-init-addr=NODE0_IP:10011 \
--nnodes=2 \
--node-rank=1 \
--tp-size=8 \
[other parameters...]
使用请求进行基准测试#
吞吐量测试#
基础吞吐量基准测试
python3 -m sgl_jax.bench_serving \
--backend sgl-jax \
--dataset-name random \
--num-prompts=100 \
--random-input=512 \
--random-output=128 \
--max-concurrency=8 \
--random-range-ratio=1 \
--warmup-requests=0
延迟测试#
测量单批次延迟
python3 -m sgl_jax.bench_one_batch_server \
--base-url http://127.0.0.1:30000 \
--model-path Qwen/Qwen-7B-Chat \
--batch-size=32 \
--input-len=256 \
--output-len=32
综合基准测试脚本#
用于在不同配置下进行系统性的性能评估
#!/bin/bash
set -e
backend=${1:-sgl-jax}
num_prompts_per_concurrency=3
input_seq_lens=(1024 4096 8192)
output_seq_lens=(1 1024)
max_concurrencies=(8 16 32 64 128 256)
for input_seq_len in "${input_seq_lens[@]}"; do
for output_seq_len in "${output_seq_lens[@]}"; do
echo "======================================="
echo "Testing ISL/OSL: $input_seq_len/$output_seq_len"
echo "======================================="
for max_concurrency in "${max_concurrencies[@]}"; do
num_prompts=$((num_prompts_per_concurrency * max_concurrency))
python3 -m sgl_jax.bench_serving \
--backend ${backend} \
--dataset-name random \
--num-prompts ${num_prompts} \
--random-input ${input_seq_len} \
--random-output ${output_seq_len} \
--max-concurrency ${max_concurrency} \
--random-range-ratio 1 \
--disable-ignore-eos \
--warmup-requests 0
done
done
done
获取所有基准测试参数的详细帮助
python3 -m sgl_jax.bench_serving --help
参见 基准测试与性能分析指南 了解高级基准测试技术及使用 JAX Profiler 进行性能分析的方法。
性能优化#
内存优化#
减少内存占用
降低
--mem-fraction-static(从 0.8 → 0.5 → 0.3)减小
--max-prefill-tokens(从 16384 → 8192 → 4096)减少
--max-running-requests
处理 OOM 错误
从保守的内存设置开始 (
--mem-fraction-static=0.5)逐渐增加直到找到最佳平衡点
增加
--page-size以获得更好的内存局部性 (1 → 16 → 64 → 128)
吞吐量优化#
最大化每秒 Token 数
使用 FlashAttention 后端:
--attention-backend=fa为 Qwen3 模型启用投机采样 (EAGLE3) (提升 20-40%)
将
--max-running-requests增加到 256+将
--mem-fraction-static设置为 0.8+(如果内存允许)使用较大的页面大小 (64-128)
启用分块预填充:
--chunked-prefill-size=2048
延迟优化#
最小化首字延迟 (TTFT) 和 Token 间延迟
将
--page-size降低至 1-4降低
--max-running-requests(16-32) 以使用较小的批次减小
--chunked-prefill-size使用保守的内存设置以避免 GC 停顿
TPU 特定优化#
JIT 编译缓存
export JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache
务必设置此环境变量以缓存编译后的算子,从而加速服务器启动。
数据类型优化: 使用
--dtype=bfloat16进行 TPU 原生优化。TPU 是专门为 bfloat16 计算设计的。张量并行: 将
--tp-size匹配到您的 TPU 核心配置(1、4 或 8)以实现最佳模型分布。Attention 后端: 在生产负载中始终使用
--attention-backend=fa(FlashAttention)。
故障排除#
OOM(内存不足)错误#
如果遇到内存不足错误
将
--mem-fraction-static从 0.8 降低到 0.5 或更低将
--max-prefill-tokens从 8192 降低到 4096 或 2048降低
--max-running-requests以减小并发批次大小增加
--page-size以提高内存布局效率
编译时间过长#
如果服务器启动时间过长
确保正确设置了
JAX_COMPILATION_CACHE_DIR理解首次运行需要进行 JIT 编译(这是正常现象)
后续运行将通过缓存的编译内容显著加快速度
考虑使用
--skip-server-warmup将编译推迟到第一个请求时
吞吐量低#
如果未达到预期吞吐量
验证
--tp-size是否匹配您的 TPU 核心配置检查是否启用了
--attention-backend=fa增加
--max-running-requests以允许形成更大的批次考虑为兼容模型启用投机采样
确保内存设置允许足够大的批次大小
连接问题#
如果客户端无法连接到服务器
确保设置了
--host=0.0.0.0以允许外部访问(不只是127.0.0.1)验证防火墙规则允许指定端口(默认:30000)的流量
检查服务器进程是否正在运行:
curl https://:30000/health
高级功能#
投机采样#
SGLang-JAX 为 Qwen3 和 LLaMA 模型系列支持 EAGLE 和 EAGLE3 投机采样算法。投机采样可以在不影响输出质量的情况下提高 20-40% 的吞吐量。
详见 投机采样文档 了解详细配置和支持的模型组合。
分块预填充 (Chunked Prefill)#
启用预填充-解码混合批处理以获得更好的 TPU 利用率
--chunked-prefill-size=2048 --enable-mixed-chunk
这允许调度器在同一个批次中混合预填充操作和解码操作,从而提高整体吞吐量。
自定义 Attention 后端#
SGLang-JAX 支持基于插件的 Attention 后端系统。您可以实现针对特定用例优化的自定义 Attention 算子。
详见 Attention 后端文档 了解实现细节。
环境验证#
在部署前验证您的 TPU 设置
python -c "from sgl_jax import check_env; check_env.check_env()"
该命令检查:
已安装的包版本
TPU 设备可用性及规格
系统资源与配置
设置的兼容性
参与贡献#
欢迎通过贡献来改进 SGLang-JAX 中的 TPU 支持!
贡献领域#
查看 开发路线图 以了解计划中的功能,并寻找贡献新功能的机会。
目前的贡献领域包括:
针对特定 TPU 代际的性能优化
支持额外的模型架构
文档改进和示例
缺陷报告与修复
基准测试结果与性能分析
如何贡献#
阅读 贡献指南
加入 SGL-JAX Slack 社区 进行讨论
在 sglang-jax/issues 提交问题
在 TPU 上进行测试#
针对需要 TPU 访问权限进行测试的贡献者
参考 TPU 资源指南 获取访问 TPU 硬件的信息
使用 SkyPilot 配合抢占式实例 (Spot Instances) 进行高性价比测试
遵循 基准测试与性能分析指南 进行性能验证