










Fuser/compose_end_to_end.py は Fuser パイプラインの最後の重要なステップであり、特定のサブグラフに最適化された分散した Triton 内核を無縫に統合し、単一で高性能なエンドツーエンドの Triton 内核を作成します。同時に、その機能がオリジナルの PyTorch 実装と数值的に等価であることを保証します。
Composerの構成図は以下の通りです。その機能の要約は一言で言えば、「すべての検証に合格したサブグラフのカーネルと元の問題をLLMに渡して単一のファイルにまとめる」です。また、エラーログもLLMに修正させるし、最大max_iters回の再試行を行い、自動的に一般的なTritonの罠をパッチングします。

Composerの核心機能は以下の通りです:
kernel_functionを含める必要があること、PyTorch 計算ロジックを禁止することなど)と数値検証を通じて、生成された Triton カーネルが利用可能で正しいことを保証します。| 機能方向 | 具体的な説明 |
|---|---|
| 厳格なコード制約 | が生成するコードにはkernel_functionのトップレベル関数(元のモデル入力と一致)、@triton.jit 内核、自測関数(出力 PASS/0 退出コード)を含める必要があります;内核で PyTorch 計算ロジックを使用することを禁止します(自測時のみ比較を許可)。 |
| は智能的なエラー反復を行い | はコンパイル/実行エラー(stderr/stdout)をキャッチし、構造化されたプロンプトを構築して LLM が問題を特定し修正できるようにします(例:Triton の一般的な tl.broadcast の誤用)。 |
| は自動パッチを適用し | は Triton の一般的な問題に対するテキストパッチを組み込みます(例:tl.broadcast(0.0) を 0.0 に置き換える)、無意味な LLM 反復を減らします。 |
| は完全なログを保持します。 | 各ラウンドの Prompt、生成コード、検証結果を保存し、デバッグや生成プロセスの追跡を容易にします。 |
| 数値等価性の保証 | 自測関数は allclose を使用して数値を検証する必要があります(fp32: rtol≤1e-3/atol≤1e-3;fp16/bf16: ≤2e-2),Triton 実装が PyTorch の結果と一致することを確認します。 |


compose_end_to_end.py で端到端の Triton 内核を合成(compose)し、元の KernelBench 問題を解決します。
compose_end_to_end.py は、元の問題ファイル、サブ図分解情報 + 各テンソルの形状(subgraphs.json から来る)、および生成済みのサブ図の Triton カーネルを入力として受け取り、プロンプト(prompt)を LLM に送信します。その後、LLM はこれらの断片を組み合わせて、 で意味が元の問題と完全に一致する の完全なカーネルを作成し、最終的に Python ファイル(/composed_kernel.py)を返します。そのファイルには以下が提供されています:
@triton.jit 装飾された Triton カーネル。kernel_function(...) のトップレベルの Python 包装関数、これは元のモデルと同じ入力テンソルを受け取り、Triton カーネルの実行を調整し、最終出力を返します。test_kernel や run_tests)は、その関数がTritonの実装結果とオリジナルのPyTorch問題コードの参照結果を比較し、成功した場合に 'PASS' を表示して終了します。composition_summary.json)に記録されます。compose_end_to_end.py の使い方は以下の通りです:
python -m Fuser.compose_end_to_end \
--problem /abs/path/to/kernelbench_problem.py \
--subgraphs /abs/path/to/subgraphs.json \
--kernels-summary /abs/path/to/kernels_out/summary.json \
[--model gpt-5] \
[--out-dir ./compose_out] \
[--verify]
は、--verify フラグで自動検証を有効にできます。このモードでは、各LLMが生成する組み合わせの試行(初期のものでも修正されたものでも)が、Fuser/runner.py の run_candidate 関数に渡されて実行されます。
の検証結果は、実行が正常に終了(exit code 0)し、出力に含まれるかどうかによって決定されます。'PASS' 文字列または ALL_TESTS_PASSED 文字列。

LLMが生成した最初の組み合わせカーネルが検証を通過できない場合(実行/コンパイル失敗)、スクリプトはエラーメッセージ(stderr、stdout)をキャッチします。新しいプロンプト(_build_refinement_prompt)を構築し、エラーメッセージをLLMのコンテキストとして提供し、コードの修正を要求します。このプロセスは複数回繰り返すことができます(max_iters パラメータが最大反復回数を制御)、生成されたカーネルが検証を通過するか最大反復回数に達するまで。
_load_kernels_from_summary は、プロンプトの構築に 「コード素材」を提供します。(各サブ図の有効なTritonカーネルコード)、具体的には、_load_kernels_from_summaryはスケジューリングフェーズで生成されたカーネルサマリJSONファイルから、すべて成功して生成された有効なサブ図Tritonカーネルをフィルタリングして読み込み、データ形式とファイルの有効性を検証し、標準化されたKernelItemオブジェクトリストとしてラップし、LLMがカーネルを組み合わせるために直接再利用できる有効なコード資源を提供し、失敗したまたは無効なカーネル生成物をフィルタリングし、後続の組み合わせプロセスを無効な資源の干渉から保護します。
の特徴は以下の通り:
多次元有効なカーネルフィルタリング:リスト形式でないサマリデータ、辞書形式でないサブアイテム、マークされて失敗したカーネル、IDまたはカーネルパスがないサブアイテム、カーネルファイルが存在しないサブアイテムを順次フィルタリングし、すべての検証を通過した成功したカーネルのみを残し、コード資源の有効性を元から保証します;
キーフィールドの強制検証:サブグラフ ID(sid)とカーネルファイルパス(kernel_path)は必須フィールドであり、どちらかが欠けても直接フィルタリングされ、各有効なカーネルが唯一のサブグラフに関連付けられ、実際のコードファイルが存在することを保証します;
標準化されたオブジェクトのラッピング:カーネルのサブグラフ ID、ファイルパス、コード内容を KernelItem オブジェクトとしてラップし、元の辞書 / 文字列ではなく、後続のコード処理の読みやすさと保守性を向上させます;
有効なカーネルがない場合は直接終了:総合ファイルに有効なカーネルがない場合、直接 SystemExit 例外をスローしてプロセスを終了し、無効な素材がないために後続のプロセスが無意味に実行されるのを回避します;
スケジューリングフェーズの出力形式の互換性:dispatch ステップで生成された summary.json の形式に厳密に適合し、上下のプロセスの無縫な連携を実現します。
以下の三つの関数は、PyTorch KernelAgent のLLM 生成端到端 Triton 内核の Prompt 构築コアモジュールであり、LLM に標準化、高指導性、シナリオ適応の正確な入力プロンプトを提供、そして「分散されたサブグラフ / 内核 / 問題データ」と「LLM が理解できる生成指令」を結びつける重要なハブです。中で
_summarize_subgraphs_for_prompt は基本データ処理関数で、サブグラフ情報をフォーマットする責任があります;_build_composition_prompt と _build_refinement_prompt は双 Prompt 构築メイン関数で、それぞれ初回端到端内核組み合わせ生成とエラーに基づく反復修正をサポートします。二つの主要なシナリオを通じて、厳格な指示制約、完全なコンテキスト情報、ターゲット指向の最適化ガイドラインにより、LLM がエンジニアリング要件に合致し、ハードウェアに適合し、直接実行可能な Triton 内核コードを生成することを保証します。kernel_function 関数名、自己テスト関数および「PASS」の印刷 / 退出コードルールを明確に要求し、精密に修正されたコードが以降の自動検証フローに無縫に接続できるようにし、検証ロジックの変更を不要にする;_summarize_subgraphs_for_prompt は 「論理的制約」(各サブグラフの機能、形状、レイアウトなどの制約情報)を提供します。具体的には、_summarize_subgraphs_for_prompt:モデル分解後のサブグラフ情報リストを構造化・簡潔化したテキスト要約を行い、サブグラフID、タイプ、データレイアウト、データタイプ、入出力形状、コア演算子などの重要情報を抽出し、統一された形式で連結して読みやすいテキスト文字列にします。これにより、LLM組み合わせPromptの構築に標準化されたサブグラフ情報の説明を提供し、LLMが各サブグラフの機能、形状制約、計算要求を迅速に理解できるようにします。
inputsフィールドを使用して入力を記述する。inputsがない場合はinput_shapeに降格し、異なるサブグラフ分解ツールの出力形式の違いに対応して、入力出力形状情報の有効な伝達を保証する。三つの関数は「基本データ処理→初回プロンプト構築→反復プロンプト精修構築」の形で連携する。 の階層的支援関係は、LLM がエンドツーエンドの Triton カーネルを生成するための全プロセスの Prompt 支援を提供します:
_summarize_subgraphs_for_prompt は原始なサブグラフ情報に対して統一されたフォーマット処理を行い、標準化されたサブグラフ要約を生成し、上の2つの Prompt 構築関数に対して一貫性があり、解析しやすいサブグラフ制約情報を提供し、データの1回の処理と複数回の再利用を実現します;_build_composition_prompt は標準化されたサブグラフ要約に基づき、問題コード、サブグラフカーネル、プラットフォーム設定を統合し、全量の制約の組み合わせの Prompt を構築し、LLM が初めてエンドツーエンドの Triton カーネルを生成するのを指導し、「無から有へ」の核心的な指導です;_build_refinement_prompt は、標準化されたサブグラフの要約と核心的な基本情報を再利用し、エラーログと前回の失敗コードを追加して、エラーに導く精密なプロンプトを構築し、LLM が「間違ったものから正しいものへ」の反復最適化を完了するのを指導する。これはコードの有効性を向上させる鍵である;def _build_composition_prompt(
problem_code: str,
subgraphs: list[dict[str, Any]],
kernel_items: list[KernelItem],
target_platform: PlatformConfig,
) -> str:
"""Create a single user message to instruct composition by the LLM.
构建初始合成Prompt:
为LLM生成结构化指令,引导其基于子图和参考内核,合成端到端的Triton算子代码
参数:
- problem_code: 原始KernelBench问题代码(PyTorch)
- subgraphs: 子图信息列表(JSON格式)
- kernel_items: 参考内核列表(包含子图ID和对应的Triton代码)
- target_platform: 目标平台配置(如CUDA/XPU)
返回值:完整的LLM用户指令字符串
"""
# 第一步:生成子图摘要(压缩子图信息,控制Token消耗,便于LLM快速理解核心特征)
sg_summary = _summarize_subgraphs_for_prompt(subgraphs)
# 第二步:构建参考内核代码区块(仅保留核心代码,避免Token溢出)
# 注释说明:暂时保留完整文件内容,调用方可根据模型窗口限制进一步裁剪
# 初始化内核区块的文本片段列表
kernels_section_parts: list[str] = []
# 遍历每个参考内核项
for ki in kernel_items:
# 为每个子图内核构建带格式的代码片段(Markdown Python代码块)
kernels_section_parts.append(
f"### Subgraph {ki.subgraph_id}\n```python\n" + ki.code + "\n```\n"
)
# 拼接所有内核片段,形成完整的参考内核区块
kernels_section = "\n".join(kernels_section_parts)
# 第三步:获取平台专属的指导规则(如CUDA的内存访问规则、XPU的编译要求)
platform_guidance = target_platform.guidance_block
# 第四步:构建核心指导语(包含任务背景、平台信息、硬性要求、实现技巧)
# 使用textwrap.dedent去除缩进,保证Prompt格式整洁
guidance = textwrap.dedent(
f"""
You are given:
- The original problem file (PyTorch module and helpers).
- A decomposition of the model into fusable subgraphs with exact shapes.
- Working Triton kernels generated for some subgraphs.
TARGET PLATFORM: {target_platform.name}
DEVICE STRING: {target_platform.device_string}
{platform_guidance}
Task:
- Compose an end-to-end Triton implementation that matches the original
model's forward pass for the provided shapes. You may inline, adapt,
or reuse the given subgraph kernels. Prefer fusing into as few kernel
launches as possible while preserving exact numerical semantics.
Hard requirements:
- Return ONE complete Python file only, fenced as a single ```python block.
- Allocate inputs, weights, intermediates, and outputs on device='{target_platform.device_string}' and keep them there throughout forward/verification.
- CPU is acceptable only for metadata, scalars, and export serialization—avoid `.cpu()` or `.to('cpu')` on compute tensors.
- Provide at least one @triton.jit kernel and a top-level Python wrapper
named kernel_function(...). This wrapper must accept the same primary
input tensor(s) as the model and any required weights/biases with shapes
implied by the problem; it should orchestrate Triton kernel(s) and
return the final output tensor.
- No PyTorch math path: kernel_function MUST compute the final outputs
using your Triton kernels only. Do NOT implement or fall back to
torch.nn / torch.nn.functional / torch.* ops
sigmoid, etc.) for producing the final result. Using PyTorch for
reference comparisons is allowed only inside the self-test.
- Use the data layout and dtype semantics indicated by subgraphs, defaulting
to NCHW + float32 if unspecified. Respect stride/padding/dilation/groups,
and exact op order.
- Numerical equivalence: include a self-test (test_kernel or run_tests)
that compares your Triton-based result to a PyTorch reference computed
from the original problem code below (use get_init_inputs() and
get_inputs() if present to instantiate the Model). The test must print
'PASS' on success and exit with code 0. Use allclose with rtol<=1e-3,
atol<=1e-3 for fp32; for fp16/bf16 allow up to 2e-2.
- No imports beyond torch, triton, triton.language as tl, and stdlib. No I/O.
- Do NOT monkey-patch PyTorch device functions or torch.cuda.is_available()
- Do NOT manipulate TRITON_BACKENDS environment variable
- Do NOT disable or mock XPU/CUDA drivers
Implementation tips:
- If merging multiple subgraphs, ensure intermediate tensor shapes match.
- Hoist constant weights or parameters to avoid reloading per block.
- Use tl.load/tl.store with masks for boundary conditions.
- Favor coalesced memory access; tile by blocks; compute grid from shape.
- Common Triton pitfalls to avoid:
* Do NOT call tl.broadcast on Python scalars; tl.maximum(x, 0.0) works.
* Prefer scalar constants directly in elementwise ops (no explicit broadcast needed).
* Keep BLOCK_SIZE power-of-two; mask stores at tail.
"""
).strip() # 去除首尾空白字符
# 第五步:拼接完整的用户指令(按逻辑组织各部分内容)
user_lines: list[str] = []
user_lines.append(guidance) # 核心指导语
user_lines.append("") # 空行分隔
user_lines.append("SUBGRAPHS (summary):") # 子图摘要标题
user_lines.append(sg_summary) # 子图摘要内容
user_lines.append("") # 空行分隔
user_lines.append("ORIGINAL PROBLEM FILE:") # 原始问题代码标题
user_lines.append("```python") # Python代码块开始标记
user_lines.append(problem_code) # 原始问题代码内容
user_lines.append("```") # Python代码块结束标记
user_lines.append("") # 空行分隔
user_lines.append("SUBGRAPH KERNELS (reference implementations):") # 参考内核标题
user_lines.append(kernels_section) # 参考内核代码内容
user_lines.append("") # 空行分隔
# 最终要求:仅返回一个包含完整代码的Python代码块
user_lines.append(
"Return only one fenced Python code block with your final composed implementation."
)
# 拼接所有行,形成完整的Prompt
return "\n".join(user_lines)
def _build_refinement_prompt(
problem_code: str,
subgraphs: list[dict[str, Any]],
kernel_items: list[KernelItem],
previous_code: str,
error_info: dict[str, str],
target_platform: PlatformConfig,
) -> str:
"""Prompt the LLM to refine the previously produced code based on errors.
构建迭代优化Prompt:
基于上一轮代码的错误信息,引导LLM修复Triton算子代码中的编译/运行错误
参数:
- previous_code: 上一轮生成的错误代码
- error_info: 错误信息字典(包含stderr_tail/stdout_tail)
其他参数同_build_composition_prompt
返回值:针对性的优化指令字符串
"""
# 提取错误日志尾部(最后2000字符,聚焦核心错误)
err_tail = error_info.get("stderr_tail", "")
# 提取标准输出尾部(辅助分析错误)
out_tail = error_info.get("stdout_tail", "")
# 构建优化指导语(聚焦错误修复,保留原有核心要求)
guidance = textwrap.dedent(
f"""
You previously produced a composed Triton implementation, but it failed
to run/compile. Analyze the ERROR_CONTEXT below and re-emit the entire
corrected single-file implementation as one ```python block.
TARGET PLATFORM: {target_platform.name}
DEVICE STRING: {target_platform.device_string}
Requirements remain the same. Additionally:
- Fix any Triton compilation/runtime errors. For scalar constants in
elementwise ops (e.g., ReLU), do not use tl.broadcast. Use direct
scalars like 0.0 in tl.maximum(x, 0.0).
- Keep function name kernel_function(...) unchanged and retain the
self-test that prints PASS on success and exits 0.
- Do NOT reintroduce any PyTorch math path in kernel_function. The final
outputs must be computed via your Triton kernels only (no fallback to
torch.nn / torch.nn.functional ops).
- Return the complete corrected file; do not send diffs.
"""
).strip()
# 拼接完整的优化指令(按“指导语→错误信息→原始代码→子图摘要→上一轮代码”组织)
lines: list[str] = []
lines.append(guidance) # 优化指导语
lines.append("") # 空行分隔
# 添加标准错误日志(核心错误信息)
lines.append("ERROR_CONTEXT (stderr tail):\n```\n" + err_tail + "\n```")
# 若标准输出非空,添加标准输出日志(辅助分析)
if out_tail.strip():
lines.append("STDOUT tail:\n```\n" + out_tail + "\n```")
lines.append("") # 空行分隔
# 添加原始问题代码(保证上下文完整)
lines.append("ORIGINAL PROBLEM FILE:\n```python\n" + problem_code + "\n```")
lines.append("") # 空行分隔
# 添加子图摘要(避免LLM遗忘核心特征)
lines.append("SUBGRAPHS (summary):\n" + _summarize_subgraphs_for_prompt(subgraphs))
lines.append("") # 空行分隔
# 添加上一轮错误代码(让LLM对比分析问题)
lines.append("PREVIOUS_ATTEMPT:\n```python\n" + previous_code + "\n```")
lines.append("") # 空行分隔
# 最终要求:仅返回修正后的完整Python代码块
lines.append(
"Return only one fenced Python code block with the corrected implementation."
)
# 拼接所有行,形成优化Prompt
return "\n".join(lines)
def _summarize_subgraphs_for_prompt(subgraphs: list[dict[str, Any]]) -> str:
"""
生成子图摘要:将复杂的子图JSON信息压缩为简洁的文本格式
核心目标:控制Token消耗,同时保留子图的ID、类型、布局、 dtype、形状、算子等核心特征
"""
# 初始化摘要行列表
lines: list[str] = []
# 遍历每个子图
for it in subgraphs:
# 提取子图ID(默认unknown)
sid = str(it.get("id", "unknown"))
# 提取子图类型(如conv2d、linear)
typ = str(it.get("type", ""))
# 提取数据布局(默认NCHW)
layout = it.get("data_layout") or "NCHW"
# 提取数据类型(默认float32)
dtype = it.get("dtype") or "float32"
# 提取输入形状(兼容多输入/单输入格式)
inputs = it.get("inputs")
in_shape = it.get("input_shape")
# 提取输出形状
out_shape = it.get("output_shape")
# 提取算子列表
ops = it.get("ops") or []
# 构建形状描述行(优先多输入格式,其次单输入格式)
shapes_line = (
f"inputs={inputs if inputs is not None else in_shape}, output={out_shape}"
)
# 构建子图核心信息行
lines.append(
f"- ID={sid} type={typ} layout={layout} dtype={dtype} {shapes_line}"
)
# 构建算子摘要(限制长度为400字符,避免Token溢出)
try:
# 尝试转为JSON字符串(结构化)
ops_short = json.dumps(ops)[:400]
except Exception:
# 转换失败则直接转为字符串
ops_short = str(ops)[:400]
# 添加算子摘要行(缩进2空格,提升可读性)
lines.append(f" ops={ops_short}")
# 拼接所有行,形成最终的子图摘要
return "\n".join(lines)
以下の2つの関数は、PyTorch KernelAgent の 端到端 Triton 内核組合生成の核心的な実行モジュール である。
_auto_patch_common_triton_issues として 軽量自動化コード修正ツールは、LLMが生成したTritonコードの前処理として欠陥を回避するパッチを提供します;_auto_patch_common_triton_issues:LLMが生成したTritonコードに対して非侵入型テキストレベルの自動化パッチを実行し、Triton開発中に頻繁に発生しやすい基本的なエラーを修正し、コード実行/検証前に低レベルのエラーを回避し、無駄な反復を減少させ、カーネル生成効率を向上させます。composeはのトップレベルエントリーポイントとプロセススケジューリングの核として、「プロンプト構築、LLM生成、コードパッチ、実機検証、反復精修、プロダクトアーカイブ」の全プロセスを連携させ、KernelAgentが「分散されたサブグラフ/カーネル素材」から「実行可能なエンドツーエンドのTritonカーネル」へとエンドツーエンドのドライブコアとなります。は、反復回数に基づいて動的に「組み合わせ / 精査プロンプト」を構築し、LLMを呼び出してコードを生成し、生成したコードに対して自動的にパッチを適用し、実機検証駆動の複数反復による精査をサポートし、最終的にハードウェアに対応し、直接デプロイ可能な統合Tritonカーネルおよび標準化された結果のまとめを出力します。は、LLMがTritonコードを生成する際に容易に発生するの2種類の高頻度の低レベルエラーに対して自動的にテキストを修正し、コード実行前に落とし穴を回避し、基本的なエラーによる検証失敗を減少させ、反復効率を向上させることです:
tl.broadcast問題(例えばtl.broadcast(0.0, ...)を直接スカラに置き換える)0.0),Tritonのスカラーオペレーション規格に合わせる;cuda_hacks_to_stripで動的に互換性のないコードを削除し、複数のプラットフォームの拡張をサポートし、プラットフォームのルールをハードコーディングする必要がありません;_fake_torch_deviceのような関数定義を識別し、コードブロック全体をスキップするフィルタリングを実現、より徹底的な処理を実行;(patched_code, changed)タプルで結果を返し、コードが修正されたかどうかを呼び出し元に明確に伝え、後続のログ記録とプロセス監視を容易にします;は、カーネルの組み合わせによって生成される全プロセススケジューリングの核心として、入力のロードから最終製品の出力に至るまでのすべてのステップを完了する:
run_candidate は実際の機器での動作検証を行い、失敗した場合にはエラーログを抽出して LLM による精密修正にフィードバックし、「生成→パッチ→検証→エラー→精密修正」 の閉ループを形成し、検証が通過するか最大反復ラウンドに達するまで、最終カーネルの利用可能性を大幅に向上させる;verify スイッチで実際の機器検証を開始するかどうかを制御する機能をサポートしている—— 開始時に生成物の有効性を保証し、終了時に一度生成しただけで終了するため、軽量使用シナリオの効率を向上させる;max_itersを通じて最大反復ラウンド(デフォルト5ラウンド)を制御し、必要に応じて調整して生成効率と成功率のバランスを取る;last_usage)を記録し最終結果に組み込み、LLM生成コストの統計と管理をサポートする;composed_kernel.pyを直接呼び出すことができ、全プロセスのスムーズな接続を実現します;run_candidateを呼び出す際に2400秒のタイムアウトを設定し、同時に隔離実行やネットワーク禁止などの設定をサポートし、検証プロセスが環境問題で固着しないようにし、検証の安定性を向上させます。二つの関数は「事前パッチ保護 + 全プロセススケジュール実行」の緊密な協調関係を形成し、KernelAgentがエンドツーエンドのTritonカーネルを組み合わせ生成するための核心的なサポートとなります:
compose関数の前処理サブステップLLMがコードを生成した後、実際の機器で検証する前に実行し、Tritonの基本的なエラーを事前に修正し、低レベルの問題による検証失敗を減少させ、composeのイテレーションプロセス「減負」、全体の実行効率を向上させる;_auto_patch_common_triton_issues、ターゲットプラットフォームの設定を伝達し、パッチ操作をハードウェアの要件に合わせるようにし、同時にパッチ適用後のコードを記録し、後続の検証、イテレーションプロセスを進める。def _auto_patch_common_triton_issues(
code: str, target_platform: PlatformConfig
) -> tuple[str, bool]:
"""Apply tiny safe textual patches for known Triton pitfalls.
- Replace tl.broadcast(0.0, ...) or tl.broadcast(1.0, ...) with scalar constants.
Returns (patched_code, changed).
自动修复Triton算子代码中的常见问题:
- 核心修复点:将tl.broadcast(0.0/1.0/0/1, ...)替换为标量常量(避免Triton广播操作的性能/语法问题)
- 返回值:(修复后的代码, 是否发生修改)
"""
# 初始化修复后的代码为原始代码
patched = code
# 标记是否发生修改(默认未修改)
changed = False
# 修复规则:采用保守的简单启发式规则,仅处理无歧义的常见问题
patterns = [
# 规则1:tl.broadcast(0.0 → 替换为0.0(移除不必要的广播操作)
("tl.broadcast(0.0", "0.0"),
# 规则2:tl.broadcast(1.0 → 替换为1.0
("tl.broadcast(1.0", "1.0"),
# 规则3:tl.broadcast(0, → 替换为0.0(统一数值类型为浮点数)
("tl.broadcast(0,", "0.0"),
# 规则4:tl.broadcast(1, → 替换为1.0
("tl.broadcast(1,", "1.0"),
]
# 遍历所有修复规则
for old, new in patterns:
# 若原始代码包含待修复的模式
if old in patched:
# 执行文本替换
patched = patched.replace(old, new)
# 标记为已修改
changed = True
# 移除CUDA相关的冗余hack代码(平台适配)
# 获取当前目标平台需要剥离的CUDA hack模式列表
cuda_hacks = target_platform.cuda_hacks_to_strip
if cuda_hacks:
# 将代码按行拆分
lines = patched.split("\n")
# 存储过滤后的代码行
filtered_lines = []
# 标记是否需要跳过直到空行(用于移除整个函数块)
skip_until_blank = False
# 逐行处理代码
for line in lines:
# 若处于跳过模式:跳过当前行,直到遇到空行
if skip_until_blank:
if line.strip() == "":
# 遇到空行,退出跳过模式
skip_until_blank = False
continue
# 检查当前行是否包含需要剥离的CUDA hack模式
if any(hack in line for hack in cuda_hacks):
# 标记为已修改
changed = True
# 特殊处理:若为_fake_torch_device函数定义,需跳过整个函数块(直到空行)
if "def _fake_torch_device" in line:
skip_until_blank = True
# 跳过当前行(移除hack代码)
continue
# 保留当前行
filtered_lines.append(line)
# 重新拼接过滤后的代码行
patched = "\n".join(filtered_lines)
# 返回修复后的代码和是否修改的标记
return patched, changed
def compose(
problem_path: Path,
subgraphs_path: Path,
kernels_summary_path: Path,
out_dir: Path,
model_name: str,
verify: bool = False,
max_iters: int = 5,
target_platform: str = "cuda",
) -> dict[str, Any]:
"""
核心函数:基于子图信息和内核摘要,合成/优化Triton算子代码
参数说明:
- problem_path: KernelBench问题文件路径
- subgraphs_path: 子图信息JSON文件路径(由之前的子图提取模块生成)
- kernels_summary_path: 内核摘要文件路径
- out_dir: 输出目录路径
- model_name: LLM模型名称
- verify: 是否验证生成的代码(默认False)
- max_iters: 最大迭代次数(默认5)
- target_platform: 目标平台(默认cuda)
返回值:包含合成结果的字典(成功状态、代码路径、迭代次数等)
"""
# 前置检查:确保LLM提供商模块可导入
if get_model_provider is None:
raise SystemExit(
"KernelAgent providers unavailable; ensure package import and dependencies"
)
# 创建输出目录(递归创建父目录,已存在则忽略)
out_dir.mkdir(parents=True, exist_ok=True)
# 获取LLM提供商实例(用于调用LLM生成代码)
provider = get_model_provider(model_name)
# 初始化目标平台配置(加载平台相关的修复规则、hack列表等)
platform = get_platform(target_platform)
# 加载输入文件:
# 1. 加载KernelBench问题代码(原始PyTorch代码)
problem_code = _read_text(problem_path)
# 2. 加载子图信息JSON(由子图提取模块生成)
subgraphs = json.loads(_read_text(subgraphs_path))
# 检查子图信息是否为列表:非列表则退出
if not isinstance(subgraphs, list):
raise SystemExit("subgraphs.json must be a JSON array")
# 3. 从内核摘要文件加载内核信息
kernels = _load_kernels_from_summary(kernels_summary_path)
# 创建迭代尝试目录(存储每一轮的生成代码)
attempts_dir = out_dir / "attempts"
attempts_dir.mkdir(parents=True, exist_ok=True)
# 初始化变量:
# - last_usage: 最后一次LLM调用的token使用信息
# - last_code: 上一轮生成的代码
# - verify_info: 验证信息字典
last_usage = None
last_code = None
verify_info: dict[str, Any] = {}
# 多轮迭代生成/优化代码(最多max_iters轮)
for i in range(1, max_iters + 1):
# 构建LLM Prompt:
# 第一轮/上一轮代码为空 → 构建初始合成Prompt
if i == 1 or last_code is None:
prompt = _build_composition_prompt(
problem_code, subgraphs, kernels, target_platform=platform
)
else:
# 非第一轮 → 基于上一轮的错误信息构建优化Prompt
# 提取验证日志的最后2000字符(便于LLM定位问题)
stderr_tail = ""
stdout_tail = ""
try:
# 读取标准错误日志尾部
if verify_info.get("stderr_path"):
with open(
verify_info["stderr_path"],
"r",
encoding="utf-8",
errors="ignore",
) as f:
stderr_tail = f.read()[-2000:]
# 读取标准输出日志尾部
if verify_info.get("stdout_path"):
with open(
verify_info["stdout_path"],
"r",
encoding="utf-8",
errors="ignore",
) as f:
stdout_tail = f.read()[-2000:]
except Exception:
# 日志读取失败则忽略(避免中断迭代)
pass
# 构建优化Prompt(包含上一轮代码和错误信息)
prompt = _build_refinement_prompt(
problem_code,
subgraphs,
kernels,
previous_code=last_code,
error_info={"stderr_tail": stderr_tail, "stdout_tail": stdout_tail},
target_platform=platform,
)
# 保存当前轮次的Prompt(便于追溯和调试)
(attempts_dir / f"attempt_{i}.prompt.txt").write_text(prompt, encoding="utf-8")
# 调用LLM生成代码:
# - 消息格式:仅包含user角色的Prompt
# - 最大生成Token数:50000(适配长代码生成)
response = provider.get_response(
model_name, [{"role": "user", "content": prompt}], max_tokens=50000
)
# 记录最后一次LLM调用的token使用信息
last_usage = response.usage
# 提取LLM输出的原始文本
raw_text = response.content or ""
# 从LLM输出中提取Python代码(剥离无关文本)
extracted = extract_single_python_file(raw_text)
code = extracted.code
# 自动修复Triton常见问题(文本补丁)
code, changed = _auto_patch_common_triton_issues(code, platform)
# 保存当前轮次的生成代码
(attempts_dir / f"attempt_{i}.py").write_text(code, encoding="utf-8")
# 更新last_code为当前轮次的代码
last_code = code
# 验证当前轮次的代码(若开启验证)
if verify:
# 运行候选代码验证:
# - artifacts_code_path: 当前轮次的代码路径
# - run_root: 验证运行目录
# - timeout_s: 验证超时时间(2400秒)
# - isolated: 非隔离模式
# - deny_network: 允许网络访问
rr = run_candidate(
artifacts_code_path=attempts_dir / f"attempt_{i}.py",
run_root=out_dir / "runs",
timeout_s=2400,
isolated=False,
deny_network=False,
)
# 记录验证信息
verify_info = {
"verify_rc": rr.rc, # 验证退出码
"verify_passed": rr.passed, # 是否验证通过
"verify_reason": rr.reason, # 验证结果原因
"validator": rr.validator_used, # 使用的验证器
"stdout_path": str(rr.stdout_path), # 标准输出日志路径
"stderr_path": str(rr.stderr_path), # 标准错误日志路径
}
# 若验证通过,终止迭代(无需继续优化)
if rr.passed:
break
else:
# 若未开启验证,仅执行第一轮后终止
break
# 保存最终合成的算子代码(取最后一轮的代码)
composed_path = out_dir / "composed_kernel.py"
composed_path.write_text(last_code or "", encoding="utf-8")
# 构建结果字典:包含核心合成信息
result: dict[str, Any] = {
# 成功状态:验证模式下为是否通过验证,非验证模式下默认成功
"success": bool(verify_info.get("verify_passed", not verify)),
# 最终合成代码的绝对路径
"composed_path": str(composed_path.resolve()),
# 使用的LLM模型名称
"model": model_name,
# LLM token使用信息
"usage": last_usage,
# 实际迭代次数
"rounds": i,
# 目标平台
"target_platform": target_platform,
}
# 合并验证信息到结果字典
result.update(verify_info)
# 保存合成摘要(结构化JSON文件)
(out_dir / "composition_summary.json").write_text(
json.dumps(result, indent=2), encoding="utf-8"
)
# 返回结果字典
return result
このコンテンツは慣性聚合(RSSリーダー)によって自動集約されています。参考としてご覧ください。 原文出典 — 著作権は原著者に帰属します。