인셔셔RSS 관심 있는 블로그, 뉴스, 기술 정보를 효율적으로 추적하고 읽으세요
원문 읽기 InertiaRSS에서 열기

추천 피드

V
V2EX
博客园 - 叶小钗
Y
Y Combinator Blog
大猫的无限游戏
大猫的无限游戏
博客园 - 【当耐特】
酷 壳 – CoolShell
酷 壳 – CoolShell
D
Docker
WordPress大学
WordPress大学
Blog — PlanetScale
Blog — PlanetScale
博客园 - Franky
G
Google Developers Blog
爱范儿
爱范儿
Google DeepMind News
Google DeepMind News
Stack Overflow Blog
Stack Overflow Blog
云风的 BLOG
云风的 BLOG
Engineering at Meta
Engineering at Meta
aimingoo的专栏
aimingoo的专栏
V
Visual Studio Blog
M
MIT News - Artificial intelligence
Hugging Face - Blog
Hugging Face - Blog

博客园_首页

Plist 二进制格式 第30篇文章:一个大三计科生的自白 Manim如何在数学公式中完美显示中文? Docker 部署 RocketMQ 5 并发编程核心概念辨析 C#事务处理最佳实践:别再让“主表存了、明细丢了”的破事发生 CLI 是什么?为什么大厂突然集体卷命令行? 【从0到1构建一个ClaudeAgent】协作-自主Agent # linux红帽教程-手把手教学 UIImageView 设置图片不生效的原因排查 .NET生态下Native AOT兼容的Cron任务调度框架 Python 潮流周刊#147:Python 和 Ruby 的 JIT 故事 - 豌豆花下猫 可持久化线段树/主席树 学习笔记 如何实现 Claude Code 和 Codex 等 Agent CLI 的自动重试 - Newbe36524 WebSocket 连接池生产级实现:实时行情高可用与负载均衡 - Walter先生 关于代码注释的思考 MicroPython对接大模型:uopenai + 火山方舟实现文字聊天和图片理解 从词向量到大模型:NLP 技术演进浅记 LangChain使用deep agent并且加载SKILL 零成本打造专业域名邮箱:Cloudflare + Gmail 终极配置保姆级全攻略 【从0到1构建一个ClaudeAgent】协作-团队协议 - 程序员Seven 最小二乘问题详解20:无先验约束下的增量式SFM自由网平差 痞子衡嵌入式:大话双核i.MXRT1180之XIP应用里实现可靠Flash IAP的方法 AI Chat 封装, SemanticKerne.AiProvider.Unified 已发布 Windows下右键编辑js文件无法打开记事本——在注册表中使用环境变量 在后台服务中使用 Scoped 服务,为什么总是报错? H200 安装驱动并使用sglang启动模型 wireshark 抓包Trap上报告警内容 我用 AI 辅助开发了一系列小工具(2):图片压缩工具 [A Primer On MC and CC] 2.1 Memory Consistency 1 - 指令重排序和 SC 模型 Oracle数据库SCN推进技术详解与实践指南 玩转控件:封装个带图片的Label控件 Claude Code 4.7 真正该升级的不是模型,而是你的工作流 我用AI写了一个颜值拉满的桌面媒体播放器,全程没动一行代码,这就是AI编程新范式 5. WorkBuddy: 小龙虾的灵魂三件套,让你的小龙虾不只是工具 SQLite 分片方案实战:三种分片策略的深度对比 告别简陋 UI!一款基于 Fluent Design 和基于 WinUI 的开源免费、现代化的 Avalonia UI 控件库 关于二进制排列组合枚举的总结 AI开发-python-LangGraph框架(3-27-LangGraph从零实现大模型智能决策工作流) ElasticSearch主分片和副本分片概念详解 【002】HTTPS 粗解:证书、TLS 握手与对后端配置的影响 Hermes Agent 一周暴涨五万 Star,但我劝你别急着追 一个面向产品化的 Electron + Vue 3 桌面应用脚手架 明明连接的是Redis的DB0,为什么能查到DB3的数据? 【从0到1构建一个ClaudeAgent】协作-Agent团队 熟悉电子元器件之后,电子小白下一步该怎么走? MAF快速入门(23)通过C#类定义Skills .NET 高级开发 | 手写一个对象映射框架 FastAPI数据库ORM怎么选?我肝了三个Demo后,终于不再纠结了 mysqldump 参数拾遗:在遗忘与铭记之间
PyTorch KernelAgent 소스 코드 해석 ---(6)--- Composer
罗西的思考 · 2026-05-22 · via 博客园_首页

PyTorch KernelAgent 소스 코드 분석 ---(6)--- Composer

0x00 요약,

Fuser/compose_end_to_end.py는 Fuser 파이프라인의 마지막 중요한 단계로, 특정 서브그래프에 대한 분산된 Triton 커널을 원활하게 통합하여 단일하고 고성능의 엔드투엔드 Triton 커널을 만들어내며, 원래 PyTorch 구현의 수치적 동등성을 보장합니다.

Composer의 아키텍처 다이어그램은 다음과 같으며, 그 기능 요약은 한 문장으로 표현하면 다음과 같다: 모든 검증된 서브그래프 커널과 원본 질문을 LLM에 넣어 단일 파일로 만든다. 또한 오류 로그를 LLM에 수정하며, 최대 max_iters 횟수까지 재시도하고 자동으로 흔한 Triton 함정을 패치한다.

composer

0x01 핵심 기능

1.1 핵심 역할

Composer의 핵심 기능은 다음과 같다:

  • 서브그래프 커널 통합: Fuser 프로세스에서 분리된 서브그래프와 해당 검증된 Triton 커널을 다시 통합하여 단일 파일의 엔드투엔드 Triton 구현을 만들어, 원본 PyTorch 코드의 전방향 전파 논리를 대체한다.
  • LLM 주도의 코드 생성:원본 질문 코드, 서브플롯 정보, 서브플롯 Triton 커널을 입력으로 하여 맞춤형 프롬프트를 통해 LLM을 호출하여 완전한 Triton 커널 코드를 생성합니다.
  • 기능 검증 및 반복 최적화 :생성된 커널을 자동으로 검증(파이토치 참조 결과와 비교)하며, 검증 실패 시 오류 정보를 기반으로 LLM을 반복적으로 호출하여 코드를 수정하여 검증을 통과하거나 최대 반복 횟수에 도달할 때까지 반복합니다.
  • 제약 보장 :엄격한 코드 규범(예: kernel_function를 포함해야 함, 파이토치 계산 로직을 금지) 및 숫자 검증을 통해 생성된 Triton 커널이 사용 가능하고 올바른지 보장합니다.

1.2 핵심 특징

특징 방향 구체적 설명
엄격한 코드 제약 는 생성된 코드에 kernel_function 최상위 함수(원본 모델 입력과 일치)、@triton.jit 내核、자체 테스트 함수(출력 PASS/0 종료 코드)를 포함하도록 강제합니다; 내核에서는 PyTorch 계산 로직을 사용하지 않도록 금지합니다(자체 테스트 시 비교만 허용).
는 지능적인 오류 반복 는 컴파일/실행 오류(stderr/stdout)를 포착하고, LLM이 문제를 찾아 수정하도록 정밀한 프롬프트를 구축합니다(예: Triton의 흔한 tl.broadcast 잘못 사용).
는 자동 패치로 수정 는 내장 Triton 흔한 문제의 텍스트 패치를 제공(예: tl.broadcast(0.0)를 0.0으로 바꾸기), LLM 반복을 줄입니다.
는 완전한 로그 보관 각 라운드의 Prompt, 생성된 코드, 검증 결과를 보존하여 디버깅 및 생성 과정 추적을 용이하게 합니다.
값의 동등성 보장 자체 검사 함수는 allclose 검증을 사용하여 값( fp32: rtol≤1e-3/atol≤1e-3; fp16/bf16: ≤2e-2)을 검증하여 Triton 구현이 PyTorch 결과와 일치하는지 확인해야 합니다.

1.3 프로세스 다이어그램

핵심 논리 관계 다이어그램

compose_end_to_end.py.逻辑关系图

완전한 실행 프로세스 다이어그램

compose_end_to_end.py.流程图

0x02 상세 기능

2.1 사용

compose_end_to_end.py를 사용하여 원래 KernelBench 문제를 해결하기 위해 Triton 내核을 엔드 투 엔드(compose)합니다.

compose_end_to_end.py는 원본 문제 파일, 서브 그래프 분해 정보 + 각 텐서 형태(来自 subgraphs.json), 그리고 생성된 서브 그래프 Triton 커널을 입력으로 받아, 프롬프트(prompt)를 LLM에게 전송합니다. 그런 다음 LLM은 이 조각들을 하나로 합쳐 의미와 원본 문제와 완전히 일치하는 의 완전한 커널을 만듭니다. 최종적으로 /composed_kernel.py 라는 이름의 파이썬 파일을 반환하며, 이 파일에는 다음을 제공합니다:

  • 하나 이상의 @triton.jit 를 사용하여 꾸며진 Triton 커널.
  • 이름이 kernel_function(...) 인 최상위 파이썬 래퍼 함수, 이 함수는 원본 모델과 동일한 입력 텐서를 받아들이고 Triton 커널의 실행을 조율하며 최종 출력을 반환합니다.
  • 자체 테스트 함수(예:test_kernel 또는 run_tests), 해당 함수는 Triton 구현 결과와 원래 PyTorch 문제 코드의 참조 결과를 비교하고, 성공 시 'PASS'를 출력하고 종료합니다.
  • 생성 과정의 메타데이터와 검증 결과는 JSON 형식의 요약 파일에 기록됩니다 (예: composition_summary.json).

compose_end_to_end.py 사용법은 다음과 같습니다:

python -m Fuser.compose_end_to_end \
    --problem /abs/path/to/kernelbench_problem.py \
    --subgraphs /abs/path/to/subgraphs.json \
    --kernels-summary /abs/path/to/kernels_out/summary.json \
    [--model gpt-5] \
    [--out-dir ./compose_out] \
    [--verify]

--verify 플래그를 통해 자동 검증을 활성화할 수 있습니다. 이 모드에서는 초기 또는 수정된 각 LLM 생성 조합 시도가 Fuser/runner.py의 run_candidate 함수에 전달되어 실행됩니다.

검증의 성공 여부는 실행이 정상적으로 종료되었는지 (exit code 0) 및 출력에 포함되는지에 따라 결정됩니다.'PASS' 문자열 또는 ALL_TESTS_PASSED 문자열입니다.

compose_end_to_end_new

2.2 오류 처리 및 반복

LLM이 생성한 첫 번째 조합 커널이 검증을 통과하지 못하면(실행/컴파일 실패), 스크립트는 오류 메시지(stderr, stdout)를 캡처합니다. 새로운 프롬프트(_build_refinement_prompt)를 구축하고, 오류 메시지를 LLM의 맥락으로 제공하여 코드를 수정하도록 요청합니다. 이 과정은 여러 번 반복될 수 있습니다(max_iters 매개변수가 최대 반복 횟수를 제어), 생성된 커널이 검증을 통과하거나 최대 반복 횟수에 도달할 때까지 계속됩니다.

2.3 _load_kernels_from_summary

_load_kernels_from_summary 프롬프트를 구축하기 위해 「코드 소재」를 제공합니다.(각 서브플롯의 유효한 Triton 내核 코드),구체적으로, _load_kernels_from_summary는 스케줄링 단계에서 생성된 내核 요약 JSON 파일에서 모든 성공적으로 생성된 유효한 서브플롯 Triton 내核을 필터링하고 로드하며, 데이터 형식 및 파일 유효성을 검증하고 표준화된 KernelItem 객체 목록으로 포장하여 LLM 조합 내核을 위해 직접 재사용할 수 있는 유효한 코드 자료를 제공하며, 실패한 또는 유효하지 않은 내核 제품을 필터링하여 무효한 자료가 후속 조합 프로세스를 방해하지 않도록 합니다.

의 특징은 다음과 같습니다:

  • 다차원 유효한 내核 필터링: 순차적으로「리스트 형식 요약 데이터가 아닌 것, 사전 형식 서브 항목이 아닌 것, 실패로 표시된 내核, ID 없음 / 내核 경로 없는 서브 항목, 내核 파일이 존재하지 않는 서브 항목」을 필터링하고, 전체 검증을 통과한 성공적인 내核만 유지하여 코드 자료의 유효성을 원천에서 보장합니다;

  • 핵심 필드 강제 검증:자사 ID(sid)와 커널 파일 경로(kernel_path)는 필수 필드로, 누락 시 즉시 필터링하여 각 유효한 커널이 유일한 자사와 실제 코드 파일에 연결되도록 보장합니다.

  • 표준화된 객체 포장:커널의 자사 ID, 파일 경로, 코드 내용을 KernelItem 객체로 포장하여 원시 사전 / 문자열이 아닌 것으로, 후속 코드 처리의 가독성과 유지보수성을 향상시킵니다.

  • 유효한 커널 없으면 즉시 종료:요약 파일에 유효한 커널이 없으면 즉시 SystemExit 예외를 발생시켜 프로세스를 종료하여 후속 프로세스가 유효한 자료 없이 무의미하게 실행되는 것을 방지합니다.

  • 호환 스케줄링 단계 출력 형식:dispatch 단계에서 생성된 summary.json 형식을 엄격하게 적합하여 상하위 프로세스의 무缝 연결을 구현합니다.

2.4 Prompt

아래 세 개의 함수는 PyTorch KernelAgent의 LLM 생성 단말 Triton 커널의 Prompt 구축 핵심 모듈 로서, LLM에 표준화되고, 높은 지도성, 시나리오에 맞는 정확한 입력 힌트를 제공 하며,「분산된 서브그래프 / 커널 / 문제 데이터」와「LLM이 이해할 수 있는 생성 명령」을 연결하는 핵심 허브입니다. 그 중

  • _summarize_subgraphs_for_prompt는 기본 데이터 처리 함수로서 서브그래프 정보를 형식화하는 역할을 합니다.
  • _build_composition_prompt_build_refinement_prompt는 이중 Prompt 구축 주 함수로서, 각각 초기 단말 커널 조합 생성 오류 기반 반복적 정제를 지원합니다.두 가지 핵심 시나리오를 통해 엄격한 지시 제약, 완전한 문맥 정보, 대상적 최적화 지침을 통해 LLM이 엔지니어링 요구사항, 하드웨어 적합성, 직접 실행 가능한 Triton 핵 코드를 생성하도록 보장합니다.

핵심 역할

  1. 프롬프트용 서브그래프 요약기본 지지 함수로 모델 분해 후의 서브 그래프 정보 리스트를구조화되고 간결하게 텍스트 요약으로 변환, 서브그래프 ID, 타입, 데이터 레이아웃, 데이터 타입, 입력/출력 모양, 핵심 연산자 등 핵심 제약 정보를 추출하여 통일된 형식으로 연결하여 읽기 쉬운 문자열로 만들고, 두 개의 Prompt 구성 함수에 표준화된 서브그래프 정보 설명을 제공하여 LLM이 각 서브그래프의 기능과 계산 제약을 빠르게 이해하도록 합니다.
  2. _건축물_프롬프트_생성LLM을 처음으로 생성하여 구축하다 전체 컨텍스트와 강한 제약 조건의 조합형 프롬프트는 원시 PyTorch 문제 코드, 서브그래프 정보 요약, 각 서브그래프의 유효한 Triton 커널 코드, 목표 하드웨어 플랫폼 구성의 네 가지 핵심 정보를 통합합니다. LLM의 핵심 작업은 서브그래프 커널을 융합하여 엔드투엔드 Triton 구현을 생성하는 것을 명확히 하고, 엄격한 엔지니어링 요구 사항, 하드웨어 적합성 규칙, Triton 개발 규범을 제정하여 LLM이「산발된 서브그래프」에서「일관된 커널」로의 조합과 최적화를 수행하도록 지시합니다.
  3. _build_refinement_prompt:LLM 반복 최적화를 위해오류 경향성, 대상적 정제형 프롬프트를 구축합니다는 핵심 기본 정보를 유지하면서 이전에 생성된 오류 로그(stdout/stderr)와 이전 라운드에 실패한 코드 구현을 추가하고, LLM의 핵심 작업이 오류 정보를 기반으로 문제를 식별하고 코드를 수정하는 것임을 명확히 하며, 더 구체적인 오류 수정 요구 사항을 추가하여 정교화된 코드가 컴파일 및 실행 문제를 해결할 수 있으며 원래 엔지니어링 규범을 위반하지 않도록 합니다.

_build_composition_prompt 전용 특징

  1. 네 가지 핵심 정보 전량 통합:완전히「원본 문제 코드(요구 기준), 서브 그래프 요약(논리 제약), 서브 그래프 핵심(코드 자료), 플랫폼 구성(하드웨어 규칙)」를 통합하여 LLM이「무엇을 해야 하는지(PyTorch 모델 기능)」를 이해하고,「무엇을 사용할 수 있는지(서브 그래프 핵심)」를 알고,「어떻게 맞춰야 하는지(하드웨어 플랫폼)」를 명확히 합니다.
  2. 핵심 철학 명확화: 융합과 융합:LLM이 여러 자식 그래프를 가능한 한 적은 Triton 내核 시작으로 통합하도록 명확히 요구하며, 수치적 의미의 정확성을 보장하는 동시에 실행 효율성을 향상시켜야 하며, Triton 내核의「큰 크기 통합, 내核 시작 비용 감소」의 최적화 핵심 철학과 일치해야 합니다;
  3. 매우 상세한 엄격한 요구사항:10개 이상의 반드시 준수해야 할 엄격한 규칙을 제정하며, 코드 출력 형식, 장치 텐서 관리, 함수 이름과 래퍼싱, 계산 인터페이스 제한(PyTorch 계산 사용 금지), 데이터 형식과 연산자 순서, 수치 검증 요구사항, 가져오기 및 동작 제한 등을 포함하여 코드 생성의 원천에서 규범화해야 합니다;
  4. 실용적인 Triton 개발 가이드:대상적인 Triton 구현 기술을 제공하며, 자식 그래프 통합의 모양 일치, 상수 가중치 최적화, 메모리 접근 규범(tl.load/tl.store 마스크 사용), 그리드와 블록 설계를 포함하며, 동시에 Triton의 흔한 개발 함정 및 회피 방법을 명시하여 LLM이 잘못된 코드를 생성할 가능성을 낮추어야 합니다;
  5. 수치적 동치성 강제 요구사항:생성되는 코드는 반드시 자가 테스트 함수를 포함해야 하며, PyTorch 참조 결과와의 비교를 통해 수치적 정확성을 검증하고, 엄격한 오차 허용 범위를 설정( fp32/_fp16/bf16 구분)하여 Triton 커널의 기능 정확성과 수치적 정확성을 보장해야 합니다.

_build_refinement_prompt 전용 특징

  1. 오류 경향의 정밀 수정:이전에 생성된 stderr/stdout 오류 로그를 핵심 참조로 사용하여 LLM이「문제 위치→문제 수정」에 집중하게 하여, 목적 없는 재생성을 방지하고 반복적 최적화의 효율성과 대상성을 크게 향상시켜야 합니다.
  2. 기존 제약을 유지하고 수정 요구 사항 추가:「기존 모든 요구 사항은 변경되지 않음」을 명확히 하고, 오류 시나리오에 대해 더 구체적인 수정 규칙을 추가(예: tl.broadcast의 스칼라 남용 금지)하여 수정된 코드가 기존 엔지니어링 규범을 위반하지 않으면서도 구체적인 문제를 해결하도록 합니다.
  3. 전체 코드 재생성, 차이점 출력 거부:LLM이 완전한 수정된 코드를 반환하도록 강제 요구, 코드 차이점 또는 수정 제안이 아닌, 후속 코드 연결 및 통합의 추가 작업을 피하고 출력을 직접 대체하여 사용할 수 있도록 보장;
  4. 핵심 식별자 강제 보존:상위 kernel_function 함수 이름, 자체 테스트 함수 및「PASS」프린트 / 종료 코드 규칙을 명확히 요구하여 보수된 코드가 후속 자동 검증 프로세스에 무缝 연결될 수 있도록 하여 검증 논리 수정 없이 사용할 수 있도록 보장;
  5. 실패 코드 참조, 문제 해결 더 효율적:이전 라운드의 실패 코드를 완전히 Prompt에 통합하여 LLM이 직접 오류 로그와 코드 구현을 비교할 수 있도록 하여 빠르게 문제의 원인을 파악할 수 있게 하고(예: 컴파일 오류 줄 번호, 실행 오류 논리), 수정의 정확성을 높임.

_summarize_subgraphs_for_prompt 전용 특징

_summarize_subgraphs_for_prompt「논리적 제약 조건」(각 하위 그래프의 기능, 모양, 레이아웃 등 제약 정보)를 제공합니다. 구체적으로, _summarize_subgraphs_for_prompt: 모델 분해 후의 하위 그래프 정보 리스트를 구조화하고 간결하게 텍스트 요약, 하위 그래프 ID, 유형, 데이터 레이아웃, 데이터 유형, 입력 출력 모양, 핵심 연산자 등 핵심 정보를 추출하여 통일된 형식으로 연결하여 읽기 쉬운 텍스트 문자열로 만들어, LLM 조합 Prompt 구성을 위한 표준화된 하위 그래프 정보 설명을 제공하여 LLM이 각 하위 그래프의 기능, 모양 제약 조건 및 계산 요구 사항을 빠르게 이해하도록 합니다.

  1. 핵심 정보 정확하게 추출, 불필요한 정보 제거:LLM 조합만 유지 / 정제된 핵심 제약 정보(아이디, 타입, 레이아웃, 데이터 타입, 입력/출력 형상, 핵심 연산자)를 제외하고 불필요한 정보를 제거하여 프롬프트 토큰 사용량을 줄이고 LLM 처리 효율을 향상시킵니다;
  2. 적절한 기본값을 보장하여 안정성을 확보합니다:데이터 레이아웃(기본값 NCHW), 데이터 타입(기본값 float32), 연산자 목록(기본값 빈 목록)과 같은 쉽게 누락될 수 있는 필드에 적절한 기본값을 설정하여 필드 누락으로 인한 프롬프트 구성 실패를 방지하고 프로세스의 오류 허용성을 향상시킵니다;
  3. 계층적이고 조밀한 형식, 읽기 쉽고 파싱 용이:「일차 행 주석 서브그래프 기본 속성 + 이차 행 주석 핵심 연산자」의 계층적 형식을 채택하여 정보가 조밀하게 유지될 수 있도록(프롬프트 길이 제한에 맞춤) 동시에 구조가 명확하여 LLM이 서브그래프 ID와 해당 기능 및 제약을 빠르게 연관시킬 수 있도록 합니다;
  4. 연산자 정보 길이 제어, 제한 초과 방지:연산자 목록을 직렬화한 후 앞 400 문자를 잘라내어 연산자가 많아서 프롬프트가 LLM의 컨텍스트 창을 초과하는 것을 방지하고, JSON 직렬화 실패 시 경우에도 직접 문자열 잘라내기를 보조하여 정보의 완전성을 보장합니다;
  5. 형상 정보 유연하게 적합inputs 필드를 우선 사용하여 입력을 설명하고, inputs이 없으면 input_shape으로 격하하여 다양한 서브그래프 분해 도구의 출력 형식 차이에 적합하게 하여 입력 및 출력 형상 정보의 효과적인 전달을 보장합니다.

협력 작업 관계

세 가지 함수는 「기본 데이터 처리→초기 프롬프트 구축→반복적 정제 프롬프트 구축」의 순서로 형성합니다.의 계층적 지지 관계는 LLM이 엔드투엔드 Triton 핵심을 생성하기 위해 전 과정의 Prompt 지지를 제공합니다:

  1. 기반 레이어: _summarize_subgraphs_for_prompt는 원시 서브그래프 정보에 대해 통일된 형식화 처리를 수행하여 표준화된 서브그래프 요약을 생성하고, 상위 두 개의 Prompt 구축 함수에 일관되고 쉽게 해석할 수 있는 서브그래프 제약 정보를 제공하여 데이터의 한 번의 처리와 여러 번의 재사용을 실현합니다;
  2. 초기 생성 레이어: _build_composition_prompt는 표준화된 서브그래프 요약을 기반으로 문제 코드, 서브그래프 핵심, 플랫폼 구성을 통합하여 전체 제약 조건의 조합 Prompt를 구축하고, LLM이 엔드투엔드 Triton 핵심을 처음 생성하는 데 도움을 주어「없음에서 있음으로」의 핵심 지시를 제공합니다;
  3. 반복 정제 레이어:_build_refinement_prompt 표준화된 서브그래프 요약과 핵심 기본 정보를 재사용하고, 오류 로그와 이전 실패 코드를 추가하여 오류 방향성을 가진 정교한 Prompt를 구축하여 LLM이「잘못된 것에서 올바른 것으로」의 반복적 최적화를 수행하도록 지시하는 것은 코드의 효과성을 향상시키는 핵심 요소입니다;
  4. 순환 지원: 세 가지 함수의 출력이 KernelAgent의「생성→검증→정교화→재검증」순환 프로세스를 공동으로 지원하여, 각 라운드의 LLM 생성에 명확한 지시, 충분한 근거, 엄격한 제약을 보장하여 엔드투엔드 Triton 핵심의 생성 성공률과 엔지니어링 품질을 크게 향상시킵니다.

코드

def _build_composition_prompt(
    problem_code: str,
    subgraphs: list[dict[str, Any]],
    kernel_items: list[KernelItem],
    target_platform: PlatformConfig,
) -> str:
    """Create a single user message to instruct composition by the LLM.
    构建初始合成Prompt:
    为LLM生成结构化指令,引导其基于子图和参考内核,合成端到端的Triton算子代码
    参数:
    - problem_code: 原始KernelBench问题代码(PyTorch)
    - subgraphs: 子图信息列表(JSON格式)
    - kernel_items: 参考内核列表(包含子图ID和对应的Triton代码)
    - target_platform: 目标平台配置(如CUDA/XPU)
    返回值:完整的LLM用户指令字符串
    """
    # 第一步:生成子图摘要(压缩子图信息,控制Token消耗,便于LLM快速理解核心特征)
    sg_summary = _summarize_subgraphs_for_prompt(subgraphs)

    # 第二步:构建参考内核代码区块(仅保留核心代码,避免Token溢出)
    # 注释说明:暂时保留完整文件内容,调用方可根据模型窗口限制进一步裁剪
    # 初始化内核区块的文本片段列表
    kernels_section_parts: list[str] = []
    # 遍历每个参考内核项
    for ki in kernel_items:
        # 为每个子图内核构建带格式的代码片段(Markdown Python代码块)
        kernels_section_parts.append(
            f"### Subgraph {ki.subgraph_id}\n```python\n" + ki.code + "\n```\n"
        )
    # 拼接所有内核片段,形成完整的参考内核区块
    kernels_section = "\n".join(kernels_section_parts)
    
    # 第三步:获取平台专属的指导规则(如CUDA的内存访问规则、XPU的编译要求)
    platform_guidance = target_platform.guidance_block

    # 第四步:构建核心指导语(包含任务背景、平台信息、硬性要求、实现技巧)
    # 使用textwrap.dedent去除缩进,保证Prompt格式整洁
    guidance = textwrap.dedent(
        f"""
        You are given:
        - The original problem file (PyTorch module and helpers).
        - A decomposition of the model into fusable subgraphs with exact shapes.
        - Working Triton kernels generated for some subgraphs.

        TARGET PLATFORM: {target_platform.name}
        DEVICE STRING: {target_platform.device_string}
        {platform_guidance}

        Task:
        - Compose an end-to-end Triton implementation that matches the original
          model's forward pass for the provided shapes. You may inline, adapt,
          or reuse the given subgraph kernels. Prefer fusing into as few kernel
          launches as possible while preserving exact numerical semantics.

        Hard requirements:
        - Return ONE complete Python file only, fenced as a single ```python block.
        - Allocate inputs, weights, intermediates, and outputs on device='{target_platform.device_string}' and keep them there throughout forward/verification.
        - CPU is acceptable only for metadata, scalars, and export serialization—avoid `.cpu()` or `.to('cpu')` on compute tensors.
        - Provide at least one @triton.jit kernel and a top-level Python wrapper
          named kernel_function(...). This wrapper must accept the same primary
          input tensor(s) as the model and any required weights/biases with shapes
          implied by the problem; it should orchestrate Triton kernel(s) and
          return the final output tensor.
        - No PyTorch math path: kernel_function MUST compute the final outputs
          using your Triton kernels only. Do NOT implement or fall back to
          torch.nn / torch.nn.functional / torch.* ops
          sigmoid, etc.) for producing the final result. Using PyTorch for
          reference comparisons is allowed only inside the self-test.
        - Use the data layout and dtype semantics indicated by subgraphs, defaulting
          to NCHW + float32 if unspecified. Respect stride/padding/dilation/groups,
          and exact op order.
        - Numerical equivalence: include a self-test (test_kernel or run_tests)
          that compares your Triton-based result to a PyTorch reference computed
          from the original problem code below (use get_init_inputs() and
          get_inputs() if present to instantiate the Model). The test must print
          'PASS' on success and exit with code 0. Use allclose with rtol<=1e-3,
          atol<=1e-3 for fp32; for fp16/bf16 allow up to 2e-2.
        - No imports beyond torch, triton, triton.language as tl, and stdlib. No I/O.
        - Do NOT monkey-patch PyTorch device functions or torch.cuda.is_available()
        - Do NOT manipulate TRITON_BACKENDS environment variable
        - Do NOT disable or mock XPU/CUDA drivers

        Implementation tips:
        - If merging multiple subgraphs, ensure intermediate tensor shapes match.
        - Hoist constant weights or parameters to avoid reloading per block.
        - Use tl.load/tl.store with masks for boundary conditions.
        - Favor coalesced memory access; tile by blocks; compute grid from shape.
        - Common Triton pitfalls to avoid:
          * Do NOT call tl.broadcast on Python scalars; tl.maximum(x, 0.0) works.
          * Prefer scalar constants directly in elementwise ops (no explicit broadcast needed).
          * Keep BLOCK_SIZE power-of-two; mask stores at tail.
        """
    ).strip()  # 去除首尾空白字符

    # 第五步:拼接完整的用户指令(按逻辑组织各部分内容)
    user_lines: list[str] = []
    user_lines.append(guidance)  # 核心指导语
    user_lines.append("")  # 空行分隔
    user_lines.append("SUBGRAPHS (summary):")  # 子图摘要标题
    user_lines.append(sg_summary)  # 子图摘要内容
    user_lines.append("")  # 空行分隔
    user_lines.append("ORIGINAL PROBLEM FILE:")  # 原始问题代码标题
    user_lines.append("```python")  # Python代码块开始标记
    user_lines.append(problem_code)  # 原始问题代码内容
    user_lines.append("```")  # Python代码块结束标记
    user_lines.append("")  # 空行分隔
    user_lines.append("SUBGRAPH KERNELS (reference implementations):")  # 参考内核标题
    user_lines.append(kernels_section)  # 参考内核代码内容
    user_lines.append("")  # 空行分隔
    # 最终要求:仅返回一个包含完整代码的Python代码块
    user_lines.append(
        "Return only one fenced Python code block with your final composed implementation."
    )
    # 拼接所有行,形成完整的Prompt
    return "\n".join(user_lines)

def _build_refinement_prompt(
    problem_code: str,
    subgraphs: list[dict[str, Any]],
    kernel_items: list[KernelItem],
    previous_code: str,
    error_info: dict[str, str],
    target_platform: PlatformConfig,
) -> str:
    """Prompt the LLM to refine the previously produced code based on errors.
    构建迭代优化Prompt:
    基于上一轮代码的错误信息,引导LLM修复Triton算子代码中的编译/运行错误
    参数:
    - previous_code: 上一轮生成的错误代码
    - error_info: 错误信息字典(包含stderr_tail/stdout_tail)
    其他参数同_build_composition_prompt
    返回值:针对性的优化指令字符串
    """
    # 提取错误日志尾部(最后2000字符,聚焦核心错误)
    err_tail = error_info.get("stderr_tail", "")
    # 提取标准输出尾部(辅助分析错误)
    out_tail = error_info.get("stdout_tail", "")

    # 构建优化指导语(聚焦错误修复,保留原有核心要求)
    guidance = textwrap.dedent(
        f"""
        You previously produced a composed Triton implementation, but it failed
        to run/compile. Analyze the ERROR_CONTEXT below and re-emit the entire
        corrected single-file implementation as one ```python block.

        TARGET PLATFORM: {target_platform.name}
        DEVICE STRING: {target_platform.device_string}

        Requirements remain the same. Additionally:
        - Fix any Triton compilation/runtime errors. For scalar constants in
          elementwise ops (e.g., ReLU), do not use tl.broadcast. Use direct
          scalars like 0.0 in tl.maximum(x, 0.0).
        - Keep function name kernel_function(...) unchanged and retain the
          self-test that prints PASS on success and exits 0.
        - Do NOT reintroduce any PyTorch math path in kernel_function. The final
          outputs must be computed via your Triton kernels only (no fallback to
          torch.nn / torch.nn.functional ops).
        - Return the complete corrected file; do not send diffs.
        """
    ).strip()

    # 拼接完整的优化指令(按“指导语→错误信息→原始代码→子图摘要→上一轮代码”组织)
    lines: list[str] = []
    lines.append(guidance)  # 优化指导语
    lines.append("")  # 空行分隔
    # 添加标准错误日志(核心错误信息)
    lines.append("ERROR_CONTEXT (stderr tail):\n```\n" + err_tail + "\n```")
    # 若标准输出非空,添加标准输出日志(辅助分析)
    if out_tail.strip():
        lines.append("STDOUT tail:\n```\n" + out_tail + "\n```")
    lines.append("")  # 空行分隔
    # 添加原始问题代码(保证上下文完整)
    lines.append("ORIGINAL PROBLEM FILE:\n```python\n" + problem_code + "\n```")
    lines.append("")  # 空行分隔
    # 添加子图摘要(避免LLM遗忘核心特征)
    lines.append("SUBGRAPHS (summary):\n" + _summarize_subgraphs_for_prompt(subgraphs))
    lines.append("")  # 空行分隔
    # 添加上一轮错误代码(让LLM对比分析问题)
    lines.append("PREVIOUS_ATTEMPT:\n```python\n" + previous_code + "\n```")
    lines.append("")  # 空行分隔
    # 最终要求:仅返回修正后的完整Python代码块
    lines.append(
        "Return only one fenced Python code block with the corrected implementation."
    )
    # 拼接所有行,形成优化Prompt
    return "\n".join(lines)

def _summarize_subgraphs_for_prompt(subgraphs: list[dict[str, Any]]) -> str:
    """
    生成子图摘要:将复杂的子图JSON信息压缩为简洁的文本格式
    核心目标:控制Token消耗,同时保留子图的ID、类型、布局、 dtype、形状、算子等核心特征
    """
    # 初始化摘要行列表
    lines: list[str] = []
    # 遍历每个子图
    for it in subgraphs:
        # 提取子图ID(默认unknown)
        sid = str(it.get("id", "unknown"))
        # 提取子图类型(如conv2d、linear)
        typ = str(it.get("type", ""))
        # 提取数据布局(默认NCHW)
        layout = it.get("data_layout") or "NCHW"
        # 提取数据类型(默认float32)
        dtype = it.get("dtype") or "float32"
        # 提取输入形状(兼容多输入/单输入格式)
        inputs = it.get("inputs")
        in_shape = it.get("input_shape")
        # 提取输出形状
        out_shape = it.get("output_shape")
        # 提取算子列表
        ops = it.get("ops") or []
        
        # 构建形状描述行(优先多输入格式,其次单输入格式)
        shapes_line = (
            f"inputs={inputs if inputs is not None else in_shape}, output={out_shape}"
        )
        # 构建子图核心信息行
        lines.append(
            f"- ID={sid} type={typ} layout={layout} dtype={dtype} {shapes_line}"
        )
        
        # 构建算子摘要(限制长度为400字符,避免Token溢出)
        try:
            # 尝试转为JSON字符串(结构化)
            ops_short = json.dumps(ops)[:400]
        except Exception:
            # 转换失败则直接转为字符串
            ops_short = str(ops)[:400]
        # 添加算子摘要行(缩进2空格,提升可读性)
        lines.append(f"  ops={ops_short}")
    
    # 拼接所有行,形成最终的子图摘要
    return "\n".join(lines)

2.5 조합 생성

다음 두 함수는 PyTorch KernelAgent에서 엔드투엔드 Triton 핵심 조합 생성의 핵심 실행 모듈입니다.

  • _auto_patch_common_triton_issues 이것은 가볍게 자동화된 코드 수선 도구는 LLM이 생성한 Triton 코드에 대한 사전 방해물 회피 패치를 제공합니다; _auto_patch_common_triton_issues: LLM이 생성한 Triton 코드에 대한 침투 없는 텍스트 수준 자동화 패치를 수행하여 Triton 개발 중 고발생 및 쉽게 범하는 기초 오류를 특별히 수정하고, 코드 실행 / 검증 전 낮은 수준의 오류를 사전에 방지하여 무효 반복을 줄이고 핵심 생성 효율을 향상시킵니다.
  • compose 최상위 입구와 프로세스 스케줄링 핵심로서 '프롬프트 구축, LLM 생성, 코드 패치, 실기 검증, 반복 정비, 제품 보관'의 전 과정을 연결하여 KernelAgent가 '분산 서브그래프 / 핵심 자료'에서 '실행 가능한 엔드투엔드 Triton 핵심'으로 엔드투엔드 드라이빙 핵심입니다., 반복 라운드에 따라 동적으로「조합 / 정비 Prompt」를 구축하고 LLM을 호출하여 코드를 생성한 후, 생성된 코드에 자동 패치를 적용하며, 실기 검증을 통한 다중 라운드 반복 정비를 지원하여 최종적으로 하드웨어 적합성을 갖춘 직접 배포 가능한 통합 Triton 커널 및 표준화된 결과 요약을 출력합니다.

_auto_patch_common_triton_issues 함수

핵심 역할

LLM이 Triton 코드를 생성할 때 자주 발생하는 문제에 대해두 가지 고빈도 저급 오류자동화된 텍스트 수정을 통해 코드 실행 전 함정을 미리 피하고, 기본 오류로 인한 검증 실패를 줄이며, 반복 효율을 향상시킵니다:

  1. 스칼라 남용 수정tl.broadcast문제(예)tl.broadcast(0.0, ...)바로 대입하기0.0), Triton 스칼라 연산 규범에 맞춰서;
  2. 목표 플랫폼과 호환되지 않는 CUDA 전용 hack 코드를 제거하여 플랫폼 간 실행 시 오류를 방지하고 코드가 목표 하드웨어와 일치하도록 보장합니다.
핵심 특징
  1. 보수적인 텍스트 패치, 논리적 침투 없음 : 문자열 수준의 간단한 대체 / 필터링만 수행 / 코드의 핵심 계산 논리를 수정하지 않아 패치가 새로운 논리 오류를 유발하지 않고 보안을 보장합니다;
  2. Triton 고주파 오류에 정확히 대응하여 높은 수리 효율 : 패치 규칙은 LLM이 가장 쉽게 발생하는 Triton 기본 오류에 집중하여 복잡한 논리 문제를 처리하지 않고 빠른 수리 속도와 높은 일중율을 보장합니다;
  3. 플랫폼 맞춤형 패치, 높은 적응성: 타겟 플랫폼 설정에 따라cuda_hacks_to_strip동적으로 호환되지 않는 코드를 제거하고, 다양한 플랫폼 확장을 지원하며, 플랫폼 규칙을 강제로 코딩할 필요 없음;
  4. 코드 블록 전체 필터링을 지원하여 더 포괄적인 처리: 단일 라인 CUDA hack 코드를 필터링하는 것 외에도 식별할 수 있습니다_fake_torch_device이러한 함수 정의, 구현코드 블록 전체 건너뛰기 필터링, 더 완벽하게 처리하십시오;
  5. 수정 상태를 반환하고, 프로세스가 인지될 수 있습니다.: 에서(patched_code, changed)튜플로 결과를 반환하고, 호출자에게 코드가 수정되었는지 명확히 알려주어 후속 로그 기록 및 프로세스 모니터링을 용이하게 합니다;
  6. 강건성이 뛰어나고 코드 침투 위험이 없습니다:모든 패치 작업은 원본 코드를 복사하여 수정하며, 원본 코드 내용을 변경하지 않아 부작용을 피합니다.

compose 함수

의 핵심 기능

는 커널 조합으로 생성된전체 프로세스 스케줄링 핵심로, 입력 로드부터 최종 제품 출력까지 모든 단계를 완료합니다:

  1. 환경 초기화(디렉토리 생성, LLM 서비스 제공자 로드, 목표 플랫폼 구성), 세 가지 핵심 입력(문제 코드, 서브그래프 정보, 유효한 서브그래프 커널) 검증 및 로드;
  2. 라운드별로 동적으로 Prompt 구축(초기에는 조합된 Prompt 사용, 이후에는 오류 주도의 정제 Prompt 사용), LLM을 호출하여 Triton 코드 생성;
  3. 생성된 코드에 자동 패치 적용, 중간 결과물 아카이빙, 실기 검증 주도의 다수 라운드 반복 정제 지원.
  4. 최종적으로 엔드투엔드 Triton 커널 파일을 출력하고, 표준화된 결과를 요약하여 영구 저장하며, 상위 모듈에 통일된 호출 인터페이스를 제공합니다.
핵심 특징
  1. Dual Prompt 동적 전환, 반복적인 정제가 더 정밀합니다 : 반복 횟수에 따라 자동으로 Prompt 유형을 전환합니다 —— 초기 생성 시 전체 제약 조건의 조합 Prompt를 사용 (없음에서 있음으로 커널을 구축), 후속 반복에서는 오류 주도의 정제 Prompt를 사용 (stderr/stdout 로그를 기반으로 대상적인 수정), 의도 없는 재생성을 피하고 반복 효율성을 크게 향상시킵니다;
  2. 검증 주도의 닫힌 루프 반복, 성공률이 높습니다 : 검증을 시작하면, 각 루프에서 생성된 코드가 모두run_candidate 실제 기기에서 실행을 검증하고, 실패하면 오류 로그를 추출하여 LLM으로 정제 요청을 전달하여 「생성→패치→검증→오류 발생→정제」의 순환 구조를 형성하고, 검증이 통과하거나 최대 반복 횟수에 도달할 때까지 최종 커널의 사용 가능성을 크게 향상시킵니다;
  3. 유연한 검증 스위치, 효율성과 효과성을 모두 고려합니다verify 스위치를 통해 실제 기기 검증을 켜고 끌 수 있도록 지원합니다 — 켜지면 제품의 유효성을 보장하고, 끄면 한 번 생성한 후 종료하여 가벼운 사용 시나리오의 효율성을 향상시킵니다;
  4. 전 과정 제품 구조화 보관, 추적 가능성이 높습니다:각 반복 시도마다 독립 파일을 생성합니다 (프롬프트 텍스트, 생성 코드), 검증 로그, 최종 커널, 결과 요약은 디렉토리 구조로 구조화하여 저장되어 모든 단계가 추적 가능하고 재현 가능하며 문제 진단에 용이합니다;
  5. 엄격한 입력 검증, 과정 내성이 높습니다:입력된 서브그래프 정보 형식과 LLM 서비스 제공자의 가용성을 엄격하게 검증하며, 형식 오류 시 프로세스를 즉시 중단하고 명확한 경고를 발생시켜 무효 실행을 방지합니다.
  6. 결과 출력을 표준화하고 엔지니어링 친화적 :성공 상태, 커널 경로, LLM 사용량, 반복 횟수, 검증 결과, 목표 플랫폼을 포함하는 표준화된 사전을 반환하며, JSON 형식의 요약 파일을 생성하여 상위 모듈 호출, 결과 통계 및 통합을 용이하게 합니다.
  7. 반복 횟수는 구성 가능하며 다양한 시나리오에 적합합니다 max_iters를 통해 최대 반복 횟수를 제어합니다 (기본값 5회), 요구 사항에 따라 조정하여 생성 효율성과 성공률을 균형을 맞춥니다.
  8. LLM 사용량 추적, 비용 관리 :각 라운드별 LLM 호출 사용량 정보(last_usage)를 기록하여 최종 결과에 포함시키며, LLM 생성 비용의 통계 및 관리를 지원합니다.
  9. 상류 하류 모듈의 무缝 연결 :상류 자식 도형 분해 및 커널이 생성한 출력 형식에 엄격하게 적합하며, 하류에서 생성된 composed_kernel.py를 직접 호출하여 전체 프로세스의 무缝 연결을 실현합니다;
  10. 타임아웃 및 환경 제어, 검증의 더 나은 신뢰성 run_candidate을 호출할 때 2400초 타임아웃을 설정하고, 격리 실행, 네트워크 금지 등의 구성을 지원하여 검증 과정이 환경 문제로 멈추는 것을 방지하고 검증의 안정성을 향상시킵니다.

협력 관계

두 함수는 「사전 패치 보호 + 전체 프로세스 스케줄링 실행」의 긴밀한 협력 관계를 형성하며, KernelAgent 조합 생성 엔드투엔드 Triton 커널의 핵심 지원입니다:

  1. _auto_patch_common_triton_issuescompose 함수의 전제 단계 로, LLM 코드 생성 후 실기 검증 전에 실행되어 Triton 기본 오류를 사전에 수정하고, 하위 수준의 문제로 인한 검증 실패를 줄이며, compose의 반복 프로세스「부담을 덜어주어」, 전체 실행 효율을 향상시킵니다.
  2. compose 최상위 스케줄러 로, _auto_patch_common_triton_issues를 호출하고 목표 플랫폼 구성을 전달하여 패치 작업이 하드웨어 요구 사항에 맞도록 하며, 패치 후 코드를 기록하고 후속 검증 및 반복 프로세스를 진행합니다.
  3. 를 결합하여「LLM 생성의 유연성 + 자동화 패치의 함정 회피 능력 + 검증 주도의 순환적 개선」을 실현하며, LLM이 서브그래프를 융합하고 복잡한 핵심을 구축하는 능력을 발휘하고, 자동화 패치와 멀티로운 정제를 통해 하위 오류를 회피하고 실행 문제를 해결하여, 최종적으로 높은 품질의 실행 가능한 엔드투엔드 Triton 핵심을 출력합니다.

코드

def _auto_patch_common_triton_issues(
    code: str, target_platform: PlatformConfig
) -> tuple[str, bool]:
    """Apply tiny safe textual patches for known Triton pitfalls.

    - Replace tl.broadcast(0.0, ...) or tl.broadcast(1.0, ...) with scalar constants.
    Returns (patched_code, changed).
    自动修复Triton算子代码中的常见问题:
    - 核心修复点:将tl.broadcast(0.0/1.0/0/1, ...)替换为标量常量(避免Triton广播操作的性能/语法问题)
    - 返回值:(修复后的代码, 是否发生修改)
    """
    # 初始化修复后的代码为原始代码
    patched = code
    # 标记是否发生修改(默认未修改)
    changed = False
    # 修复规则:采用保守的简单启发式规则,仅处理无歧义的常见问题
    patterns = [
        # 规则1:tl.broadcast(0.0 → 替换为0.0(移除不必要的广播操作)
        ("tl.broadcast(0.0", "0.0"),
        # 规则2:tl.broadcast(1.0 → 替换为1.0
        ("tl.broadcast(1.0", "1.0"),
        # 规则3:tl.broadcast(0, → 替换为0.0(统一数值类型为浮点数)
        ("tl.broadcast(0,", "0.0"),
        # 规则4:tl.broadcast(1, → 替换为1.0
        ("tl.broadcast(1,", "1.0"),
    ]
    # 遍历所有修复规则
    for old, new in patterns:
        # 若原始代码包含待修复的模式
        if old in patched:
            # 执行文本替换
            patched = patched.replace(old, new)
            # 标记为已修改
            changed = True
    # 移除CUDA相关的冗余hack代码(平台适配)
    # 获取当前目标平台需要剥离的CUDA hack模式列表
    cuda_hacks = target_platform.cuda_hacks_to_strip
    if cuda_hacks:
        # 将代码按行拆分
        lines = patched.split("\n")
        # 存储过滤后的代码行
        filtered_lines = []
        # 标记是否需要跳过直到空行(用于移除整个函数块)
        skip_until_blank = False
        # 逐行处理代码
        for line in lines:
            # 若处于跳过模式:跳过当前行,直到遇到空行
            if skip_until_blank:
                if line.strip() == "":
                    # 遇到空行,退出跳过模式
                    skip_until_blank = False
                continue
            # 检查当前行是否包含需要剥离的CUDA hack模式
            if any(hack in line for hack in cuda_hacks):
                # 标记为已修改
                changed = True
                # 特殊处理:若为_fake_torch_device函数定义,需跳过整个函数块(直到空行)
                if "def _fake_torch_device" in line:
                    skip_until_blank = True
                # 跳过当前行(移除hack代码)
                continue
            # 保留当前行
            filtered_lines.append(line)
        # 重新拼接过滤后的代码行
        patched = "\n".join(filtered_lines)
    # 返回修复后的代码和是否修改的标记
    return patched, changed

def compose(
    problem_path: Path,
    subgraphs_path: Path,
    kernels_summary_path: Path,
    out_dir: Path,
    model_name: str,
    verify: bool = False,
    max_iters: int = 5,
    target_platform: str = "cuda",
) -> dict[str, Any]:
    """
    核心函数:基于子图信息和内核摘要,合成/优化Triton算子代码
    参数说明:
    - problem_path: KernelBench问题文件路径
    - subgraphs_path: 子图信息JSON文件路径(由之前的子图提取模块生成)
    - kernels_summary_path: 内核摘要文件路径
    - out_dir: 输出目录路径
    - model_name: LLM模型名称
    - verify: 是否验证生成的代码(默认False)
    - max_iters: 最大迭代次数(默认5)
    - target_platform: 目标平台(默认cuda)
    返回值:包含合成结果的字典(成功状态、代码路径、迭代次数等)
    """
    # 前置检查:确保LLM提供商模块可导入
    if get_model_provider is None:
        raise SystemExit(
            "KernelAgent providers unavailable; ensure package import and dependencies"
        )

    # 创建输出目录(递归创建父目录,已存在则忽略)
    out_dir.mkdir(parents=True, exist_ok=True)
    # 获取LLM提供商实例(用于调用LLM生成代码)
    provider = get_model_provider(model_name)

    # 初始化目标平台配置(加载平台相关的修复规则、hack列表等)
    platform = get_platform(target_platform)

    # 加载输入文件:
    # 1. 加载KernelBench问题代码(原始PyTorch代码)
    problem_code = _read_text(problem_path)
    # 2. 加载子图信息JSON(由子图提取模块生成)
    subgraphs = json.loads(_read_text(subgraphs_path))
    # 检查子图信息是否为列表:非列表则退出
    if not isinstance(subgraphs, list):
        raise SystemExit("subgraphs.json must be a JSON array")
    # 3. 从内核摘要文件加载内核信息
    kernels = _load_kernels_from_summary(kernels_summary_path)

    # 创建迭代尝试目录(存储每一轮的生成代码)
    attempts_dir = out_dir / "attempts"
    attempts_dir.mkdir(parents=True, exist_ok=True)

    # 初始化变量:
    # - last_usage: 最后一次LLM调用的token使用信息
    # - last_code: 上一轮生成的代码
    # - verify_info: 验证信息字典
    last_usage = None
    last_code = None
    verify_info: dict[str, Any] = {}

    # 多轮迭代生成/优化代码(最多max_iters轮)
    for i in range(1, max_iters + 1):
        # 构建LLM Prompt:
        # 第一轮/上一轮代码为空 → 构建初始合成Prompt
        if i == 1 or last_code is None:
            prompt = _build_composition_prompt(
                problem_code, subgraphs, kernels, target_platform=platform
            )
        else:
            # 非第一轮 → 基于上一轮的错误信息构建优化Prompt
            # 提取验证日志的最后2000字符(便于LLM定位问题)
            stderr_tail = ""
            stdout_tail = ""
            try:
                # 读取标准错误日志尾部
                if verify_info.get("stderr_path"):
                    with open(
                        verify_info["stderr_path"],
                        "r",
                        encoding="utf-8",
                        errors="ignore",
                    ) as f:
                        stderr_tail = f.read()[-2000:]
                # 读取标准输出日志尾部
                if verify_info.get("stdout_path"):
                    with open(
                        verify_info["stdout_path"],
                        "r",
                        encoding="utf-8",
                        errors="ignore",
                    ) as f:
                        stdout_tail = f.read()[-2000:]
            except Exception:
                # 日志读取失败则忽略(避免中断迭代)
                pass
            # 构建优化Prompt(包含上一轮代码和错误信息)
            prompt = _build_refinement_prompt(
                problem_code,
                subgraphs,
                kernels,
                previous_code=last_code,
                error_info={"stderr_tail": stderr_tail, "stdout_tail": stdout_tail},
                target_platform=platform,
            )

        # 保存当前轮次的Prompt(便于追溯和调试)
        (attempts_dir / f"attempt_{i}.prompt.txt").write_text(prompt, encoding="utf-8")
        # 调用LLM生成代码:
        # - 消息格式:仅包含user角色的Prompt
        # - 最大生成Token数:50000(适配长代码生成)
        response = provider.get_response(
            model_name, [{"role": "user", "content": prompt}], max_tokens=50000
        )
        # 记录最后一次LLM调用的token使用信息
        last_usage = response.usage
        # 提取LLM输出的原始文本
        raw_text = response.content or ""

        # 从LLM输出中提取Python代码(剥离无关文本)
        extracted = extract_single_python_file(raw_text)
        code = extracted.code
        # 自动修复Triton常见问题(文本补丁)
        code, changed = _auto_patch_common_triton_issues(code, platform)
        # 保存当前轮次的生成代码
        (attempts_dir / f"attempt_{i}.py").write_text(code, encoding="utf-8")
        # 更新last_code为当前轮次的代码
        last_code = code

        # 验证当前轮次的代码(若开启验证)
        if verify:
            # 运行候选代码验证:
            # - artifacts_code_path: 当前轮次的代码路径
            # - run_root: 验证运行目录
            # - timeout_s: 验证超时时间(2400秒)
            # - isolated: 非隔离模式
            # - deny_network: 允许网络访问
            rr = run_candidate(
                artifacts_code_path=attempts_dir / f"attempt_{i}.py",
                run_root=out_dir / "runs",
                timeout_s=2400,
                isolated=False,
                deny_network=False,
            )
            # 记录验证信息
            verify_info = {
                "verify_rc": rr.rc,  # 验证退出码
                "verify_passed": rr.passed,  # 是否验证通过
                "verify_reason": rr.reason,  # 验证结果原因
                "validator": rr.validator_used,  # 使用的验证器
                "stdout_path": str(rr.stdout_path),  # 标准输出日志路径
                "stderr_path": str(rr.stderr_path),  # 标准错误日志路径
            }
            # 若验证通过,终止迭代(无需继续优化)
            if rr.passed:
                break
        else:
            # 若未开启验证,仅执行第一轮后终止
            break

    # 保存最终合成的算子代码(取最后一轮的代码)
    composed_path = out_dir / "composed_kernel.py"
    composed_path.write_text(last_code or "", encoding="utf-8")

    # 构建结果字典:包含核心合成信息
    result: dict[str, Any] = {
        # 成功状态:验证模式下为是否通过验证,非验证模式下默认成功
        "success": bool(verify_info.get("verify_passed", not verify)),
        # 最终合成代码的绝对路径
        "composed_path": str(composed_path.resolve()),
        # 使用的LLM模型名称
        "model": model_name,
        # LLM token使用信息
        "usage": last_usage,
        # 实际迭代次数
        "rounds": i,
        # 目标平台
        "target_platform": target_platform,
    }
    # 合并验证信息到结果字典
    result.update(verify_info)

    # 保存合成摘要(结构化JSON文件)
    (out_dir / "composition_summary.json").write_text(
        json.dumps(result, indent=2), encoding="utf-8"
    )
    # 返回结果字典
    return result

0xFF 참조