










Fuser/compose_end_to_end.py는 Fuser 파이프라인의 마지막 중요한 단계로, 특정 서브그래프에 대한 분산된 Triton 커널을 원활하게 통합하여 단일하고 고성능의 엔드투엔드 Triton 커널을 만들어내며, 원래 PyTorch 구현의 수치적 동등성을 보장합니다.
Composer의 아키텍처 다이어그램은 다음과 같으며, 그 기능 요약은 한 문장으로 표현하면 다음과 같다: 모든 검증된 서브그래프 커널과 원본 질문을 LLM에 넣어 단일 파일로 만든다. 또한 오류 로그를 LLM에 수정하며, 최대 max_iters 횟수까지 재시도하고 자동으로 흔한 Triton 함정을 패치한다.

Composer의 핵심 기능은 다음과 같다:
kernel_function를 포함해야 함, 파이토치 계산 로직을 금지) 및 숫자 검증을 통해 생성된 Triton 커널이 사용 가능하고 올바른지 보장합니다. | 특징 방향 | 구체적 설명 |
|---|---|
| 엄격한 코드 제약 | 는 생성된 코드에 kernel_function 최상위 함수(원본 모델 입력과 일치)、@triton.jit 내核、자체 테스트 함수(출력 PASS/0 종료 코드)를 포함하도록 강제합니다; 내核에서는 PyTorch 계산 로직을 사용하지 않도록 금지합니다(자체 테스트 시 비교만 허용). |
| 는 지능적인 오류 반복 | 는 컴파일/실행 오류(stderr/stdout)를 포착하고, LLM이 문제를 찾아 수정하도록 정밀한 프롬프트를 구축합니다(예: Triton의 흔한 tl.broadcast 잘못 사용). |
| 는 자동 패치로 수정 | 는 내장 Triton 흔한 문제의 텍스트 패치를 제공(예: tl.broadcast(0.0)를 0.0으로 바꾸기), LLM 반복을 줄입니다. |
| 는 완전한 로그 보관 | 각 라운드의 Prompt, 생성된 코드, 검증 결과를 보존하여 디버깅 및 생성 과정 추적을 용이하게 합니다. |
| 값의 동등성 보장 | 자체 검사 함수는 allclose 검증을 사용하여 값( fp32: rtol≤1e-3/atol≤1e-3; fp16/bf16: ≤2e-2)을 검증하여 Triton 구현이 PyTorch 결과와 일치하는지 확인해야 합니다. |


compose_end_to_end.py를 사용하여 원래 KernelBench 문제를 해결하기 위해 Triton 내核을 엔드 투 엔드(compose)합니다.
compose_end_to_end.py는 원본 문제 파일, 서브 그래프 분해 정보 + 각 텐서 형태(来自 subgraphs.json), 그리고 생성된 서브 그래프 Triton 커널을 입력으로 받아, 프롬프트(prompt)를 LLM에게 전송합니다. 그런 다음 LLM은 이 조각들을 하나로 합쳐 의미와 원본 문제와 완전히 일치하는 의 완전한 커널을 만듭니다. 최종적으로 /composed_kernel.py 라는 이름의 파이썬 파일을 반환하며, 이 파일에는 다음을 제공합니다:
@triton.jit 를 사용하여 꾸며진 Triton 커널. kernel_function(...) 인 최상위 파이썬 래퍼 함수, 이 함수는 원본 모델과 동일한 입력 텐서를 받아들이고 Triton 커널의 실행을 조율하며 최종 출력을 반환합니다. test_kernel 또는 run_tests), 해당 함수는 Triton 구현 결과와 원래 PyTorch 문제 코드의 참조 결과를 비교하고, 성공 시 'PASS'를 출력하고 종료합니다. composition_summary.json). compose_end_to_end.py 사용법은 다음과 같습니다:
python -m Fuser.compose_end_to_end \
--problem /abs/path/to/kernelbench_problem.py \
--subgraphs /abs/path/to/subgraphs.json \
--kernels-summary /abs/path/to/kernels_out/summary.json \
[--model gpt-5] \
[--out-dir ./compose_out] \
[--verify]
는 --verify 플래그를 통해 자동 검증을 활성화할 수 있습니다. 이 모드에서는 초기 또는 수정된 각 LLM 생성 조합 시도가 Fuser/runner.py의 run_candidate 함수에 전달되어 실행됩니다.
검증의 성공 여부는 실행이 정상적으로 종료되었는지 (exit code 0) 및 출력에 포함되는지에 따라 결정됩니다.'PASS' 문자열 또는 ALL_TESTS_PASSED 문자열입니다.

LLM이 생성한 첫 번째 조합 커널이 검증을 통과하지 못하면(실행/컴파일 실패), 스크립트는 오류 메시지(stderr, stdout)를 캡처합니다. 새로운 프롬프트(_build_refinement_prompt)를 구축하고, 오류 메시지를 LLM의 맥락으로 제공하여 코드를 수정하도록 요청합니다. 이 과정은 여러 번 반복될 수 있습니다(max_iters 매개변수가 최대 반복 횟수를 제어), 생성된 커널이 검증을 통과하거나 최대 반복 횟수에 도달할 때까지 계속됩니다.
_load_kernels_from_summary 프롬프트를 구축하기 위해 「코드 소재」를 제공합니다.(각 서브플롯의 유효한 Triton 내核 코드),구체적으로, _load_kernels_from_summary는 스케줄링 단계에서 생성된 내核 요약 JSON 파일에서 모든 성공적으로 생성된 유효한 서브플롯 Triton 내核을 필터링하고 로드하며, 데이터 형식 및 파일 유효성을 검증하고 표준화된 KernelItem 객체 목록으로 포장하여 LLM 조합 내核을 위해 직접 재사용할 수 있는 유효한 코드 자료를 제공하며, 실패한 또는 유효하지 않은 내核 제품을 필터링하여 무효한 자료가 후속 조합 프로세스를 방해하지 않도록 합니다.
의 특징은 다음과 같습니다:
다차원 유효한 내核 필터링: 순차적으로「리스트 형식 요약 데이터가 아닌 것, 사전 형식 서브 항목이 아닌 것, 실패로 표시된 내核, ID 없음 / 내核 경로 없는 서브 항목, 내核 파일이 존재하지 않는 서브 항목」을 필터링하고, 전체 검증을 통과한 성공적인 내核만 유지하여 코드 자료의 유효성을 원천에서 보장합니다;
핵심 필드 강제 검증:자사 ID(sid)와 커널 파일 경로(kernel_path)는 필수 필드로, 누락 시 즉시 필터링하여 각 유효한 커널이 유일한 자사와 실제 코드 파일에 연결되도록 보장합니다.
표준화된 객체 포장:커널의 자사 ID, 파일 경로, 코드 내용을 KernelItem 객체로 포장하여 원시 사전 / 문자열이 아닌 것으로, 후속 코드 처리의 가독성과 유지보수성을 향상시킵니다.
유효한 커널 없으면 즉시 종료:요약 파일에 유효한 커널이 없으면 즉시 SystemExit 예외를 발생시켜 프로세스를 종료하여 후속 프로세스가 유효한 자료 없이 무의미하게 실행되는 것을 방지합니다.
호환 스케줄링 단계 출력 형식:dispatch 단계에서 생성된 summary.json 형식을 엄격하게 적합하여 상하위 프로세스의 무缝 연결을 구현합니다.
아래 세 개의 함수는 PyTorch KernelAgent의 LLM 생성 단말 Triton 커널의 Prompt 구축 핵심 모듈 로서, LLM에 표준화되고, 높은 지도성, 시나리오에 맞는 정확한 입력 힌트를 제공 하며,「분산된 서브그래프 / 커널 / 문제 데이터」와「LLM이 이해할 수 있는 생성 명령」을 연결하는 핵심 허브입니다. 그 중
_summarize_subgraphs_for_prompt는 기본 데이터 처리 함수로서 서브그래프 정보를 형식화하는 역할을 합니다. _build_composition_prompt와 _build_refinement_prompt는 이중 Prompt 구축 주 함수로서, 각각 초기 단말 커널 조합 생성 과 오류 기반 반복적 정제를 지원합니다.두 가지 핵심 시나리오를 통해 엄격한 지시 제약, 완전한 문맥 정보, 대상적 최적화 지침을 통해 LLM이 엔지니어링 요구사항, 하드웨어 적합성, 직접 실행 가능한 Triton 핵 코드를 생성하도록 보장합니다.kernel_function 함수 이름, 자체 테스트 함수 및「PASS」프린트 / 종료 코드 규칙을 명확히 요구하여 보수된 코드가 후속 자동 검증 프로세스에 무缝 연결될 수 있도록 하여 검증 논리 수정 없이 사용할 수 있도록 보장;_summarize_subgraphs_for_prompt는 「논리적 제약 조건」(각 하위 그래프의 기능, 모양, 레이아웃 등 제약 정보)를 제공합니다. 구체적으로, _summarize_subgraphs_for_prompt: 모델 분해 후의 하위 그래프 정보 리스트를 구조화하고 간결하게 텍스트 요약, 하위 그래프 ID, 유형, 데이터 레이아웃, 데이터 유형, 입력 출력 모양, 핵심 연산자 등 핵심 정보를 추출하여 통일된 형식으로 연결하여 읽기 쉬운 텍스트 문자열로 만들어, LLM 조합 Prompt 구성을 위한 표준화된 하위 그래프 정보 설명을 제공하여 LLM이 각 하위 그래프의 기능, 모양 제약 조건 및 계산 요구 사항을 빠르게 이해하도록 합니다.
inputs 필드를 우선 사용하여 입력을 설명하고, inputs이 없으면 input_shape으로 격하하여 다양한 서브그래프 분해 도구의 출력 형식 차이에 적합하게 하여 입력 및 출력 형상 정보의 효과적인 전달을 보장합니다. 세 가지 함수는 「기본 데이터 처리→초기 프롬프트 구축→반복적 정제 프롬프트 구축」의 순서로 형성합니다.의 계층적 지지 관계는 LLM이 엔드투엔드 Triton 핵심을 생성하기 위해 전 과정의 Prompt 지지를 제공합니다:
_summarize_subgraphs_for_prompt는 원시 서브그래프 정보에 대해 통일된 형식화 처리를 수행하여 표준화된 서브그래프 요약을 생성하고, 상위 두 개의 Prompt 구축 함수에 일관되고 쉽게 해석할 수 있는 서브그래프 제약 정보를 제공하여 데이터의 한 번의 처리와 여러 번의 재사용을 실현합니다; _build_composition_prompt는 표준화된 서브그래프 요약을 기반으로 문제 코드, 서브그래프 핵심, 플랫폼 구성을 통합하여 전체 제약 조건의 조합 Prompt를 구축하고, LLM이 엔드투엔드 Triton 핵심을 처음 생성하는 데 도움을 주어「없음에서 있음으로」의 핵심 지시를 제공합니다; _build_refinement_prompt 표준화된 서브그래프 요약과 핵심 기본 정보를 재사용하고, 오류 로그와 이전 실패 코드를 추가하여 오류 방향성을 가진 정교한 Prompt를 구축하여 LLM이「잘못된 것에서 올바른 것으로」의 반복적 최적화를 수행하도록 지시하는 것은 코드의 효과성을 향상시키는 핵심 요소입니다; def _build_composition_prompt(
problem_code: str,
subgraphs: list[dict[str, Any]],
kernel_items: list[KernelItem],
target_platform: PlatformConfig,
) -> str:
"""Create a single user message to instruct composition by the LLM.
构建初始合成Prompt:
为LLM生成结构化指令,引导其基于子图和参考内核,合成端到端的Triton算子代码
参数:
- problem_code: 原始KernelBench问题代码(PyTorch)
- subgraphs: 子图信息列表(JSON格式)
- kernel_items: 参考内核列表(包含子图ID和对应的Triton代码)
- target_platform: 目标平台配置(如CUDA/XPU)
返回值:完整的LLM用户指令字符串
"""
# 第一步:生成子图摘要(压缩子图信息,控制Token消耗,便于LLM快速理解核心特征)
sg_summary = _summarize_subgraphs_for_prompt(subgraphs)
# 第二步:构建参考内核代码区块(仅保留核心代码,避免Token溢出)
# 注释说明:暂时保留完整文件内容,调用方可根据模型窗口限制进一步裁剪
# 初始化内核区块的文本片段列表
kernels_section_parts: list[str] = []
# 遍历每个参考内核项
for ki in kernel_items:
# 为每个子图内核构建带格式的代码片段(Markdown Python代码块)
kernels_section_parts.append(
f"### Subgraph {ki.subgraph_id}\n```python\n" + ki.code + "\n```\n"
)
# 拼接所有内核片段,形成完整的参考内核区块
kernels_section = "\n".join(kernels_section_parts)
# 第三步:获取平台专属的指导规则(如CUDA的内存访问规则、XPU的编译要求)
platform_guidance = target_platform.guidance_block
# 第四步:构建核心指导语(包含任务背景、平台信息、硬性要求、实现技巧)
# 使用textwrap.dedent去除缩进,保证Prompt格式整洁
guidance = textwrap.dedent(
f"""
You are given:
- The original problem file (PyTorch module and helpers).
- A decomposition of the model into fusable subgraphs with exact shapes.
- Working Triton kernels generated for some subgraphs.
TARGET PLATFORM: {target_platform.name}
DEVICE STRING: {target_platform.device_string}
{platform_guidance}
Task:
- Compose an end-to-end Triton implementation that matches the original
model's forward pass for the provided shapes. You may inline, adapt,
or reuse the given subgraph kernels. Prefer fusing into as few kernel
launches as possible while preserving exact numerical semantics.
Hard requirements:
- Return ONE complete Python file only, fenced as a single ```python block.
- Allocate inputs, weights, intermediates, and outputs on device='{target_platform.device_string}' and keep them there throughout forward/verification.
- CPU is acceptable only for metadata, scalars, and export serialization—avoid `.cpu()` or `.to('cpu')` on compute tensors.
- Provide at least one @triton.jit kernel and a top-level Python wrapper
named kernel_function(...). This wrapper must accept the same primary
input tensor(s) as the model and any required weights/biases with shapes
implied by the problem; it should orchestrate Triton kernel(s) and
return the final output tensor.
- No PyTorch math path: kernel_function MUST compute the final outputs
using your Triton kernels only. Do NOT implement or fall back to
torch.nn / torch.nn.functional / torch.* ops
sigmoid, etc.) for producing the final result. Using PyTorch for
reference comparisons is allowed only inside the self-test.
- Use the data layout and dtype semantics indicated by subgraphs, defaulting
to NCHW + float32 if unspecified. Respect stride/padding/dilation/groups,
and exact op order.
- Numerical equivalence: include a self-test (test_kernel or run_tests)
that compares your Triton-based result to a PyTorch reference computed
from the original problem code below (use get_init_inputs() and
get_inputs() if present to instantiate the Model). The test must print
'PASS' on success and exit with code 0. Use allclose with rtol<=1e-3,
atol<=1e-3 for fp32; for fp16/bf16 allow up to 2e-2.
- No imports beyond torch, triton, triton.language as tl, and stdlib. No I/O.
- Do NOT monkey-patch PyTorch device functions or torch.cuda.is_available()
- Do NOT manipulate TRITON_BACKENDS environment variable
- Do NOT disable or mock XPU/CUDA drivers
Implementation tips:
- If merging multiple subgraphs, ensure intermediate tensor shapes match.
- Hoist constant weights or parameters to avoid reloading per block.
- Use tl.load/tl.store with masks for boundary conditions.
- Favor coalesced memory access; tile by blocks; compute grid from shape.
- Common Triton pitfalls to avoid:
* Do NOT call tl.broadcast on Python scalars; tl.maximum(x, 0.0) works.
* Prefer scalar constants directly in elementwise ops (no explicit broadcast needed).
* Keep BLOCK_SIZE power-of-two; mask stores at tail.
"""
).strip() # 去除首尾空白字符
# 第五步:拼接完整的用户指令(按逻辑组织各部分内容)
user_lines: list[str] = []
user_lines.append(guidance) # 核心指导语
user_lines.append("") # 空行分隔
user_lines.append("SUBGRAPHS (summary):") # 子图摘要标题
user_lines.append(sg_summary) # 子图摘要内容
user_lines.append("") # 空行分隔
user_lines.append("ORIGINAL PROBLEM FILE:") # 原始问题代码标题
user_lines.append("```python") # Python代码块开始标记
user_lines.append(problem_code) # 原始问题代码内容
user_lines.append("```") # Python代码块结束标记
user_lines.append("") # 空行分隔
user_lines.append("SUBGRAPH KERNELS (reference implementations):") # 参考内核标题
user_lines.append(kernels_section) # 参考内核代码内容
user_lines.append("") # 空行分隔
# 最终要求:仅返回一个包含完整代码的Python代码块
user_lines.append(
"Return only one fenced Python code block with your final composed implementation."
)
# 拼接所有行,形成完整的Prompt
return "\n".join(user_lines)
def _build_refinement_prompt(
problem_code: str,
subgraphs: list[dict[str, Any]],
kernel_items: list[KernelItem],
previous_code: str,
error_info: dict[str, str],
target_platform: PlatformConfig,
) -> str:
"""Prompt the LLM to refine the previously produced code based on errors.
构建迭代优化Prompt:
基于上一轮代码的错误信息,引导LLM修复Triton算子代码中的编译/运行错误
参数:
- previous_code: 上一轮生成的错误代码
- error_info: 错误信息字典(包含stderr_tail/stdout_tail)
其他参数同_build_composition_prompt
返回值:针对性的优化指令字符串
"""
# 提取错误日志尾部(最后2000字符,聚焦核心错误)
err_tail = error_info.get("stderr_tail", "")
# 提取标准输出尾部(辅助分析错误)
out_tail = error_info.get("stdout_tail", "")
# 构建优化指导语(聚焦错误修复,保留原有核心要求)
guidance = textwrap.dedent(
f"""
You previously produced a composed Triton implementation, but it failed
to run/compile. Analyze the ERROR_CONTEXT below and re-emit the entire
corrected single-file implementation as one ```python block.
TARGET PLATFORM: {target_platform.name}
DEVICE STRING: {target_platform.device_string}
Requirements remain the same. Additionally:
- Fix any Triton compilation/runtime errors. For scalar constants in
elementwise ops (e.g., ReLU), do not use tl.broadcast. Use direct
scalars like 0.0 in tl.maximum(x, 0.0).
- Keep function name kernel_function(...) unchanged and retain the
self-test that prints PASS on success and exits 0.
- Do NOT reintroduce any PyTorch math path in kernel_function. The final
outputs must be computed via your Triton kernels only (no fallback to
torch.nn / torch.nn.functional ops).
- Return the complete corrected file; do not send diffs.
"""
).strip()
# 拼接完整的优化指令(按“指导语→错误信息→原始代码→子图摘要→上一轮代码”组织)
lines: list[str] = []
lines.append(guidance) # 优化指导语
lines.append("") # 空行分隔
# 添加标准错误日志(核心错误信息)
lines.append("ERROR_CONTEXT (stderr tail):\n```\n" + err_tail + "\n```")
# 若标准输出非空,添加标准输出日志(辅助分析)
if out_tail.strip():
lines.append("STDOUT tail:\n```\n" + out_tail + "\n```")
lines.append("") # 空行分隔
# 添加原始问题代码(保证上下文完整)
lines.append("ORIGINAL PROBLEM FILE:\n```python\n" + problem_code + "\n```")
lines.append("") # 空行分隔
# 添加子图摘要(避免LLM遗忘核心特征)
lines.append("SUBGRAPHS (summary):\n" + _summarize_subgraphs_for_prompt(subgraphs))
lines.append("") # 空行分隔
# 添加上一轮错误代码(让LLM对比分析问题)
lines.append("PREVIOUS_ATTEMPT:\n```python\n" + previous_code + "\n```")
lines.append("") # 空行分隔
# 最终要求:仅返回修正后的完整Python代码块
lines.append(
"Return only one fenced Python code block with the corrected implementation."
)
# 拼接所有行,形成优化Prompt
return "\n".join(lines)
def _summarize_subgraphs_for_prompt(subgraphs: list[dict[str, Any]]) -> str:
"""
生成子图摘要:将复杂的子图JSON信息压缩为简洁的文本格式
核心目标:控制Token消耗,同时保留子图的ID、类型、布局、 dtype、形状、算子等核心特征
"""
# 初始化摘要行列表
lines: list[str] = []
# 遍历每个子图
for it in subgraphs:
# 提取子图ID(默认unknown)
sid = str(it.get("id", "unknown"))
# 提取子图类型(如conv2d、linear)
typ = str(it.get("type", ""))
# 提取数据布局(默认NCHW)
layout = it.get("data_layout") or "NCHW"
# 提取数据类型(默认float32)
dtype = it.get("dtype") or "float32"
# 提取输入形状(兼容多输入/单输入格式)
inputs = it.get("inputs")
in_shape = it.get("input_shape")
# 提取输出形状
out_shape = it.get("output_shape")
# 提取算子列表
ops = it.get("ops") or []
# 构建形状描述行(优先多输入格式,其次单输入格式)
shapes_line = (
f"inputs={inputs if inputs is not None else in_shape}, output={out_shape}"
)
# 构建子图核心信息行
lines.append(
f"- ID={sid} type={typ} layout={layout} dtype={dtype} {shapes_line}"
)
# 构建算子摘要(限制长度为400字符,避免Token溢出)
try:
# 尝试转为JSON字符串(结构化)
ops_short = json.dumps(ops)[:400]
except Exception:
# 转换失败则直接转为字符串
ops_short = str(ops)[:400]
# 添加算子摘要行(缩进2空格,提升可读性)
lines.append(f" ops={ops_short}")
# 拼接所有行,形成最终的子图摘要
return "\n".join(lines)
다음 두 함수는 PyTorch KernelAgent에서 엔드투엔드 Triton 핵심 조합 생성의 핵심 실행 모듈입니다.
_auto_patch_common_triton_issues 이것은 가볍게 자동화된 코드 수선 도구는 LLM이 생성한 Triton 코드에 대한 사전 방해물 회피 패치를 제공합니다; _auto_patch_common_triton_issues: LLM이 생성한 Triton 코드에 대한 침투 없는 텍스트 수준 자동화 패치를 수행하여 Triton 개발 중 고발생 및 쉽게 범하는 기초 오류를 특별히 수정하고, 코드 실행 / 검증 전 낮은 수준의 오류를 사전에 방지하여 무효 반복을 줄이고 핵심 생성 효율을 향상시킵니다.compose는 최상위 입구와 프로세스 스케줄링 핵심로서 '프롬프트 구축, LLM 생성, 코드 패치, 실기 검증, 반복 정비, 제품 보관'의 전 과정을 연결하여 KernelAgent가 '분산 서브그래프 / 핵심 자료'에서 '실행 가능한 엔드투엔드 Triton 핵심'으로 엔드투엔드 드라이빙 핵심입니다., 반복 라운드에 따라 동적으로「조합 / 정비 Prompt」를 구축하고 LLM을 호출하여 코드를 생성한 후, 생성된 코드에 자동 패치를 적용하며, 실기 검증을 통한 다중 라운드 반복 정비를 지원하여 최종적으로 하드웨어 적합성을 갖춘 직접 배포 가능한 통합 Triton 커널 및 표준화된 결과 요약을 출력합니다.LLM이 Triton 코드를 생성할 때 자주 발생하는 문제에 대해두 가지 고빈도 저급 오류자동화된 텍스트 수정을 통해 코드 실행 전 함정을 미리 피하고, 기본 오류로 인한 검증 실패를 줄이며, 반복 효율을 향상시킵니다:
tl.broadcast문제(예)tl.broadcast(0.0, ...)바로 대입하기0.0), Triton 스칼라 연산 규범에 맞춰서; cuda_hacks_to_strip동적으로 호환되지 않는 코드를 제거하고, 다양한 플랫폼 확장을 지원하며, 플랫폼 규칙을 강제로 코딩할 필요 없음;_fake_torch_device이러한 함수 정의, 구현코드 블록 전체 건너뛰기 필터링, 더 완벽하게 처리하십시오;(patched_code, changed)튜플로 결과를 반환하고, 호출자에게 코드가 수정되었는지 명확히 알려주어 후속 로그 기록 및 프로세스 모니터링을 용이하게 합니다;는 커널 조합으로 생성된전체 프로세스 스케줄링 핵심로, 입력 로드부터 최종 제품 출력까지 모든 단계를 완료합니다:
run_candidate 실제 기기에서 실행을 검증하고, 실패하면 오류 로그를 추출하여 LLM으로 정제 요청을 전달하여 「생성→패치→검증→오류 발생→정제」의 순환 구조를 형성하고, 검증이 통과하거나 최대 반복 횟수에 도달할 때까지 최종 커널의 사용 가능성을 크게 향상시킵니다; verify 스위치를 통해 실제 기기 검증을 켜고 끌 수 있도록 지원합니다 — 켜지면 제품의 유효성을 보장하고, 끄면 한 번 생성한 후 종료하여 가벼운 사용 시나리오의 효율성을 향상시킵니다; max_iters를 통해 최대 반복 횟수를 제어합니다 (기본값 5회), 요구 사항에 따라 조정하여 생성 효율성과 성공률을 균형을 맞춥니다. last_usage)를 기록하여 최종 결과에 포함시키며, LLM 생성 비용의 통계 및 관리를 지원합니다.composed_kernel.py를 직접 호출하여 전체 프로세스의 무缝 연결을 실현합니다; run_candidate을 호출할 때 2400초 타임아웃을 설정하고, 격리 실행, 네트워크 금지 등의 구성을 지원하여 검증 과정이 환경 문제로 멈추는 것을 방지하고 검증의 안정성을 향상시킵니다. 두 함수는 「사전 패치 보호 + 전체 프로세스 스케줄링 실행」의 긴밀한 협력 관계를 형성하며, KernelAgent 조합 생성 엔드투엔드 Triton 커널의 핵심 지원입니다:
compose 함수의 전제 단계 로, LLM 코드 생성 후 실기 검증 전에 실행되어 Triton 기본 오류를 사전에 수정하고, 하위 수준의 문제로 인한 검증 실패를 줄이며, compose의 반복 프로세스「부담을 덜어주어」, 전체 실행 효율을 향상시킵니다. _auto_patch_common_triton_issues를 호출하고 목표 플랫폼 구성을 전달하여 패치 작업이 하드웨어 요구 사항에 맞도록 하며, 패치 후 코드를 기록하고 후속 검증 및 반복 프로세스를 진행합니다.def _auto_patch_common_triton_issues(
code: str, target_platform: PlatformConfig
) -> tuple[str, bool]:
"""Apply tiny safe textual patches for known Triton pitfalls.
- Replace tl.broadcast(0.0, ...) or tl.broadcast(1.0, ...) with scalar constants.
Returns (patched_code, changed).
自动修复Triton算子代码中的常见问题:
- 核心修复点:将tl.broadcast(0.0/1.0/0/1, ...)替换为标量常量(避免Triton广播操作的性能/语法问题)
- 返回值:(修复后的代码, 是否发生修改)
"""
# 初始化修复后的代码为原始代码
patched = code
# 标记是否发生修改(默认未修改)
changed = False
# 修复规则:采用保守的简单启发式规则,仅处理无歧义的常见问题
patterns = [
# 规则1:tl.broadcast(0.0 → 替换为0.0(移除不必要的广播操作)
("tl.broadcast(0.0", "0.0"),
# 规则2:tl.broadcast(1.0 → 替换为1.0
("tl.broadcast(1.0", "1.0"),
# 规则3:tl.broadcast(0, → 替换为0.0(统一数值类型为浮点数)
("tl.broadcast(0,", "0.0"),
# 规则4:tl.broadcast(1, → 替换为1.0
("tl.broadcast(1,", "1.0"),
]
# 遍历所有修复规则
for old, new in patterns:
# 若原始代码包含待修复的模式
if old in patched:
# 执行文本替换
patched = patched.replace(old, new)
# 标记为已修改
changed = True
# 移除CUDA相关的冗余hack代码(平台适配)
# 获取当前目标平台需要剥离的CUDA hack模式列表
cuda_hacks = target_platform.cuda_hacks_to_strip
if cuda_hacks:
# 将代码按行拆分
lines = patched.split("\n")
# 存储过滤后的代码行
filtered_lines = []
# 标记是否需要跳过直到空行(用于移除整个函数块)
skip_until_blank = False
# 逐行处理代码
for line in lines:
# 若处于跳过模式:跳过当前行,直到遇到空行
if skip_until_blank:
if line.strip() == "":
# 遇到空行,退出跳过模式
skip_until_blank = False
continue
# 检查当前行是否包含需要剥离的CUDA hack模式
if any(hack in line for hack in cuda_hacks):
# 标记为已修改
changed = True
# 特殊处理:若为_fake_torch_device函数定义,需跳过整个函数块(直到空行)
if "def _fake_torch_device" in line:
skip_until_blank = True
# 跳过当前行(移除hack代码)
continue
# 保留当前行
filtered_lines.append(line)
# 重新拼接过滤后的代码行
patched = "\n".join(filtered_lines)
# 返回修复后的代码和是否修改的标记
return patched, changed
def compose(
problem_path: Path,
subgraphs_path: Path,
kernels_summary_path: Path,
out_dir: Path,
model_name: str,
verify: bool = False,
max_iters: int = 5,
target_platform: str = "cuda",
) -> dict[str, Any]:
"""
核心函数:基于子图信息和内核摘要,合成/优化Triton算子代码
参数说明:
- problem_path: KernelBench问题文件路径
- subgraphs_path: 子图信息JSON文件路径(由之前的子图提取模块生成)
- kernels_summary_path: 内核摘要文件路径
- out_dir: 输出目录路径
- model_name: LLM模型名称
- verify: 是否验证生成的代码(默认False)
- max_iters: 最大迭代次数(默认5)
- target_platform: 目标平台(默认cuda)
返回值:包含合成结果的字典(成功状态、代码路径、迭代次数等)
"""
# 前置检查:确保LLM提供商模块可导入
if get_model_provider is None:
raise SystemExit(
"KernelAgent providers unavailable; ensure package import and dependencies"
)
# 创建输出目录(递归创建父目录,已存在则忽略)
out_dir.mkdir(parents=True, exist_ok=True)
# 获取LLM提供商实例(用于调用LLM生成代码)
provider = get_model_provider(model_name)
# 初始化目标平台配置(加载平台相关的修复规则、hack列表等)
platform = get_platform(target_platform)
# 加载输入文件:
# 1. 加载KernelBench问题代码(原始PyTorch代码)
problem_code = _read_text(problem_path)
# 2. 加载子图信息JSON(由子图提取模块生成)
subgraphs = json.loads(_read_text(subgraphs_path))
# 检查子图信息是否为列表:非列表则退出
if not isinstance(subgraphs, list):
raise SystemExit("subgraphs.json must be a JSON array")
# 3. 从内核摘要文件加载内核信息
kernels = _load_kernels_from_summary(kernels_summary_path)
# 创建迭代尝试目录(存储每一轮的生成代码)
attempts_dir = out_dir / "attempts"
attempts_dir.mkdir(parents=True, exist_ok=True)
# 初始化变量:
# - last_usage: 最后一次LLM调用的token使用信息
# - last_code: 上一轮生成的代码
# - verify_info: 验证信息字典
last_usage = None
last_code = None
verify_info: dict[str, Any] = {}
# 多轮迭代生成/优化代码(最多max_iters轮)
for i in range(1, max_iters + 1):
# 构建LLM Prompt:
# 第一轮/上一轮代码为空 → 构建初始合成Prompt
if i == 1 or last_code is None:
prompt = _build_composition_prompt(
problem_code, subgraphs, kernels, target_platform=platform
)
else:
# 非第一轮 → 基于上一轮的错误信息构建优化Prompt
# 提取验证日志的最后2000字符(便于LLM定位问题)
stderr_tail = ""
stdout_tail = ""
try:
# 读取标准错误日志尾部
if verify_info.get("stderr_path"):
with open(
verify_info["stderr_path"],
"r",
encoding="utf-8",
errors="ignore",
) as f:
stderr_tail = f.read()[-2000:]
# 读取标准输出日志尾部
if verify_info.get("stdout_path"):
with open(
verify_info["stdout_path"],
"r",
encoding="utf-8",
errors="ignore",
) as f:
stdout_tail = f.read()[-2000:]
except Exception:
# 日志读取失败则忽略(避免中断迭代)
pass
# 构建优化Prompt(包含上一轮代码和错误信息)
prompt = _build_refinement_prompt(
problem_code,
subgraphs,
kernels,
previous_code=last_code,
error_info={"stderr_tail": stderr_tail, "stdout_tail": stdout_tail},
target_platform=platform,
)
# 保存当前轮次的Prompt(便于追溯和调试)
(attempts_dir / f"attempt_{i}.prompt.txt").write_text(prompt, encoding="utf-8")
# 调用LLM生成代码:
# - 消息格式:仅包含user角色的Prompt
# - 最大生成Token数:50000(适配长代码生成)
response = provider.get_response(
model_name, [{"role": "user", "content": prompt}], max_tokens=50000
)
# 记录最后一次LLM调用的token使用信息
last_usage = response.usage
# 提取LLM输出的原始文本
raw_text = response.content or ""
# 从LLM输出中提取Python代码(剥离无关文本)
extracted = extract_single_python_file(raw_text)
code = extracted.code
# 自动修复Triton常见问题(文本补丁)
code, changed = _auto_patch_common_triton_issues(code, platform)
# 保存当前轮次的生成代码
(attempts_dir / f"attempt_{i}.py").write_text(code, encoding="utf-8")
# 更新last_code为当前轮次的代码
last_code = code
# 验证当前轮次的代码(若开启验证)
if verify:
# 运行候选代码验证:
# - artifacts_code_path: 当前轮次的代码路径
# - run_root: 验证运行目录
# - timeout_s: 验证超时时间(2400秒)
# - isolated: 非隔离模式
# - deny_network: 允许网络访问
rr = run_candidate(
artifacts_code_path=attempts_dir / f"attempt_{i}.py",
run_root=out_dir / "runs",
timeout_s=2400,
isolated=False,
deny_network=False,
)
# 记录验证信息
verify_info = {
"verify_rc": rr.rc, # 验证退出码
"verify_passed": rr.passed, # 是否验证通过
"verify_reason": rr.reason, # 验证结果原因
"validator": rr.validator_used, # 使用的验证器
"stdout_path": str(rr.stdout_path), # 标准输出日志路径
"stderr_path": str(rr.stderr_path), # 标准错误日志路径
}
# 若验证通过,终止迭代(无需继续优化)
if rr.passed:
break
else:
# 若未开启验证,仅执行第一轮后终止
break
# 保存最终合成的算子代码(取最后一轮的代码)
composed_path = out_dir / "composed_kernel.py"
composed_path.write_text(last_code or "", encoding="utf-8")
# 构建结果字典:包含核心合成信息
result: dict[str, Any] = {
# 成功状态:验证模式下为是否通过验证,非验证模式下默认成功
"success": bool(verify_info.get("verify_passed", not verify)),
# 最终合成代码的绝对路径
"composed_path": str(composed_path.resolve()),
# 使用的LLM模型名称
"model": model_name,
# LLM token使用信息
"usage": last_usage,
# 实际迭代次数
"rounds": i,
# 目标平台
"target_platform": target_platform,
}
# 合并验证信息到结果字典
result.update(verify_info)
# 保存合成摘要(结构化JSON文件)
(out_dir / "composition_summary.json").write_text(
json.dumps(result, indent=2), encoding="utf-8"
)
# 返回结果字典
return result
이 콘텐츠는 인셔셔RSS(RSS 리더)가 자동으로 집계한 것으로 읽기 참고용입니다. 원문 출처 — 저작권은 원저작자에게 있습니다.