构建 SQL 智能体¶
在本教程中,我们将演示如何构建一个可以回答关于 SQL 数据库问题的智能体。
在高层次上,智能体将:
- 从数据库获取可用的表
- 决定哪些表与问题相关
- 获取相关表的模式
- 基于问题和模式信息生成查询
- 使用 LLM 检查查询中常见的错误
- 执行查询并返回结果
- 纠正数据库引擎发现的错误,直到查询成功
- 基于结果制定响应
安全注意事项
构建 SQL 数据库的问答系统需要执行模型生成的 SQL 查询。这样做存在固有的风险。请确保你的数据库连接权限始终尽可能缩小到智能体需要的范围。这将减轻(但不会消除)构建模型驱动系统的风险。
1. 设置¶
首先让我们安装一些依赖项。本教程使用来自 langchain-community 的 SQL 数据库和工具抽象。我们还需要一个 LangChain 聊天模型。
Tip
注册 LangSmith 以快速发现问题并提高 LangGraph 项目的性能。LangSmith 让你可以使用追踪数据来调试、测试和监控用 LangGraph 构建的 LLM 应用程序。
选择 LLM¶
首先我们初始化我们的 LLM。任何支持工具调用的模型都应该可以工作。我们在下面使用 OpenAI。
配置数据库¶
我们将为本教程创建一个 SQLite 数据库。SQLite 是一个轻量级数据库,易于设置和使用。我们将加载 chinook 数据库,这是一个代表数字媒体商店的示例数据库。
在这里找到更多关于该数据库的信息。
为了方便,我们在公共 GCS 存储桶上托管了该数据库(Chinook.db)。
import requests
url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
response = requests.get(url)
if response.status_code == 200:
# 以二进制写入模式打开本地文件
with open("Chinook.db", "wb") as file:
# 将响应的内容(文件)写入本地文件
file.write(response.content)
print("File downloaded and saved as Chinook.db")
else:
print(f"Failed to download the file. Status code: {response.status_code}")
我们将使用 langchain_community 包中提供的便捷 SQL 数据库包装器来与数据库交互。该包装器提供了一个简单的接口来执行 SQL 查询和获取结果:
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(f"Dialect: {db.dialect}")
print(f"Available tables: {db.get_usable_table_names()}")
print(f'Sample output: {db.run("SELECT * FROM Artist LIMIT 5;")}')
输出:
Dialect: sqlite
Available tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
Sample output: [(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains')]
数据库交互工具¶
langchain-community 实现了一些用于与我们的 SQLDatabase 交互的内置工具,包括用于列出表、读取表模式以及检查和运行查询的工具:
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()
for tool in tools:
print(f"{tool.name}: {tool.description}\n")
输出:
sql_db_query: Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.
sql_db_schema: Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3
sql_db_list_tables: Input is an empty string, output is a comma-separated list of tables in the database.
sql_db_query_checker: Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!
2. 使用预构建智能体¶
给定这些工具,我们可以在一行中初始化一个预构建的智能体。为了自定义我们的智能体行为,我们编写一个描述性的系统提示。
from langgraph.prebuilt import create_react_agent
system_prompt = """
你是一个专为与 SQL 数据库交互而设计的智能体。
针对输入的问题,创建一个语法正确的 {dialect} 查询来运行,
然后查看查询结果并返回答案。除非用户
指定了希望获得的示例数量,否则始终将查询
限制为最多 {top_k} 个结果。
你可以按相关列对结果进行排序,以返回数据库中最有趣的
示例。永远不要查询特定表的所有列,
只查询问题相关的列。
你必须在执行查询之前仔细检查查询。如果执行
查询时出错,请重写查询并重试。
不要对数据库执行任何 DML 语句(INSERT、UPDATE、DELETE、DROP 等)。
首先你应该始终查看数据库中的表,以了解可以
查询什么。不要跳过此步骤。
然后你应该查询最相关表的模式。
""".format(
dialect=db.dialect,
top_k=5,
)
agent = create_react_agent(
llm,
tools,
prompt=system_prompt,
)
Note
此系统提示包含许多指令,例如始终在其他工具之前或之后运行特定工具。在下一节中,我们将通过图的结构来强制执行这些行为,为我们提供更大程度的控制并允许我们简化提示。
让我们在一个示例查询上运行这个智能体并观察它的行为:
question = "Which genre on average has the longest tracks?"
for step in agent.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
输出:
================================ Human Message =================================
Which genre on average has the longest tracks?
================================== Ai Message ==================================
Tool Calls:
sql_db_list_tables (call_d8lCgywSroCgpVl558nmXKwA)
Call ID: call_d8lCgywSroCgpVl558nmXKwA
Args:
================================= Tool Message =================================
Name: sql_db_list_tables
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
Tool Calls:
sql_db_schema (call_nNf6IIUcwMYLIkE0l6uWkZHe)
Call ID: call_nNf6IIUcwMYLIkE0l6uWkZHe
Args:
table_names: Genre, Track
================================= Tool Message =================================
Name: sql_db_schema
CREATE TABLE "Genre" (
"GenreId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("GenreId")
)
/*
3 rows from Genre table:
GenreId Name
1 Rock
2 Jazz
3 Metal
*/
CREATE TABLE "Track" (
"TrackId" INTEGER NOT NULL,
"Name" NVARCHAR(200) NOT NULL,
"AlbumId" INTEGER,
"MediaTypeId" INTEGER NOT NULL,
"GenreId" INTEGER,
"Composer" NVARCHAR(220),
"Milliseconds" INTEGER NOT NULL,
"Bytes" INTEGER,
"UnitPrice" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("TrackId"),
FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"),
FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"),
FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)
/*
3 rows from Track table:
TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice
1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99
2 Balls to the Wall 2 2 1 None 342562 5510424 0.99
3 Fast As a Shark 3 2 1 F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman 230619 3990994 0.99
*/
================================== Ai Message ==================================
Tool Calls:
sql_db_query_checker (call_urTRmtiGtTxkwHtscec7Fd2K)
Call ID: call_urTRmtiGtTxkwHtscec7Fd2K
Args:
query: SELECT Genre.Name, AVG(Track.Milliseconds) AS AvgMilliseconds
FROM Track
JOIN Genre ON Track.GenreId = Genre.GenreId
GROUP BY Genre.Name
ORDER BY AvgMilliseconds DESC
LIMIT 1;
================================= Tool Message =================================
Name: sql_db_query_checker
\`\`\`sql
SELECT Genre.Name, AVG(Track.Milliseconds) AS AvgMilliseconds
FROM Track
JOIN Genre ON Track.GenreId = Genre.GenreId
GROUP BY Genre.Name
ORDER BY AvgMilliseconds DESC
LIMIT 1;
\`\`\`
================================== Ai Message ==================================
Tool Calls:
sql_db_query (call_RNMqyUEMv0rvy0UxSwrXY2AV)
Call ID: call_RNMqyUEMv0rvy0UxSwrXY2AV
Args:
query: SELECT Genre.Name, AVG(Track.Milliseconds) AS AvgMilliseconds
FROM Track
JOIN Genre ON Track.GenreId = Genre.GenreId
GROUP BY Genre.Name
ORDER BY AvgMilliseconds DESC
LIMIT 1;
================================= Tool Message =================================
Name: sql_db_query
[('Sci Fi & Fantasy', 2911783.0384615385)]
================================== Ai Message ==================================
The genre with the longest average track length is "Sci Fi & Fantasy," with an average duration of about 2,911,783 milliseconds (approximately 48.5 minutes) per track.
这运行得相当好:智能体正确地列出了表,获取了模式,编写了查询,检查了查询,并运行它以提供最终响应。
Tip
你可以在 LangSmith 追踪中检查上述运行的所有方面,包括采取的步骤、调用的工具、LLM 看到的提示等。
3. 自定义智能体¶
预构建的智能体让我们快速开始,但在每一步智能体都可以访问完整的工具集。上面,我们依赖系统提示来约束其行为——例如,我们指示智能体始终从"列出表"工具开始,并始终在执行查询之前运行查询检查器工具。
我们可以通过自定义智能体在 LangGraph 中强制执行更高程度的控制。下面,我们实现了一个简单的 ReAct-智能体设置,具有用于特定工具调用的专用节点。我们将使用与预构建智能体相同的状态。
我们为以下步骤构建专用节点:
- 列出数据库表
- 调用"获取模式"工具
- 生成查询
- 检查查询
将这些步骤放在专用节点中让我们(1)在需要时强制工具调用,以及(2)自定义与每个步骤关联的提示。
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
get_schema_node = ToolNode([get_schema_tool], name="get_schema")
run_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")
run_query_node = ToolNode([run_query_tool], name="run_query")
# 示例:创建预定的工具调用
def list_tables(state: MessagesState):
tool_call = {
"name": "sql_db_list_tables",
"args": {},
"id": "abc123",
"type": "tool_call",
}
tool_call_message = AIMessage(content="", tool_calls=[tool_call])
list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
tool_message = list_tables_tool.invoke(tool_call)
response = AIMessage(f"Available tables: {tool_message.content}")
return {"messages": [tool_call_message, tool_message, response]}
# 示例:强制模型创建工具调用
def call_get_schema(state: MessagesState):
# 注意,LangChain 强制所有模型接受 `tool_choice="any"`
# 以及 `tool_choice=<工具名称字符串>`。
llm_with_tools = llm.bind_tools([get_schema_tool], tool_choice="any")
response = llm_with_tools.invoke(state["messages"])
return {"messages": [response]}
generate_query_system_prompt = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
""".format(
dialect=db.dialect,
top_k=5,
)
def generate_query(state: MessagesState):
system_message = {
"role": "system",
"content": generate_query_system_prompt,
}
# 我们不会在这里强制工具调用,以允许模型在获得解决方案时自然地响应。
llm_with_tools = llm.bind_tools([run_query_tool])
response = llm_with_tools.invoke([system_message] + state["messages"])
return {"messages": [response]}
check_query_system_prompt = """
你是一位注重细节的 SQL 专家。
仔细检查 {dialect} 查询中的常见错误,包括:
- 使用 NOT IN 处理 NULL 值
- 应该使用 UNION ALL 时却使用了 UNION
- 使用 BETWEEN 处理排他性范围
- 谓词中的数据类型不匹配
- 正确引用标识符
- 使用正确数量的函数参数
- 转换为正确的数据类型
- 使用正确的连接列
如果存在上述任何错误,请重写查询。如果没有错误,
只需重现原始查询。
运行此检查后,你将调用适当的工具来执行查询。
""".format(dialect=db.dialect)
def check_query(state: MessagesState):
system_message = {
"role": "system",
"content": check_query_system_prompt,
}
# 生成人工用户消息以检查
tool_call = state["messages"][-1].tool_calls[0]
user_message = {"role": "user", "content": tool_call["args"]["query"]}
llm_with_tools = llm.bind_tools([run_query_tool], tool_choice="any")
response = llm_with_tools.invoke([system_message, user_message])
response.id = state["messages"][-1].id
return {"messages": [response]}
最后,我们使用 Graph API 将这些步骤组装成一个工作流。我们在查询生成步骤定义一个条件边,如果生成了查询,将路由到查询检查器;如果没有工具调用,则结束,这样 LLM 就已经提供了对查询的响应。
def should_continue(state: MessagesState) -> Literal[END, "check_query"]:
messages = state["messages"]
last_message = messages[-1]
if not last_message.tool_calls:
return END
else:
return "check_query"
builder = StateGraph(MessagesState)
builder.add_node(list_tables)
builder.add_node(call_get_schema)
builder.add_node(get_schema_node, "get_schema")
builder.add_node(generate_query)
builder.add_node(check_query)
builder.add_node(run_query_node, "run_query")
builder.add_edge(START, "list_tables")
builder.add_edge("list_tables", "call_get_schema")
builder.add_edge("call_get_schema", "get_schema")
builder.add_edge("get_schema", "generate_query")
builder.add_conditional_edges(
"generate_query",
should_continue,
)
builder.add_edge("check_query", "run_query")
builder.add_edge("run_query", "generate_query")
agent = builder.compile()
我们在下面可视化应用:
from IPython.display import Image, display
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeStyles
display(Image(agent.get_graph().draw_mermaid_png()))

注意: 当你运行这段代码时,它将生成并显示 SQL 智能体图的可视化表示,展示不同节点之间的流程(list_tables → call_get_schema → get_schema → generate_query → check_query → run_query)。
我们现在可以像以前一样调用图:
question = "Which genre on average has the longest tracks?"
for step in agent.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
输出:
================================ Human Message =================================
Which genre on average has the longest tracks?
================================== Ai Message ==================================
Available tables: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
Tool Calls:
sql_db_schema (call_qxKtYiHgf93AiTDin9ez5wFp)
Call ID: call_qxKtYiHgf93AiTDin9ez5wFp
Args:
table_names: Genre,Track
================================= Tool Message =================================
Name: sql_db_schema
CREATE TABLE "Genre" (
"GenreId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("GenreId")
)
/*
3 rows from Genre table:
GenreId Name
1 Rock
2 Jazz
3 Metal
*/
CREATE TABLE "Track" (
"TrackId" INTEGER NOT NULL,
"Name" NVARCHAR(200) NOT NULL,
"AlbumId" INTEGER,
"MediaTypeId" INTEGER NOT NULL,
"GenreId" INTEGER,
"Composer" NVARCHAR(220),
"Milliseconds" INTEGER NOT NULL,
"Bytes" INTEGER,
"UnitPrice" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("TrackId"),
FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"),
FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"),
FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)
/*
3 rows from Track table:
TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice
1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99
2 Balls to the Wall 2 2 1 None 342562 5510424 0.99
3 Fast As a Shark 3 2 1 F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman 230619 3990994 0.99
*/
================================== Ai Message ==================================
Tool Calls:
sql_db_query (call_RPN3GABMfb6DTaFTLlwnZxVN)
Call ID: call_RPN3GABMfb6DTaFTLlwnZxVN
Args:
query: SELECT Genre.Name, AVG(Track.Milliseconds) AS AvgTrackLength
FROM Track
JOIN Genre ON Track.GenreId = Genre.GenreId
GROUP BY Genre.GenreId
ORDER BY AvgTrackLength DESC
LIMIT 1;
================================== Ai Message ==================================
Tool Calls:
sql_db_query (call_PR4s8ymiF3ZQLaoZADXtdqcl)
Call ID: call_PR4s8ymiF3ZQLaoZADXtdqcl
Args:
query: SELECT Genre.Name, AVG(Track.Milliseconds) AS AvgTrackLength
FROM Track
JOIN Genre ON Track.GenreId = Genre.GenreId
GROUP BY Genre.GenreId
ORDER BY AvgTrackLength DESC
LIMIT 1;
================================= Tool Message =================================
Name: sql_db_query
[('Sci Fi & Fantasy', 2911783.0384615385)]
================================== Ai Message ==================================
The genre with the longest tracks on average is "Sci Fi & Fantasy," with an average track length of approximately 2,911,783 milliseconds.
Tip
查看上述运行的 LangSmith 追踪。
下一步¶
查看此指南以使用 LangSmith 评估 LangGraph 应用程序,包括像这样的 SQL 智能体。