结构化输出#

您可以指定 JSON Schema、正则表达式EBNF 来约束模型的输出。模型的输出将保证遵循给定的约束。每个请求只能指定一个约束参数(json_schemaregexebnf)。

SGLang 支持三种语法后端

  • Outlines:支持 JSON Schema 和正则表达式约束。

  • XGrammar(默认):支持 JSON Schema、正则表达式和 EBNF 约束。

  • Llguidance:支持 JSON Schema、正则表达式和 EBNF 约束。

我们建议使用 XGrammar,因为它具有更好的性能和实用性。XGrammar 当前使用 GGML BNF 格式。更多详细信息,请参阅 XGrammar 技术概述

要使用 Outlines,只需在启动服务器时添加 --grammar-backend outlines。要使用 llguidance,请在启动服务器时添加 --grammar-backend llguidance。如果未指定后端,将使用 XGrammar 作为默认后端。

为了获得更好的输出质量,建议在提示中明确包含指导,引导模型生成所需的格式。 例如,您可以指定:“请按照以下 JSON 格式生成输出:…”

OpenAI 兼容 API#

[1]:
import openai
import os
from sglang.test.test_utils import is_in_ci

if is_in_ci():
    from patch import launch_server_cmd
else:
    from sglang.utils import launch_server_cmd

from sglang.utils import wait_for_server, print_highlight, terminate_process

os.environ["TOKENIZERS_PARALLELISM"] = "false"


server_process, port = launch_server_cmd(
    "python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0"
)

wait_for_server(f"http://localhost:{port}")
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
[2025-05-15 22:31:26] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, host='0.0.0.0', port=34593, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=859282418, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
[2025-05-15 22:31:37] Attention backend not set. Use fa3 backend by default.
[2025-05-15 22:31:37] Init torch distributed begin.
[2025-05-15 22:31:38] Init torch distributed ends. mem usage=0.55 GB
[2025-05-15 22:31:38] Load weight begin. avail mem=65.82 GB
[2025-05-15 22:31:39] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:26<01:20, 26.69s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:29<00:25, 12.80s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [01:01<00:21, 21.54s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [01:17<00:00, 19.26s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [01:17<00:00, 19.37s/it]

[2025-05-15 22:32:58] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=32.14 GB, mem usage=33.68 GB.
[2025-05-15 22:32:58] KV Cache is allocated. #tokens: 20480, K size: 1.25 GB, V size: 1.25 GB
[2025-05-15 22:32:58] Memory pool end. avail mem=29.34 GB
[2025-05-15 22:32:59] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-05-15 22:32:59] INFO:     Started server process [51692]
[2025-05-15 22:32:59] INFO:     Waiting for application startup.
[2025-05-15 22:32:59] INFO:     Application startup complete.
[2025-05-15 22:32:59] INFO:     Uvicorn running on http://0.0.0.0:34593 (Press CTRL+C to quit)
[2025-05-15 22:33:00] INFO:     127.0.0.1:58986 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-15 22:33:00] INFO:     127.0.0.1:45834 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-05-15 22:33:00] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-15 22:33:02] INFO:     127.0.0.1:45850 - "POST /generate HTTP/1.1" 200 OK
[2025-05-15 22:33:02] The server is fired up and ready to roll!


注意:通常,服务器在单独的终端中运行。
在此 Notebook 中,我们将服务器和 Notebook 代码一起运行,因此它们的输出是合并的。
为了提高清晰度,服务器日志以原始的黑色显示,而 Notebook 输出以蓝色突出显示。
我们正在 CI 并行环境中运行这些 Notebook,因此吞吐量不代表实际性能。

JSON#

您可以直接定义 JSON Schema 或使用 Pydantic 来定义和验证响应。

使用 Pydantic

[2]:
from pydantic import BaseModel, Field


# Define the schema using Pydantic
class CapitalInfo(BaseModel):
    name: str = Field(..., pattern=r"^\w+$", description="Name of the capital city")
    population: int = Field(..., description="Population of the capital city")


response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {
            "role": "user",
            "content": "Please generate the information of the capital of France in the JSON format.",
        },
    ],
    temperature=0,
    max_tokens=128,
    response_format={
        "type": "json_schema",
        "json_schema": {
            "name": "foo",
            # convert the pydantic model to json schema
            "schema": CapitalInfo.model_json_schema(),
        },
    },
)

response_content = response.choices[0].message.content
# validate the JSON response by the pydantic model
capital_info = CapitalInfo.model_validate_json(response_content)
print_highlight(f"Validated response: {capital_info.model_dump_json()}")
[2025-05-15 22:33:05] Prefill batch. #new-seq: 1, #new-token: 48, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-15 22:33:06] INFO:     127.0.0.1:45858 - "POST /v1/chat/completions HTTP/1.1" 200 OK
已验证响应: {"name":"巴黎","population":2147000}

直接使用 JSON Schema

[3]:
import json

json_schema = json.dumps(
    {
        "type": "object",
        "properties": {
            "name": {"type": "string", "pattern": "^[\\w]+$"},
            "population": {"type": "integer"},
        },
        "required": ["name", "population"],
    }
)

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {
            "role": "user",
            "content": "Give me the information of the capital of France in the JSON format.",
        },
    ],
    temperature=0,
    max_tokens=128,
    response_format={
        "type": "json_schema",
        "json_schema": {"name": "foo", "schema": json.loads(json_schema)},
    },
)

print_highlight(response.choices[0].message.content)
[2025-05-15 22:33:06] Prefill batch. #new-seq: 1, #new-token: 19, #cached-token: 30, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-15 22:33:06] INFO:     127.0.0.1:45858 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{"name": "巴黎", "population": 2147000}

EBNF#

[4]:
ebnf_grammar = """
root ::= city | description
city ::= "London" | "Paris" | "Berlin" | "Rome"
description ::= city " is " status
status ::= "the capital of " country
country ::= "England" | "France" | "Germany" | "Italy"
"""

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {"role": "system", "content": "You are a helpful geography bot."},
        {
            "role": "user",
            "content": "Give me the information of the capital of France.",
        },
    ],
    temperature=0,
    max_tokens=32,
    extra_body={"ebnf": ebnf_grammar},
)

print_highlight(response.choices[0].message.content)
[2025-05-15 22:33:06] Prefill batch. #new-seq: 1, #new-token: 27, #cached-token: 25, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-15 22:33:06] Decode batch. #running-req: 1, #token: 55, token usage: 0.00, cuda graph: False, gen throughput (token/s): 5.35, #queue-req: 0
[2025-05-15 22:33:06] INFO:     127.0.0.1:45858 - "POST /v1/chat/completions HTTP/1.1" 200 OK
巴黎是法国的首都

正则表达式#

[5]:
response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {"role": "user", "content": "What is the capital of France?"},
    ],
    temperature=0,
    max_tokens=128,
    extra_body={"regex": "(Paris|London)"},
)

print_highlight(response.choices[0].message.content)
[2025-05-15 22:33:06] Prefill batch. #new-seq: 1, #new-token: 12, #cached-token: 30, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-15 22:33:06] INFO:     127.0.0.1:45858 - "POST /v1/chat/completions HTTP/1.1" 200 OK
巴黎

结构化标签#

[6]:
tool_get_current_weather = {
    "type": "function",
    "function": {
        "name": "get_current_weather",
        "description": "Get the current weather in a given location",
        "parameters": {
            "type": "object",
            "properties": {
                "city": {
                    "type": "string",
                    "description": "The city to find the weather for, e.g. 'San Francisco'",
                },
                "state": {
                    "type": "string",
                    "description": "the two-letter abbreviation for the state that the city is"
                    " in, e.g. 'CA' which would mean 'California'",
                },
                "unit": {
                    "type": "string",
                    "description": "The unit to fetch the temperature in",
                    "enum": ["celsius", "fahrenheit"],
                },
            },
            "required": ["city", "state", "unit"],
        },
    },
}

tool_get_current_date = {
    "type": "function",
    "function": {
        "name": "get_current_date",
        "description": "Get the current date and time for a given timezone",
        "parameters": {
            "type": "object",
            "properties": {
                "timezone": {
                    "type": "string",
                    "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'",
                }
            },
            "required": ["timezone"],
        },
    },
}

schema_get_current_weather = tool_get_current_weather["function"]["parameters"]
schema_get_current_date = tool_get_current_date["function"]["parameters"]


def get_messages():
    return [
        {
            "role": "system",
            "content": f"""
# Tool Instructions
- Always execute python code in messages that you share.
- When looking for real time information use relevant functions if available else fallback to brave_search
You have access to the following functions:
Use the function 'get_current_weather' to: Get the current weather in a given location
{tool_get_current_weather["function"]}
Use the function 'get_current_date' to: Get the current date and time for a given timezone
{tool_get_current_date["function"]}
If a you choose to call a function ONLY reply in the following format:
<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}
where
start_tag => `<function`
parameters => a JSON dict with the function argument name as key and function argument value as value.
end_tag => `</function>`
Here is an example,
<function=example_function_name>{{"example_name": "example_value"}}</function>
Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query
You are a helpful assistant.""",
        },
        {
            "role": "user",
            "content": "You are in New York. Please get the current date and time, and the weather.",
        },
    ]


messages = get_messages()

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=messages,
    response_format={
        "type": "structural_tag",
        "structures": [
            {
                "begin": "<function=get_current_weather>",
                "schema": schema_get_current_weather,
                "end": "</function>",
            },
            {
                "begin": "<function=get_current_date>",
                "schema": schema_get_current_date,
                "end": "</function>",
            },
        ],
        "triggers": ["<function="],
    },
)

print_highlight(response.choices[0].message.content)
[2025-05-15 22:33:06] Prefill batch. #new-seq: 1, #new-token: 476, #cached-token: 25, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-15 22:33:07] Decode batch. #running-req: 1, #token: 535, token usage: 0.03, cuda graph: False, gen throughput (token/s): 37.99, #queue-req: 0
[2025-05-15 22:33:07] INFO:     127.0.0.1:45858 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{"timezone": "America/New_York"}
{"city": "New York", "state": "NY", "unit": "fahrenheit"}

原生 API 和 SGLang 运行时 (SRT)#

JSON#

使用 Pydantic

[7]:
import requests
import json
from pydantic import BaseModel, Field

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")


# Define the schema using Pydantic
class CapitalInfo(BaseModel):
    name: str = Field(..., pattern=r"^\w+$", description="Name of the capital city")
    population: int = Field(..., description="Population of the capital city")


# Make API request
messages = [
    {
        "role": "user",
        "content": "Here is the information of the capital of France in the JSON format.\n",
    }
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
response = requests.post(
    f"http://localhost:{port}/generate",
    json={
        "text": text,
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 64,
            "json_schema": json.dumps(CapitalInfo.model_json_schema()),
        },
    },
)
print_highlight(response.json())


response_data = json.loads(response.json()["text"])
# validate the response by the pydantic model
capital_info = CapitalInfo.model_validate(response_data)
print_highlight(f"Validated response: {capital_info.model_dump_json()}")
[2025-05-15 22:33:08] Prefill batch. #new-seq: 1, #new-token: 49, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-15 22:33:09] INFO:     127.0.0.1:45860 - "POST /generate HTTP/1.1" 200 OK
{'text': '{"name": "Paris", "population": 2147000}', 'meta_info': {'id': '0f6615cab8c548cc9883121a97d17937', 'finish_reason': {'type': 'stop', 'matched': 128009}, 'prompt_tokens': 50, 'completion_tokens': 15, 'cached_tokens': 1, 'e2e_latency': 0.26668524742126465}}
已验证响应: {"name":"巴黎","population":2147000}

直接使用 JSON Schema

[8]:
json_schema = json.dumps(
    {
        "type": "object",
        "properties": {
            "name": {"type": "string", "pattern": "^[\\w]+$"},
            "population": {"type": "integer"},
        },
        "required": ["name", "population"],
    }
)

# JSON
response = requests.post(
    f"http://localhost:{port}/generate",
    json={
        "text": text,
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 64,
            "json_schema": json_schema,
        },
    },
)

print_highlight(response.json())
[2025-05-15 22:33:09] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 49, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-15 22:33:09] Decode batch. #running-req: 1, #token: 63, token usage: 0.00, cuda graph: False, gen throughput (token/s): 22.25, #queue-req: 0
[2025-05-15 22:33:09] INFO:     127.0.0.1:45864 - "POST /generate HTTP/1.1" 200 OK
{'text': '{"name": "Paris", "population": 2147000}', 'meta_info': {'id': '403abf3b4d01486187a2d3739303073a', 'finish_reason': {'type': 'stop', 'matched': 128009}, 'prompt_tokens': 50, 'completion_tokens': 15, 'cached_tokens': 49, 'e2e_latency': 0.26096081733703613}}

EBNF#

[9]:
messages = [
    {
        "role": "user",
        "content": "Give me the information of the capital of France.",
    }
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
response = requests.post(
    f"http://localhost:{port}/generate",
    json={
        "text": text,
        "sampling_params": {
            "max_new_tokens": 128,
            "temperature": 0,
            "n": 3,
            "ebnf": (
                "root ::= city | description\n"
                'city ::= "London" | "Paris" | "Berlin" | "Rome"\n'
                'description ::= city " is " status\n'
                'status ::= "the capital of " country\n'
                'country ::= "England" | "France" | "Germany" | "Italy"'
            ),
        },
        "stream": False,
        "return_logprob": False,
    },
)

print_highlight(response.json())
[2025-05-15 22:33:09] Prefill batch. #new-seq: 1, #new-token: 15, #cached-token: 31, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-15 22:33:09] Prefill batch. #new-seq: 3, #new-token: 3, #cached-token: 135, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-15 22:33:09] INFO:     127.0.0.1:45878 - "POST /generate HTTP/1.1" 200 OK
[{'text': '巴黎是法国的首都', 'meta_info': {'id': '8c8605ff54504434b42b5b61287d894f', 'finish_reason': {'type': 'stop', 'matched': 128009}, 'prompt_tokens': 46, 'completion_tokens': 7, 'cached_tokens': 45, 'e2e_latency': 0.3752317428588867}}, {'text': '巴黎是法国的首都', 'meta_info': {'id': '3c4999e791414edf80bfe9e9691dce2b', 'finish_reason': {'type': 'stop', 'matched': 128009}, 'prompt_tokens': 46, 'completion_tokens': 7, 'cached_tokens': 45, 'e2e_latency': 0.37523746490478516}}, {'text': '巴黎是法国的首都', 'meta_info': {'id': 'f2b062b02ed64a9cbc2d4bff516e9032', 'finish_reason': {'type': 'stop', 'matched': 128009}, 'prompt_tokens': 46, 'completion_tokens': 7, 'cached_tokens': 45, 'e2e_latency': 0.3752415180206299}}]

正则表达式#

[10]:
messages = [
    {
        "role": "user",
        "content": "Paris is the capital of",
    }
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
response = requests.post(
    f"http://localhost:{port}/generate",
    json={
        "text": text,
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 64,
            "regex": "(France|England)",
        },
    },
)
print_highlight(response.json())
[2025-05-15 22:33:09] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 31, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-15 22:33:09] INFO:     127.0.0.1:45890 - "POST /generate HTTP/1.1" 200 OK
{'text': '法国', 'meta_info': {'id': '8996ff6691bc4a549c0d16c4144fe354', 'finish_reason': {'type': 'stop', 'matched': 128009}, 'prompt_tokens': 41, 'completion_tokens': 2, 'cached_tokens': 31, 'e2e_latency': 0.06646585464477539}}

结构化标签#

[11]:
from transformers import AutoTokenizer

# generate an answer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")

text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
payload = {
    "text": text,
    "sampling_params": {
        "structural_tag": json.dumps(
            {
                "type": "structural_tag",
                "structures": [
                    {
                        "begin": "<function=get_current_weather>",
                        "schema": schema_get_current_weather,
                        "end": "</function>",
                    },
                    {
                        "begin": "<function=get_current_date>",
                        "schema": schema_get_current_date,
                        "end": "</function>",
                    },
                ],
                "triggers": ["<function="],
            }
        )
    },
}


# Send POST request to the API endpoint
response = requests.post(f"http://localhost:{port}/generate", json=payload)
print_highlight(response.json())
[2025-05-15 22:33:10] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 40, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-15 22:33:11] INFO:     127.0.0.1:54236 - "POST /generate HTTP/1.1" 200 OK
{'text': '巴黎是法国的首都。', 'meta_info': {'id': 'db75114db6f443e68e4d531d1ba86c47', 'finish_reason': {'type': 'stop', 'matched': 128009}, 'prompt_tokens': 41, 'completion_tokens': 8, 'cached_tokens': 40, 'e2e_latency': 0.16428756713867188}}
[12]:
terminate_process(server_process)

离线引擎 API#

[13]:
import sglang as sgl

llm = sgl.Engine(
    model_path="meta-llama/Meta-Llama-3.1-8B-Instruct", grammar_backend="xgrammar"
)
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:02<00:06,  2.23s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:03<00:02,  1.45s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:05<00:01,  1.77s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:07<00:00,  2.01s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:07<00:00,  1.92s/it]

JSON#

使用 Pydantic

[14]:
import json
from pydantic import BaseModel, Field


prompts = [
    "Give me the information of the capital of China in the JSON format.",
    "Give me the information of the capital of France in the JSON format.",
    "Give me the information of the capital of Ireland in the JSON format.",
]


# Define the schema using Pydantic
class CapitalInfo(BaseModel):
    name: str = Field(..., pattern=r"^\w+$", description="Name of the capital city")
    population: int = Field(..., description="Population of the capital city")


sampling_params = {
    "temperature": 0.1,
    "top_p": 0.95,
    "json_schema": json.dumps(CapitalInfo.model_json_schema()),
}

outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
    print_highlight("===============================")
    print_highlight(f"Prompt: {prompt}")  # validate the output by the pydantic model
    capital_info = CapitalInfo.model_validate_json(output["text"])
    print_highlight(f"Validated output: {capital_info.model_dump_json()}")
===============================
Prompt: 给我中国首都的信息,格式为 JSON。
已验证输出: {"name":"北京","population":21500000}
===============================
Prompt: 给我法国首都的信息,格式为 JSON。
已验证输出: {"name":"巴黎","population":2141000}
===============================
Prompt: 给我爱尔兰首都的信息,格式为 JSON。
已验证输出: {"name":"都柏林","population":527617}

直接使用 JSON Schema

[15]:
prompts = [
    "Give me the information of the capital of China in the JSON format.",
    "Give me the information of the capital of France in the JSON format.",
    "Give me the information of the capital of Ireland in the JSON format.",
]

json_schema = json.dumps(
    {
        "type": "object",
        "properties": {
            "name": {"type": "string", "pattern": "^[\\w]+$"},
            "population": {"type": "integer"},
        },
        "required": ["name", "population"],
    }
)

sampling_params = {"temperature": 0.1, "top_p": 0.95, "json_schema": json_schema}

outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
    print_highlight("===============================")
    print_highlight(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: 给我中国首都的信息,格式为 JSON。
生成文本: {"name": "北京", "population": 21500000}
===============================
Prompt: 给我法国首都的信息,格式为 JSON。
生成文本: {"name": "巴黎", "population": 2141000}
===============================
Prompt: 给我爱尔兰首都的信息,格式为 JSON。
生成文本: {"name": "都柏林", "population": 527617}

EBNF#

[16]:
prompts = [
    "Give me the information of the capital of France.",
    "Give me the information of the capital of Germany.",
    "Give me the information of the capital of Italy.",
]

sampling_params = {
    "temperature": 0.8,
    "top_p": 0.95,
    "ebnf": (
        "root ::= city | description\n"
        'city ::= "London" | "Paris" | "Berlin" | "Rome"\n'
        'description ::= city " is " status\n'
        'status ::= "the capital of " country\n'
        'country ::= "England" | "France" | "Germany" | "Italy"'
    ),
}

outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
    print_highlight("===============================")
    print_highlight(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: 给我法国首都的信息。
生成文本: 巴黎是法国的首都
===============================
Prompt: 给我德国首都的信息。
生成文本: 柏林是德国的首都
===============================
Prompt: 给我意大利首都的信息。
生成文本: 巴黎是意大利的首都

正则表达式#

[17]:
prompts = [
    "Please provide information about London as a major global city:",
    "Please provide information about Paris as a major global city:",
]

sampling_params = {"temperature": 0.8, "top_p": 0.95, "regex": "(France|England)"}

outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
    print_highlight("===============================")
    print_highlight(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: 请提供关于伦敦作为主要国际城市的信息
生成文本: 英格兰
===============================
Prompt: 请提供关于巴黎作为主要国际城市的信息
生成文本: 法国

结构化标签#

[18]:
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
prompts = [text]


sampling_params = {
    "temperature": 0.8,
    "top_p": 0.95,
    "structural_tag": json.dumps(
        {
            "type": "structural_tag",
            "structures": [
                {
                    "begin": "<function=get_current_weather>",
                    "schema": schema_get_current_weather,
                    "end": "</function>",
                },
                {
                    "begin": "<function=get_current_date>",
                    "schema": schema_get_current_date,
                    "end": "</function>",
                },
            ],
            "triggers": ["<function="],
        }
    ),
}


# Send POST request to the API endpoint
outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
    print_highlight("===============================")
    print_highlight(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: <|begin_of_text|><|start_header_id|>system<|end_header_id|>

知识截止日期:2023 年 12 月
今日日期:2024 年 7 月 26 日

<|eot_id|><|start_header_id|>user<|end_header_id|>

巴黎是<|eot_id|><|start_header_id|>assistant<|end_header_id|>


生成文本: 法国。
[19]:
llm.shutdown()