0

0

JAX Pallas 中静态参数的正确传递方式

聖光之護

聖光之護

发布时间:2026-03-11 13:17:24

|

263人浏览过

|

来源于php中文网

原创

JAX Pallas 中静态参数的正确传递方式

在 JAX Pallas 中,pallas_call 会强制追踪所有传入参数,即使其在外部 jit 中被标记为静态;解决方法是通过闭包捕获静态值,而非直接作为 kernel 参数传递。

在 jax pallas 中,`pallas_call` 会强制追踪所有传入参数,即使其在外部 `jit` 中被标记为静态;解决方法是通过闭包捕获静态值,而非直接作为 kernel 参数传递。

JAX 的追踪(tracing)机制是其函数式变换(如 jit、vmap、grad)的核心基础:它将 Python 函数重写为可优化、可编译的计算图(JAXPR)。然而,这一机制在嵌套调用中具有层级独立性——即上层函数(如 jit)声明的 static_argnums 仅对本层有效,不会透传给内部调用的低级原语(如 pallas_call)

在你的代码中:

@functools.partial(jax.jit, static_argnums=(1,))  # ✅ offsets 在 jit 层被视为静态
def dia_matmul(diags, offsets, other):
    return pl.pallas_call(...)(diags, offsets, other)  # ❌ pallas_call 无视此声明,全部参数被 traced

offsets 虽被 jit 视为静态(不参与梯度、不随输入变化),但一旦作为参数传入 pallas_call,Pallas 运行时会将其包装为 Traced(即动态内存引用),导致 print(offsets) 输出类似 Traced<...>{int32[]} 的调试信息——这正是你观察到的现象。根本原因在于:pallas_call 当前不支持 static_argnums 或等效机制,所有 kernel 参数均按动态张量处理

✅ 正确解法:利用 Python 闭包(closure)将静态值“冻结”在 kernel 外部作用域,使其完全不进入 pallas_call 的参数列表:

蛙蛙写作——超级AI智能写作助手
蛙蛙写作——超级AI智能写作助手

蛙蛙写作辅助AI写文,帮助获取创意灵感,提供拆书、小说转剧本、视频生成等功能,是一款功能全面的AI智能写作工具。

下载
@functools.partial(jax.jit, static_argnums=(1,))
def dia_matmul(diags: Array, offsets: tuple[int], other: Array) -> Array:
    # 闭包捕获 offsets → 它不再作为 kernel 参数,而是编译时常量
    def kernel(diags_ref, other_ref, o_ref):
        diags_val = diags_ref[...]
        other_val = other_ref[...]
        N = other_val.shape[0]
        out = jnp.zeros((N, N), dtype=other_val.dtype)

        # offsets 现在是纯 Python tuple,可在 for 循环中安全解包
        for i, offset in enumerate(offsets):  # ✅ 静态 tuple,支持 len()、索引、迭代
            diag = diags_val[i]  # 假设 diags.shape[0] == len(offsets)
            start = jax.lax.max(0, offset)
            end = jax.lax.min(N, N + offset)
            top = jax.lax.max(0, -offset)
            bottom = top + (end - start)

            # 注意:Pallas 中需用 jax.lax.min/max 替代 Python min/max
            # 且切片需保证静态形状推断(start/end/top/bottom 必须是标量 tracer)
            out = out.at[top:bottom, :].add(
                diag[start:end, None] * other_val[start:end, :]
            )
        o_ref[...] = out

    return pl.pallas_call(
        kernel,
        out_shape=jax.ShapeDtypeStruct(other.shape, other.dtype)
    )(diags, other)

? 关键改进说明:

  • offsets 不再出现在 pallas_call 的参数签名或调用中,而是通过闭包在 kernel 内部直接使用;
  • for i, offset in enumerate(offsets) 可正常执行,因为 offsets 是编译期已知的 Python tuple(非 tracer);
  • 所有 min/max 替换为 jax.lax.min/max,确保在 traced 上下文中安全;
  • diag[start:end, None] 中的 start:end 切片依赖于 offset,但因 offset 是静态整数,JAX 能推断出切片长度为常量,满足 Pallas 的 shape 约束。

⚠️ 注意事项:

  • 避免在 kernel 中使用 print() 调试 traced 值:它输出的是 tracer 对象,无实际数值意义;应改用 jax.debug.print("offset: {x}", x=offset)(需启用 --jax_debug_nans 或 jax.config.update("jax_explain_traces", True) 辅助诊断);
  • 若 offsets 长度可变(如 tuple[int, ...]),需确保其长度在编译时固定(例如通过 static_argnums 限定),否则 enumerate(offsets) 将报错;
  • Pallas kernel 中所有控制流(for、if)必须能被 JAX 静态展开,因此循环次数必须由静态值(如 len(offsets))决定。

总结:JAX 的静态性是“逐层显式声明”的,不存在隐式继承。要让值在 Pallas kernel 中保持静态,唯一可靠的方式是将其移出参数列表,转为闭包变量。这是 JAX 高性能内核开发中的关键范式,也是理解 tracing 与 compilation 边界的重要实践。

本站声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
python中print函数的用法
python中print函数的用法

python中print函数的语法是“print(value1, value2, ..., sep=' ', end=' ', file=sys.stdout, flush=False)”。本专题为大家提供print相关的文章、下载、课程内容,供大家免费下载体验。

192

2023.09.27

python print用法与作用
python print用法与作用

本专题整合了python print的用法、作用、函数功能相关内容,阅读专题下面的文章了解更多详细教程。

18

2026.02.03

java基础知识汇总
java基础知识汇总

java基础知识有Java的历史和特点、Java的开发环境、Java的基本数据类型、变量和常量、运算符和表达式、控制语句、数组和字符串等等知识点。想要知道更多关于java基础知识的朋友,请阅读本专题下面的的有关文章,欢迎大家来php中文网学习。

1566

2023.10.24

if什么意思
if什么意思

if的意思是“如果”的条件。它是一个用于引导条件语句的关键词,用于根据特定条件的真假情况来执行不同的代码块。本专题提供if什么意思的相关文章,供大家免费阅读。

846

2023.08.22

string转int
string转int

在编程中,我们经常会遇到需要将字符串(str)转换为整数(int)的情况。这可能是因为我们需要对字符串进行数值计算,或者需要将用户输入的字符串转换为整数进行处理。php中文网给大家带来了相关的教程以及文章,欢迎大家前来学习阅读。

1010

2023.08.02

int占多少字节
int占多少字节

int占4个字节,意味着一个int变量可以存储范围在-2,147,483,648到2,147,483,647之间的整数值,在某些情况下也可能是2个字节或8个字节,int是一种常用的数据类型,用于表示整数,需要根据具体情况选择合适的数据类型,以确保程序的正确性和性能。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

610

2024.08.29

c++怎么把double转成int
c++怎么把double转成int

本专题整合了 c++ double相关教程,阅读专题下面的文章了解更多详细内容。

334

2025.08.29

C++中int的含义
C++中int的含义

本专题整合了C++中int相关内容,阅读专题下面的文章了解更多详细内容。

235

2025.08.29

C# ASP.NET Core微服务架构与API网关实践
C# ASP.NET Core微服务架构与API网关实践

本专题围绕 C# 在现代后端架构中的微服务实践展开,系统讲解基于 ASP.NET Core 构建可扩展服务体系的核心方法。内容涵盖服务拆分策略、RESTful API 设计、服务间通信、API 网关统一入口管理以及服务治理机制。通过真实项目案例,帮助开发者掌握构建高可用微服务系统的关键技术,提高系统的可扩展性与维护效率。

3

2026.03.11

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号