TPU#

SGLang 通过 SGLang-JAX 后端支持高性能 TPU 推理,该后端专门为 Google Cloud TPU 进行了优化。基于 JAX 的实现为 TPU 硬件上的大语言模型 (LLM) 服务负载提供了卓越的吞吐量和低延迟。

有关 TPU 的特定问题或功能需求,请访问 sglang-jax GitHub issues 页面

注意: SGLang 的 TPU 支持是通过 SGLang-JAX 后端实现的,这是一个专门的基于 JAX 的推理引擎,维护在独立仓库 sgl-project/sglang-jax 中。

系统要求#

支持的 TPU 硬件#

TPU 型号

HBM 显存

可用平台

TPU v6e

32 GB

Google Cloud

TPU v7

每个核心 96 GB

Google Cloud

软件要求#

  • Python: 3.12 或更高版本

  • JAX: 支持 TPU 的最新版本

  • 环境: Google Cloud TPU VM 或兼容的 TPU 运行时

  • 可选: 用于简化云部署的 SkyPilot

功能支持矩阵#

SGLang-JAX 为生产级 LLM 服务提供全面的 TPU 优化功能

功能

支持状态

描述

高吞吐连续批处理 (Continuous Batching)

动态请求批处理,最大化 TPU 利用率

基数树 (Radix Tree) KV 缓存

请求间高效的内存前缀共享

FlashAttention 后端

针对长序列优化的 TPU Attention 算子

张量并行 (Tensor Parallelism)

将模型分布在多个 TPU 核心上

分页注意力 (Paged Attention)

灵活的分页 KV 缓存管理

投机采样 (EAGLE/EAGLE3)

兼容模型可提升 20-40% 的吞吐量

分块预填充 (Chunked Prefill)

预填充-解码混合批处理

OpenAI 兼容 API

OpenAI API 的无缝替代方案

数据并行注意力 (Data Parallel Attention)

🚧

开发中 - 结合数据并行的注意力计算

量化

🚧

开发中 - 减少内存占用的模型量化

Multi-LoRA

🚧

开发中 - 同时提供多个 LoRA 适配器服务

Attention 后端对比#

后端

分页注意力 (Paged Attention)

投机采样

MLA

滑动窗口

FlashAttention (fa)

原生支持

注意: 推荐在生产环境负载中使用 FlashAttention 后端,因其具有卓越的内存效率和性能。

优化模型列表#

以下模型已在 TPU 部署中经过测试和优化

模型家族

性能状态

Qwen 3

⭐ 生产环境推荐

Qwen 3 MoE

⭐ 最佳性能

Qwen 2

待改进

Qwen 2 MoE

待改进

Qwen 1.5

待改进

Llama/LLaMA

待改进

Grok-2

待改进

Gemma 2

已在 TPU 验证

Bailing MoE

待改进

安装#

方法 2:从源码安装#

git clone https://github.com/sgl-project/sglang-jax
cd sglang-jax
uv venv --python 3.12 && source .venv/bin/activate
uv pip install -e "python[all]"

方法 3:使用 Docker#

注意: TPU 的 Docker 支持目前正在开发中。请使用 PyPI 或源码安装方法。

方法 4:通过 SkyPilot 使用云端 TPU#

SkyPilot 提供了在 Google Cloud TPU 上的简易部署方案

  1. 安装 SkyPilot 并配置 GCP 访问权限(参见 SkyPilot 文档

  2. 创建一个 SkyPilot 配置文件

SkyPilot YAML: sglang-jax.sky.yaml
# sglang-jax.sky.yaml
resources:
   accelerators: tpu-v6e-4
   accelerator_args:
      tpu_vm: True
      runtime_version: v2-alpha-tpuv6e

run: |
  git clone https://github.com/sgl-project/sglang-jax.git
  cd sglang-jax
  uv venv --python 3.12
  source .venv/bin/activate
  uv pip install -e "python[all]"
  1. 启动您的 TPU 集群

# Standard deployment
sky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp

# With spot instances for cost savings
sky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp --use-spot

启动推理引擎服务#

基础示例:Qwen-7B#

JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 -u -m sgl_jax.launch_server \
    --model-path Qwen/Qwen-7B-Chat \
    --trust-remote-code \
    --dist-init-addr=0.0.0.0:10011 \
    --nnodes=1 \
    --tp-size=4 \
    --device=tpu \
    --random-seed=3 \
    --node-rank=0 \
    --mem-fraction-static=0.8 \
    --max-prefill-tokens=8192 \
    --download-dir=/tmp \
    --dtype=bfloat16 \
    --skip-server-warmup \
    --host 0.0.0.0 \
    --port 30000

关键参数说明

  1. JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache - 启用 JIT 编译缓存,加速后续运行时的服务器启动速度

  2. --tp-size=4 - 张量并行大小;请将其与您的 TPU 核心数匹配(通常为 1、4 或 8)

  3. --device=tpu - 指定 TPU 设备(这是 sglang-jax 的默认设置)

  4. --dtype=bfloat16 - 使用 bfloat16 精度,这是 TPU 优化的格式

  5. --mem-fraction-static=0.8 - 为静态内存分配 80% 的 TPU HBM(可在 0.2 到 0.9 之间调节)

  6. --max-prefill-tokens=8192 - 预填充阶段处理的最大 Token 数量

高性能配置:Qwen3-8B#

适用于追求最佳吞吐量的生产环境负载

python3 -u -m sgl_jax.launch_server \
    --model-path Qwen/Qwen3-8B \
    --trust-remote-code \
    --tp-size=4 \
    --device=tpu \
    --mem-fraction-static=0.8 \
    --chunked-prefill-size=2048 \
    --dtype=bfloat16 \
    --max-running-requests=256 \
    --page-size=128 \
    --attention-backend=fa

高级:投机采样 (EAGLE3)#

投机采样可为兼容模型提升 20-40% 的吞吐量

python3 -u -m sgl_jax.launch_server \
    --model-path Qwen/Qwen3-32B \
    --trust-remote-code \
    --device=tpu \
    --tp-size=4 \
    --mem-fraction-static=0.8 \
    --max-prefill-tokens=4096 \
    --attention-backend=fa \
    --dtype=bfloat16 \
    --port=30000 \
    --host=0.0.0.0 \
    --disable-overlap-schedule \
    --speculative-algorithm=EAGLE3 \
    --speculative-draft-model-path=AngelSlim/Qwen3-32B_eagle3 \
    --page-size=64 \
    --speculative-eagle-topk=1 \
    --speculative-num-steps=3 \
    --speculative-num-draft-tokens=4

注意: 投机采样目前支持 Qwen3 和 LLaMA 模型系列。详见 投机采样文档 以获取详细配置指南。

多节点分布式服务#

适用于需要多个 TPU VM 的大型模型

# Node 0 (coordinator)
python3 -m sgl_jax.launch_server \
    --model-path MODEL_PATH \
    --dist-init-addr=NODE0_IP:10011 \
    --nnodes=2 \
    --node-rank=0 \
    --tp-size=8 \
    [other parameters...]

# Node 1 (worker)
python3 -m sgl_jax.launch_server \
    --model-path MODEL_PATH \
    --dist-init-addr=NODE0_IP:10011 \
    --nnodes=2 \
    --node-rank=1 \
    --tp-size=8 \
    [other parameters...]

使用请求进行基准测试#

吞吐量测试#

基础吞吐量基准测试

python3 -m sgl_jax.bench_serving \
    --backend sgl-jax \
    --dataset-name random \
    --num-prompts=100 \
    --random-input=512 \
    --random-output=128 \
    --max-concurrency=8 \
    --random-range-ratio=1 \
    --warmup-requests=0

延迟测试#

测量单批次延迟

python3 -m sgl_jax.bench_one_batch_server \
    --base-url http://127.0.0.1:30000 \
    --model-path Qwen/Qwen-7B-Chat \
    --batch-size=32 \
    --input-len=256 \
    --output-len=32

综合基准测试脚本#

用于在不同配置下进行系统性的性能评估

#!/bin/bash
set -e

backend=${1:-sgl-jax}
num_prompts_per_concurrency=3
input_seq_lens=(1024 4096 8192)
output_seq_lens=(1 1024)
max_concurrencies=(8 16 32 64 128 256)

for input_seq_len in "${input_seq_lens[@]}"; do
    for output_seq_len in "${output_seq_lens[@]}"; do
        echo "======================================="
        echo "Testing ISL/OSL: $input_seq_len/$output_seq_len"
        echo "======================================="
        for max_concurrency in "${max_concurrencies[@]}"; do
            num_prompts=$((num_prompts_per_concurrency * max_concurrency))
            python3 -m sgl_jax.bench_serving \
                --backend ${backend} \
                --dataset-name random \
                --num-prompts ${num_prompts} \
                --random-input ${input_seq_len} \
                --random-output ${output_seq_len} \
                --max-concurrency ${max_concurrency} \
                --random-range-ratio 1 \
                --disable-ignore-eos \
                --warmup-requests 0
        done
    done
done

获取所有基准测试参数的详细帮助

python3 -m sgl_jax.bench_serving --help

参见 基准测试与性能分析指南 了解高级基准测试技术及使用 JAX Profiler 进行性能分析的方法。

性能优化#

内存优化#

减少内存占用

  • 降低 --mem-fraction-static (从 0.8 → 0.5 → 0.3)

  • 减小 --max-prefill-tokens (从 16384 → 8192 → 4096)

  • 减少 --max-running-requests

处理 OOM 错误

  • 从保守的内存设置开始 (--mem-fraction-static=0.5)

  • 逐渐增加直到找到最佳平衡点

  • 增加 --page-size 以获得更好的内存局部性 (1 → 16 → 64 → 128)

吞吐量优化#

最大化每秒 Token 数

  • 使用 FlashAttention 后端:--attention-backend=fa

  • 为 Qwen3 模型启用投机采样 (EAGLE3) (提升 20-40%)

  • --max-running-requests 增加到 256+

  • --mem-fraction-static 设置为 0.8+(如果内存允许)

  • 使用较大的页面大小 (64-128)

  • 启用分块预填充:--chunked-prefill-size=2048

延迟优化#

最小化首字延迟 (TTFT) 和 Token 间延迟

  • --page-size 降低至 1-4

  • 降低 --max-running-requests (16-32) 以使用较小的批次

  • 减小 --chunked-prefill-size

  • 使用保守的内存设置以避免 GC 停顿

TPU 特定优化#

  1. JIT 编译缓存

    export JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache
    

    务必设置此环境变量以缓存编译后的算子,从而加速服务器启动。

  2. 数据类型优化: 使用 --dtype=bfloat16 进行 TPU 原生优化。TPU 是专门为 bfloat16 计算设计的。

  3. 张量并行:--tp-size 匹配到您的 TPU 核心配置(1、4 或 8)以实现最佳模型分布。

  4. Attention 后端: 在生产负载中始终使用 --attention-backend=fa (FlashAttention)。

故障排除#

OOM(内存不足)错误#

如果遇到内存不足错误

  1. --mem-fraction-static 从 0.8 降低到 0.5 或更低

  2. --max-prefill-tokens 从 8192 降低到 4096 或 2048

  3. 降低 --max-running-requests 以减小并发批次大小

  4. 增加 --page-size 以提高内存布局效率

编译时间过长#

如果服务器启动时间过长

  1. 确保正确设置了 JAX_COMPILATION_CACHE_DIR

  2. 理解首次运行需要进行 JIT 编译(这是正常现象)

  3. 后续运行将通过缓存的编译内容显著加快速度

  4. 考虑使用 --skip-server-warmup 将编译推迟到第一个请求时

吞吐量低#

如果未达到预期吞吐量

  1. 验证 --tp-size 是否匹配您的 TPU 核心配置

  2. 检查是否启用了 --attention-backend=fa

  3. 增加 --max-running-requests 以允许形成更大的批次

  4. 考虑为兼容模型启用投机采样

  5. 确保内存设置允许足够大的批次大小

连接问题#

如果客户端无法连接到服务器

  1. 确保设置了 --host=0.0.0.0 以允许外部访问(不只是 127.0.0.1

  2. 验证防火墙规则允许指定端口(默认:30000)的流量

  3. 检查服务器进程是否正在运行:curl https://:30000/health

高级功能#

投机采样#

SGLang-JAX 为 Qwen3 和 LLaMA 模型系列支持 EAGLE 和 EAGLE3 投机采样算法。投机采样可以在不影响输出质量的情况下提高 20-40% 的吞吐量。

详见 投机采样文档 了解详细配置和支持的模型组合。

分块预填充 (Chunked Prefill)#

启用预填充-解码混合批处理以获得更好的 TPU 利用率

--chunked-prefill-size=2048 --enable-mixed-chunk

这允许调度器在同一个批次中混合预填充操作和解码操作,从而提高整体吞吐量。

自定义 Attention 后端#

SGLang-JAX 支持基于插件的 Attention 后端系统。您可以实现针对特定用例优化的自定义 Attention 算子。

详见 Attention 后端文档 了解实现细节。

环境验证#

在部署前验证您的 TPU 设置

python -c "from sgl_jax import check_env; check_env.check_env()"

该命令检查:

  • 已安装的包版本

  • TPU 设备可用性及规格

  • 系统资源与配置

  • 设置的兼容性

参与贡献#

欢迎通过贡献来改进 SGLang-JAX 中的 TPU 支持!

贡献领域#

查看 开发路线图 以了解计划中的功能,并寻找贡献新功能的机会。

目前的贡献领域包括:

  • 针对特定 TPU 代际的性能优化

  • 支持额外的模型架构

  • 文档改进和示例

  • 缺陷报告与修复

  • 基准测试结果与性能分析

如何贡献#

  1. 访问 sglang-jax 仓库

  2. 阅读 贡献指南

  3. 加入 SGL-JAX Slack 社区 进行讨论

  4. sglang-jax/issues 提交问题

在 TPU 上进行测试#

针对需要 TPU 访问权限进行测试的贡献者

参考资料#

文档#

外部资源#