



















Fuser/compose_end_to_end.py 是 Fuser 管道中的最后一个关键步骤,它将分散的、针对特定子图优化的 Triton 内核无缝地整合成一个单一的、高性能的端到端 Triton 内核,同时确保其功能与原始 PyTorch 实现的数值等价性。
Composer的架构图如下,其功能概括是一句话:把所有验证通过的子图内核+原问题喂给LLM拼成单文件。也会把错误日志让LLM修,最多重试 max_iters轮,并自动 patch 常见 Triton 陷阱。

Composer的核心功能如下:
kernel_function、禁止 PyTorch 计算逻辑)和数值校验,确保生成的 Triton 内核可用、正确。| 特色方向 | 具体说明 |
|---|---|
| 严格的代码约束 | 强制要求生成的代码包含 kernel_function 顶层函数(与原始模型输入一致)、@triton.jit 内核、自测函数(输出 PASS/0 退出码);禁止内核中使用 PyTorch 计算逻辑(仅允许自测时对比)。 |
| 智能错误迭代 | 捕获编译 / 运行错误(stderr/stdout),构建精细化 Prompt 让 LLM 定位并修正问题(如 Triton 常见的 tl.broadcast 误用)。 |
| 自动补丁修复 | 内置 Triton 常见问题的文本补丁(如替换 tl.broadcast(0.0) 为 0.0),减少无意义的 LLM 迭代。 |
| 完整的日志留存 | 保存每一轮的 Prompt、生成的代码、验证结果,便于调试和追溯生成过程。 |
| 数值等价性保障 | 要求自测函数使用 allclose 校验数值(fp32: rtol≤1e-3/atol≤1e-3;fp16/bf16: ≤2e-2),确保 Triton 实现与 PyTorch 结果一致。 |


compose_end_to_end.py 会合成(compose)一个端到端的 Triton 内核,用来解决原始 KernelBench 问题。
compose_end_to_end.py 会将原始问题文件、子图分解信息 + 各张量形状(来自 subgraphs.json)、以及已生成的子图 Triton 内核作为输入,构建提示(prompt)发送给 LLM。然后LLM 会把这些碎片拼成一个语义与原始问题完全一致的完整内核,最终返回一个 Python 文件(/composed_kernel.py),里面提供:
@triton.jit 装饰的 Triton 内核。kernel_function(...) 的顶层 Python 包装函数,它接受与原始模型相同的输入张量,并协调 Triton 内核的执行,返回最终输出。test_kernel 或 run_tests),该函数比较 Triton 实现的结果与原始 PyTorch 问题代码的参考结果,并在成功时打印 'PASS' 并退出。composition_summary.json)。compose_end_to_end.py 用法如下:
python -m Fuser.compose_end_to_end \
--problem /abs/path/to/kernelbench_problem.py \
--subgraphs /abs/path/to/subgraphs.json \
--kernels-summary /abs/path/to/kernels_out/summary.json \
[--model gpt-5] \
[--out-dir ./compose_out] \
[--verify]
可以通过 --verify 标志启用自动验证。在此模式下,每个 LLM 生成的组合尝试(无论是初始的还是经过修正的)都会被传递给 Fuser/runner.py 中的 run_candidate 函数来执行。
验证的成功与否取决于执行是否正常退出(exit code 0)并且输出中包含 'PASS' 字符串或 ALL_TESTS_PASSED 字符串。

如果 LLM 生成的第一个组合内核无法通过验证(运行 / 编译失败),该脚本会捕获错误信息(stderr,stdout)。它会构建一个新的提示(_build_refinement_prompt),将错误信息作为上下文提供给 LLM,要求其修正代码。这个过程可以重复多次(max_iters 参数控制最大迭代次数),直到生成的内核通过验证或达到最大迭代次数。
_load_kernels_from_summary 为构建 prompt 提供 「代码素材」(各子图的有效 Triton 内核代码),具体而言,_load_kernels_from_summary从调度阶段生成的内核汇总 JSON 文件中,过滤并加载所有成功生成的有效子图 Triton 内核,校验数据格式与文件有效性,封装为标准化 KernelItem 对象列表,为 LLM 组合内核提供可直接复用的有效代码素材,过滤失败、无效的内核产物,避免无效素材干扰后续组合流程。
其特殊如下:
多维度有效内核过滤:依次过滤「非列表格式汇总数据、非字典格式子项、标记为失败的内核、无 ID / 无内核路径的子项、内核文件不存在的子项」,仅保留全量校验通过的成功内核,从源头保证代码素材的有效性;
关键字段强制校验:子图 ID(sid)和内核文件路径(kernel_path)为必选字段,缺失任一则直接过滤,确保每个有效内核都能关联到唯一子图且存在实际代码文件;
标准化对象封装:将内核的子图 ID、文件路径、代码内容封装为 KernelItem 对象,而非原始字典 / 字符串,提升后续代码处理的可读性与可维护性;
无有效内核直接终止:若汇总文件中无任何有效内核,直接抛出 SystemExit 异常终止流程,避免后续流程因无有效素材而无意义执行;
兼容调度阶段输出格式:严格适配 dispatch 步骤生成的 summary.json 格式,实现上下游流程的无缝衔接。
以下三个函数是 PyTorch KernelAgent 中LLM 生成端到端 Triton 内核的 Prompt 构建核心模块,为 LLM 提供标准化、高指导性、场景适配的精准输入提示,是连接「分散的子图 / 内核 / 问题数据」与「LLM 可理解的生成指令」的关键枢纽。其中
_summarize_subgraphs_for_prompt 为基础数据处理函数,负责将子图信息格式化;_build_composition_prompt 和 _build_refinement_prompt 为双 Prompt 构建主函数,分别支撑首次端到端内核组合生成和基于错误的迭代精修两大核心场景,通过严格的指令约束、完整的上下文信息、针对性的优化指导,确保 LLM 生成符合工程要求、硬件适配、可直接运行的 Triton 内核代码。kernel_function 函数名、自测试函数及「PASS」打印 / 退出码规则,确保精修后的代码能无缝对接后续的自动化验证流程,无需修改验证逻辑;_summarize_subgraphs_for_prompt 提供 「逻辑约束」(各子图的功能、形状、布局等约束信息)。具体而言,_summarize_subgraphs_for_prompt:对模型分解后的子图信息列表进行结构化、简洁化的文本汇总,提取子图 ID、类型、数据布局、数据类型、输入输出形状、核心算子等关键信息,按统一格式拼接为易读的文本字符串,为构建 LLM 组合 Prompt 提供标准化的子图信息描述,让 LLM 快速理解各子图的功能、形状约束与计算要求。
inputs 字段描述输入,无 inputs 则降级为 input_shape,适配不同子图分解工具的输出格式差异,保证输入输出形状信息的有效传递。三个函数形成 「基础数据处理→首次生成 Prompt 构建→迭代精修 Prompt 构建」 的层级支撑关系,为 LLM 生成端到端 Triton 内核提供全流程的 Prompt 支撑:
_summarize_subgraphs_for_prompt 对原始子图信息做统一格式化处理,生成标准化子图摘要,为上层两个 Prompt 构建函数提供一致、易解析的子图约束信息,实现数据的一次处理、多次复用;_build_composition_prompt 基于标准化子图摘要,整合问题代码、子图内核、平台配置,构建全量约束的组合 Prompt,指导 LLM 完成首次端到端 Triton 内核生成,是「从无到有」的核心指导;_build_refinement_prompt 复用标准化子图摘要与核心基础信息,新增错误日志和前次失败代码,构建错误导向的精修 Prompt,指导 LLM 完成「从错到对」的迭代优化,是提升代码有效性的关键;def _build_composition_prompt(
problem_code: str,
subgraphs: list[dict[str, Any]],
kernel_items: list[KernelItem],
target_platform: PlatformConfig,
) -> str:
"""Create a single user message to instruct composition by the LLM.
构建初始合成Prompt:
为LLM生成结构化指令,引导其基于子图和参考内核,合成端到端的Triton算子代码
参数:
- problem_code: 原始KernelBench问题代码(PyTorch)
- subgraphs: 子图信息列表(JSON格式)
- kernel_items: 参考内核列表(包含子图ID和对应的Triton代码)
- target_platform: 目标平台配置(如CUDA/XPU)
返回值:完整的LLM用户指令字符串
"""
# 第一步:生成子图摘要(压缩子图信息,控制Token消耗,便于LLM快速理解核心特征)
sg_summary = _summarize_subgraphs_for_prompt(subgraphs)
# 第二步:构建参考内核代码区块(仅保留核心代码,避免Token溢出)
# 注释说明:暂时保留完整文件内容,调用方可根据模型窗口限制进一步裁剪
# 初始化内核区块的文本片段列表
kernels_section_parts: list[str] = []
# 遍历每个参考内核项
for ki in kernel_items:
# 为每个子图内核构建带格式的代码片段(Markdown Python代码块)
kernels_section_parts.append(
f"### Subgraph {ki.subgraph_id}\n```python\n" + ki.code + "\n```\n"
)
# 拼接所有内核片段,形成完整的参考内核区块
kernels_section = "\n".join(kernels_section_parts)
# 第三步:获取平台专属的指导规则(如CUDA的内存访问规则、XPU的编译要求)
platform_guidance = target_platform.guidance_block
# 第四步:构建核心指导语(包含任务背景、平台信息、硬性要求、实现技巧)
# 使用textwrap.dedent去除缩进,保证Prompt格式整洁
guidance = textwrap.dedent(
f"""
You are given:
- The original problem file (PyTorch module and helpers).
- A decomposition of the model into fusable subgraphs with exact shapes.
- Working Triton kernels generated for some subgraphs.
TARGET PLATFORM: {target_platform.name}
DEVICE STRING: {target_platform.device_string}
{platform_guidance}
Task:
- Compose an end-to-end Triton implementation that matches the original
model's forward pass for the provided shapes. You may inline, adapt,
or reuse the given subgraph kernels. Prefer fusing into as few kernel
launches as possible while preserving exact numerical semantics.
Hard requirements:
- Return ONE complete Python file only, fenced as a single ```python block.
- Allocate inputs, weights, intermediates, and outputs on device='{target_platform.device_string}' and keep them there throughout forward/verification.
- CPU is acceptable only for metadata, scalars, and export serialization—avoid `.cpu()` or `.to('cpu')` on compute tensors.
- Provide at least one @triton.jit kernel and a top-level Python wrapper
named kernel_function(...). This wrapper must accept the same primary
input tensor(s) as the model and any required weights/biases with shapes
implied by the problem; it should orchestrate Triton kernel(s) and
return the final output tensor.
- No PyTorch math path: kernel_function MUST compute the final outputs
using your Triton kernels only. Do NOT implement or fall back to
torch.nn / torch.nn.functional / torch.* ops
sigmoid, etc.) for producing the final result. Using PyTorch for
reference comparisons is allowed only inside the self-test.
- Use the data layout and dtype semantics indicated by subgraphs, defaulting
to NCHW + float32 if unspecified. Respect stride/padding/dilation/groups,
and exact op order.
- Numerical equivalence: include a self-test (test_kernel or run_tests)
that compares your Triton-based result to a PyTorch reference computed
from the original problem code below (use get_init_inputs() and
get_inputs() if present to instantiate the Model). The test must print
'PASS' on success and exit with code 0. Use allclose with rtol<=1e-3,
atol<=1e-3 for fp32; for fp16/bf16 allow up to 2e-2.
- No imports beyond torch, triton, triton.language as tl, and stdlib. No I/O.
- Do NOT monkey-patch PyTorch device functions or torch.cuda.is_available()
- Do NOT manipulate TRITON_BACKENDS environment variable
- Do NOT disable or mock XPU/CUDA drivers
Implementation tips:
- If merging multiple subgraphs, ensure intermediate tensor shapes match.
- Hoist constant weights or parameters to avoid reloading per block.
- Use tl.load/tl.store with masks for boundary conditions.
- Favor coalesced memory access; tile by blocks; compute grid from shape.
- Common Triton pitfalls to avoid:
* Do NOT call tl.broadcast on Python scalars; tl.maximum(x, 0.0) works.
* Prefer scalar constants directly in elementwise ops (no explicit broadcast needed).
* Keep BLOCK_SIZE power-of-two; mask stores at tail.
"""
).strip() # 去除首尾空白字符
# 第五步:拼接完整的用户指令(按逻辑组织各部分内容)
user_lines: list[str] = []
user_lines.append(guidance) # 核心指导语
user_lines.append("") # 空行分隔
user_lines.append("SUBGRAPHS (summary):") # 子图摘要标题
user_lines.append(sg_summary) # 子图摘要内容
user_lines.append("") # 空行分隔
user_lines.append("ORIGINAL PROBLEM FILE:") # 原始问题代码标题
user_lines.append("```python") # Python代码块开始标记
user_lines.append(problem_code) # 原始问题代码内容
user_lines.append("```") # Python代码块结束标记
user_lines.append("") # 空行分隔
user_lines.append("SUBGRAPH KERNELS (reference implementations):") # 参考内核标题
user_lines.append(kernels_section) # 参考内核代码内容
user_lines.append("") # 空行分隔
# 最终要求:仅返回一个包含完整代码的Python代码块
user_lines.append(
"Return only one fenced Python code block with your final composed implementation."
)
# 拼接所有行,形成完整的Prompt
return "\n".join(user_lines)
def _build_refinement_prompt(
problem_code: str,
subgraphs: list[dict[str, Any]],
kernel_items: list[KernelItem],
previous_code: str,
error_info: dict[str, str],
target_platform: PlatformConfig,
) -> str:
"""Prompt the LLM to refine the previously produced code based on errors.
构建迭代优化Prompt:
基于上一轮代码的错误信息,引导LLM修复Triton算子代码中的编译/运行错误
参数:
- previous_code: 上一轮生成的错误代码
- error_info: 错误信息字典(包含stderr_tail/stdout_tail)
其他参数同_build_composition_prompt
返回值:针对性的优化指令字符串
"""
# 提取错误日志尾部(最后2000字符,聚焦核心错误)
err_tail = error_info.get("stderr_tail", "")
# 提取标准输出尾部(辅助分析错误)
out_tail = error_info.get("stdout_tail", "")
# 构建优化指导语(聚焦错误修复,保留原有核心要求)
guidance = textwrap.dedent(
f"""
You previously produced a composed Triton implementation, but it failed
to run/compile. Analyze the ERROR_CONTEXT below and re-emit the entire
corrected single-file implementation as one ```python block.
TARGET PLATFORM: {target_platform.name}
DEVICE STRING: {target_platform.device_string}
Requirements remain the same. Additionally:
- Fix any Triton compilation/runtime errors. For scalar constants in
elementwise ops (e.g., ReLU), do not use tl.broadcast. Use direct
scalars like 0.0 in tl.maximum(x, 0.0).
- Keep function name kernel_function(...) unchanged and retain the
self-test that prints PASS on success and exits 0.
- Do NOT reintroduce any PyTorch math path in kernel_function. The final
outputs must be computed via your Triton kernels only (no fallback to
torch.nn / torch.nn.functional ops).
- Return the complete corrected file; do not send diffs.
"""
).strip()
# 拼接完整的优化指令(按“指导语→错误信息→原始代码→子图摘要→上一轮代码”组织)
lines: list[str] = []
lines.append(guidance) # 优化指导语
lines.append("") # 空行分隔
# 添加标准错误日志(核心错误信息)
lines.append("ERROR_CONTEXT (stderr tail):\n```\n" + err_tail + "\n```")
# 若标准输出非空,添加标准输出日志(辅助分析)
if out_tail.strip():
lines.append("STDOUT tail:\n```\n" + out_tail + "\n```")
lines.append("") # 空行分隔
# 添加原始问题代码(保证上下文完整)
lines.append("ORIGINAL PROBLEM FILE:\n```python\n" + problem_code + "\n```")
lines.append("") # 空行分隔
# 添加子图摘要(避免LLM遗忘核心特征)
lines.append("SUBGRAPHS (summary):\n" + _summarize_subgraphs_for_prompt(subgraphs))
lines.append("") # 空行分隔
# 添加上一轮错误代码(让LLM对比分析问题)
lines.append("PREVIOUS_ATTEMPT:\n```python\n" + previous_code + "\n```")
lines.append("") # 空行分隔
# 最终要求:仅返回修正后的完整Python代码块
lines.append(
"Return only one fenced Python code block with the corrected implementation."
)
# 拼接所有行,形成优化Prompt
return "\n".join(lines)
def _summarize_subgraphs_for_prompt(subgraphs: list[dict[str, Any]]) -> str:
"""
生成子图摘要:将复杂的子图JSON信息压缩为简洁的文本格式
核心目标:控制Token消耗,同时保留子图的ID、类型、布局、 dtype、形状、算子等核心特征
"""
# 初始化摘要行列表
lines: list[str] = []
# 遍历每个子图
for it in subgraphs:
# 提取子图ID(默认unknown)
sid = str(it.get("id", "unknown"))
# 提取子图类型(如conv2d、linear)
typ = str(it.get("type", ""))
# 提取数据布局(默认NCHW)
layout = it.get("data_layout") or "NCHW"
# 提取数据类型(默认float32)
dtype = it.get("dtype") or "float32"
# 提取输入形状(兼容多输入/单输入格式)
inputs = it.get("inputs")
in_shape = it.get("input_shape")
# 提取输出形状
out_shape = it.get("output_shape")
# 提取算子列表
ops = it.get("ops") or []
# 构建形状描述行(优先多输入格式,其次单输入格式)
shapes_line = (
f"inputs={inputs if inputs is not None else in_shape}, output={out_shape}"
)
# 构建子图核心信息行
lines.append(
f"- ID={sid} type={typ} layout={layout} dtype={dtype} {shapes_line}"
)
# 构建算子摘要(限制长度为400字符,避免Token溢出)
try:
# 尝试转为JSON字符串(结构化)
ops_short = json.dumps(ops)[:400]
except Exception:
# 转换失败则直接转为字符串
ops_short = str(ops)[:400]
# 添加算子摘要行(缩进2空格,提升可读性)
lines.append(f" ops={ops_short}")
# 拼接所有行,形成最终的子图摘要
return "\n".join(lines)
以下两个函数是 PyTorch KernelAgent 中端到端 Triton 内核组合生成的核心执行模块。
_auto_patch_common_triton_issues 作为轻量自动化代码修复工具,为 LLM 生成的 Triton 代码做前置避坑补丁;_auto_patch_common_triton_issues:对 LLM 生成的 Triton 代码做无侵入式文本级自动化补丁,专门修复 Triton 开发中高频、易犯的基础错误,在代码运行 / 验证前提前规避低级错误,减少无效迭代,提升内核生成效率。compose 作为顶层入口与流程调度核心,串联起「Prompt 构建、LLM 生成、代码补丁、真机验证、迭代精修、产物归档」的全流程,是 KernelAgent 从「分散子图 / 内核素材」到「可运行端到端 Triton 内核」的端到端驱动核心,根据迭代轮次动态构建「组合 / 精修 Prompt」,调用 LLM 生成代码,对生成代码做自动化补丁,支持真机验证驱动的多轮迭代精修,最终输出硬件适配、可直接部署的一体化 Triton 内核及标准化结果汇总。针对 LLM 生成 Triton 代码时易出现的两类高频低级错误做自动化文本修复,在代码运行前提前避坑,减少因基础错误导致的验证失败,提升迭代效率:
tl.broadcast 问题(如 tl.broadcast(0.0, ...) 替换为直接标量 0.0),贴合 Triton 标量运算规范;cuda_hacks_to_strip 动态剥离不兼容代码,支持多平台扩展,无需硬编码平台规则;_fake_torch_device 这类函数定义,实现整段代码块的跳过过滤,处理更彻底;(patched_code, changed) 元组返回结果,明确告知调用方是否对代码做了修改,便于后续日志记录与流程监控;作为内核组合生成的全流程调度核心,完成从输入加载到最终产物输出的所有步骤:
run_candidate 做真机运行验证,失败则提取错误日志反馈给 LLM 精修,形成 「生成→补丁→验证→报错→精修」 的闭环,直至验证通过或达到最大迭代轮次,大幅提升最终内核的可用性;verify 开关控制是否开启真机验证 —— 开启时保证产物有效性,关闭时仅做一次生成即终止,提升轻量使用场景的效率;max_iters 控制最大迭代轮次(默认 5 轮),可根据需求调整,平衡生成效率与成功率;last_usage)并纳入最终结果,支持 LLM 生成成本的统计与管控;composed_kernel.py,实现全流程的无缝衔接;run_candidate 时设置 2400 秒超时,同时支持隔离运行、禁止网络等配置,避免验证过程因环境问题卡死,提升验证的稳定性。两个函数形成 「前置补丁防护 + 全流程调度执行」 的紧密协同关系,是 KernelAgent 组合生成端到端 Triton 内核的核心支撑:
compose 函数的前置子步骤,在 LLM 生成代码后、真机验证前执行,提前修复 Triton 基础错误,减少因低级问题导致的验证失败,为 compose 的迭代流程「减负」,提升整体执行效率;_auto_patch_common_triton_issues,并为其传递目标平台配置,让补丁操作贴合硬件要求,同时记录补丁后的代码并推进后续验证、迭代流程;def _auto_patch_common_triton_issues(
code: str, target_platform: PlatformConfig
) -> tuple[str, bool]:
"""Apply tiny safe textual patches for known Triton pitfalls.
- Replace tl.broadcast(0.0, ...) or tl.broadcast(1.0, ...) with scalar constants.
Returns (patched_code, changed).
自动修复Triton算子代码中的常见问题:
- 核心修复点:将tl.broadcast(0.0/1.0/0/1, ...)替换为标量常量(避免Triton广播操作的性能/语法问题)
- 返回值:(修复后的代码, 是否发生修改)
"""
# 初始化修复后的代码为原始代码
patched = code
# 标记是否发生修改(默认未修改)
changed = False
# 修复规则:采用保守的简单启发式规则,仅处理无歧义的常见问题
patterns = [
# 规则1:tl.broadcast(0.0 → 替换为0.0(移除不必要的广播操作)
("tl.broadcast(0.0", "0.0"),
# 规则2:tl.broadcast(1.0 → 替换为1.0
("tl.broadcast(1.0", "1.0"),
# 规则3:tl.broadcast(0, → 替换为0.0(统一数值类型为浮点数)
("tl.broadcast(0,", "0.0"),
# 规则4:tl.broadcast(1, → 替换为1.0
("tl.broadcast(1,", "1.0"),
]
# 遍历所有修复规则
for old, new in patterns:
# 若原始代码包含待修复的模式
if old in patched:
# 执行文本替换
patched = patched.replace(old, new)
# 标记为已修改
changed = True
# 移除CUDA相关的冗余hack代码(平台适配)
# 获取当前目标平台需要剥离的CUDA hack模式列表
cuda_hacks = target_platform.cuda_hacks_to_strip
if cuda_hacks:
# 将代码按行拆分
lines = patched.split("\n")
# 存储过滤后的代码行
filtered_lines = []
# 标记是否需要跳过直到空行(用于移除整个函数块)
skip_until_blank = False
# 逐行处理代码
for line in lines:
# 若处于跳过模式:跳过当前行,直到遇到空行
if skip_until_blank:
if line.strip() == "":
# 遇到空行,退出跳过模式
skip_until_blank = False
continue
# 检查当前行是否包含需要剥离的CUDA hack模式
if any(hack in line for hack in cuda_hacks):
# 标记为已修改
changed = True
# 特殊处理:若为_fake_torch_device函数定义,需跳过整个函数块(直到空行)
if "def _fake_torch_device" in line:
skip_until_blank = True
# 跳过当前行(移除hack代码)
continue
# 保留当前行
filtered_lines.append(line)
# 重新拼接过滤后的代码行
patched = "\n".join(filtered_lines)
# 返回修复后的代码和是否修改的标记
return patched, changed
def compose(
problem_path: Path,
subgraphs_path: Path,
kernels_summary_path: Path,
out_dir: Path,
model_name: str,
verify: bool = False,
max_iters: int = 5,
target_platform: str = "cuda",
) -> dict[str, Any]:
"""
核心函数:基于子图信息和内核摘要,合成/优化Triton算子代码
参数说明:
- problem_path: KernelBench问题文件路径
- subgraphs_path: 子图信息JSON文件路径(由之前的子图提取模块生成)
- kernels_summary_path: 内核摘要文件路径
- out_dir: 输出目录路径
- model_name: LLM模型名称
- verify: 是否验证生成的代码(默认False)
- max_iters: 最大迭代次数(默认5)
- target_platform: 目标平台(默认cuda)
返回值:包含合成结果的字典(成功状态、代码路径、迭代次数等)
"""
# 前置检查:确保LLM提供商模块可导入
if get_model_provider is None:
raise SystemExit(
"KernelAgent providers unavailable; ensure package import and dependencies"
)
# 创建输出目录(递归创建父目录,已存在则忽略)
out_dir.mkdir(parents=True, exist_ok=True)
# 获取LLM提供商实例(用于调用LLM生成代码)
provider = get_model_provider(model_name)
# 初始化目标平台配置(加载平台相关的修复规则、hack列表等)
platform = get_platform(target_platform)
# 加载输入文件:
# 1. 加载KernelBench问题代码(原始PyTorch代码)
problem_code = _read_text(problem_path)
# 2. 加载子图信息JSON(由子图提取模块生成)
subgraphs = json.loads(_read_text(subgraphs_path))
# 检查子图信息是否为列表:非列表则退出
if not isinstance(subgraphs, list):
raise SystemExit("subgraphs.json must be a JSON array")
# 3. 从内核摘要文件加载内核信息
kernels = _load_kernels_from_summary(kernels_summary_path)
# 创建迭代尝试目录(存储每一轮的生成代码)
attempts_dir = out_dir / "attempts"
attempts_dir.mkdir(parents=True, exist_ok=True)
# 初始化变量:
# - last_usage: 最后一次LLM调用的token使用信息
# - last_code: 上一轮生成的代码
# - verify_info: 验证信息字典
last_usage = None
last_code = None
verify_info: dict[str, Any] = {}
# 多轮迭代生成/优化代码(最多max_iters轮)
for i in range(1, max_iters + 1):
# 构建LLM Prompt:
# 第一轮/上一轮代码为空 → 构建初始合成Prompt
if i == 1 or last_code is None:
prompt = _build_composition_prompt(
problem_code, subgraphs, kernels, target_platform=platform
)
else:
# 非第一轮 → 基于上一轮的错误信息构建优化Prompt
# 提取验证日志的最后2000字符(便于LLM定位问题)
stderr_tail = ""
stdout_tail = ""
try:
# 读取标准错误日志尾部
if verify_info.get("stderr_path"):
with open(
verify_info["stderr_path"],
"r",
encoding="utf-8",
errors="ignore",
) as f:
stderr_tail = f.read()[-2000:]
# 读取标准输出日志尾部
if verify_info.get("stdout_path"):
with open(
verify_info["stdout_path"],
"r",
encoding="utf-8",
errors="ignore",
) as f:
stdout_tail = f.read()[-2000:]
except Exception:
# 日志读取失败则忽略(避免中断迭代)
pass
# 构建优化Prompt(包含上一轮代码和错误信息)
prompt = _build_refinement_prompt(
problem_code,
subgraphs,
kernels,
previous_code=last_code,
error_info={"stderr_tail": stderr_tail, "stdout_tail": stdout_tail},
target_platform=platform,
)
# 保存当前轮次的Prompt(便于追溯和调试)
(attempts_dir / f"attempt_{i}.prompt.txt").write_text(prompt, encoding="utf-8")
# 调用LLM生成代码:
# - 消息格式:仅包含user角色的Prompt
# - 最大生成Token数:50000(适配长代码生成)
response = provider.get_response(
model_name, [{"role": "user", "content": prompt}], max_tokens=50000
)
# 记录最后一次LLM调用的token使用信息
last_usage = response.usage
# 提取LLM输出的原始文本
raw_text = response.content or ""
# 从LLM输出中提取Python代码(剥离无关文本)
extracted = extract_single_python_file(raw_text)
code = extracted.code
# 自动修复Triton常见问题(文本补丁)
code, changed = _auto_patch_common_triton_issues(code, platform)
# 保存当前轮次的生成代码
(attempts_dir / f"attempt_{i}.py").write_text(code, encoding="utf-8")
# 更新last_code为当前轮次的代码
last_code = code
# 验证当前轮次的代码(若开启验证)
if verify:
# 运行候选代码验证:
# - artifacts_code_path: 当前轮次的代码路径
# - run_root: 验证运行目录
# - timeout_s: 验证超时时间(2400秒)
# - isolated: 非隔离模式
# - deny_network: 允许网络访问
rr = run_candidate(
artifacts_code_path=attempts_dir / f"attempt_{i}.py",
run_root=out_dir / "runs",
timeout_s=2400,
isolated=False,
deny_network=False,
)
# 记录验证信息
verify_info = {
"verify_rc": rr.rc, # 验证退出码
"verify_passed": rr.passed, # 是否验证通过
"verify_reason": rr.reason, # 验证结果原因
"validator": rr.validator_used, # 使用的验证器
"stdout_path": str(rr.stdout_path), # 标准输出日志路径
"stderr_path": str(rr.stderr_path), # 标准错误日志路径
}
# 若验证通过,终止迭代(无需继续优化)
if rr.passed:
break
else:
# 若未开启验证,仅执行第一轮后终止
break
# 保存最终合成的算子代码(取最后一轮的代码)
composed_path = out_dir / "composed_kernel.py"
composed_path.write_text(last_code or "", encoding="utf-8")
# 构建结果字典:包含核心合成信息
result: dict[str, Any] = {
# 成功状态:验证模式下为是否通过验证,非验证模式下默认成功
"success": bool(verify_info.get("verify_passed", not verify)),
# 最终合成代码的绝对路径
"composed_path": str(composed_path.resolve()),
# 使用的LLM模型名称
"model": model_name,
# LLM token使用信息
"usage": last_usage,
# 实际迭代次数
"rounds": i,
# 目标平台
"target_platform": target_platform,
}
# 合并验证信息到结果字典
result.update(verify_info)
# 保存合成摘要(结构化JSON文件)
(out_dir / "composition_summary.json").write_text(
json.dumps(result, indent=2), encoding="utf-8"
)
# 返回结果字典
return result
此内容由惯性聚合(RSS阅读器)自动聚合整理,仅供阅读参考。 原文来自 — 版权归原作者所有。