0

0

JAX 的 jit 缓存机制:如何基于形状与类型动态复用 JAXPR

花韻仙語

花韻仙語

发布时间:2026-02-11 19:00:57

|

747人浏览过

|

来源于php中文网

原创

JAX 的 jit 缓存机制:如何基于形状与类型动态复用 JAXPR

jax 的 `@jit` 并非仅编译一次全局函数,而是根据输入的形状、dtype 和静态参数等构建缓存键,为每组兼容输入独立缓存一份 jaxpr 与编译产物,从而兼顾性能与语义正确性。

JAX 的 @jit 装饰器在首次调用时执行追踪(tracing)→ JAXPR 生成 → XLA 编译三步流程,并将结果缓存;但关键在于:缓存不是“单例式”的,而是多态的。JAX 会为每一组具有不同“缓存键(cache key)”的输入生成并保存独立的 JAXPR 与编译后可执行代码。

这个缓存键由以下要素共同决定:

  • 所有数组参数的 shape 与 dtype(例如 f32[8] 与 f32[3] 视为不同键);
  • 所有被标记为 static_argnums 或 static_argnames 的参数的 Python 值哈希(如 @jit(static_argnums=(0,)) 下传入的整数或布尔值);
  • 全局配置状态(如 jax.default_device()、jax.debug_nans 等);
  • 函数定义本身(源码哈希或 AST 等效性)。

因此,在你的示例中:

import jax
import jax.numpy as jnp

@jax.jit
def test(x):
    if x.shape[0] > 4:
        return 1
    else:
        return -1

x8 = jnp.ones(8)  # shape = (8,)
x3 = jnp.ones(3)  # shape = (3,)

首次调用 test(x8) 时,JAX 追踪得到一个恒返回 1 的 JAXPR(因为 x.shape[0] == 8 > 4 为常量真,分支被完全剪枝),并缓存该版本;
当调用 test(x3) 时,因输入 shape 变为 (3,),缓存键不匹配,JAX 重新追踪,此时 x.shape[0] == 3 > 4 为假,JAXPR 恒返回 -1,并缓存第二个版本。

你可以通过 func._cache_size() 直观验证这一行为:

ChatGPT Website Builder
ChatGPT Website Builder

ChatGPT网站生成器,AI对话快速生成网站

下载
print(test._cache_size())  # 0 — 尚未调用
test(x8)
print(test._cache_size())  # 1 — 缓存了 (8,) 版本
test(x8)
print(test._cache_size())  # 1 — 复用,不新增
test(x3)
print(test._cache_size())  # 2 — 新增 (3,) 版本
✅ 重要提示:JAX 不在运行时插入条件跳转(如 if 分支),而是在追踪阶段依据具体输入值做常量传播与控制流展开。这意味着 if 语句是否被“执行”,取决于其条件能否在追踪时被完全求值(即所有参与运算的值均为已知常量,如 shape、static 参数等)。这也解释了为何 test(jnp.ones(8)) 的 JAXPR 中不出现条件节点——它已被静态消除。

若希望强制共享同一份编译代码(例如规避多次编译开销),可显式使用 static_argnums 将 shape 相关逻辑移出追踪范围(需确保语义允许):

@jax.jit(static_argnums=(0,))
def test_static(n):
    x = jnp.ones(n)  # n 是 Python int,不参与 tracing
    if n > 4:         # 此处 n 是 static,if 在 tracing 时求值
        return 1
    else:
        return -1

但注意:这会使 n 成为编译常量,每次传入新 n 都触发全新编译(除非 n 值重复),且无法支持 n 来自运行时数组(如 n = x.shape[0])。

总结:JAX 的 jit 缓存是细粒度、基于输入签名的多版本缓存,而非单次全局编译。理解缓存键构成,有助于合理设计函数接口(如适时使用 static_argnums)、诊断意外重编译,以及正确预期控制流行为。

相关标签:

本站声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
java基础知识汇总
java基础知识汇总

java基础知识有Java的历史和特点、Java的开发环境、Java的基本数据类型、变量和常量、运算符和表达式、控制语句、数组和字符串等等知识点。想要知道更多关于java基础知识的朋友,请阅读本专题下面的的有关文章,欢迎大家来php中文网学习。

1547

2023.10.24

if什么意思
if什么意思

if的意思是“如果”的条件。它是一个用于引导条件语句的关键词,用于根据特定条件的真假情况来执行不同的代码块。本专题提供if什么意思的相关文章,供大家免费阅读。

804

2023.08.22

java多态详细介绍
java多态详细介绍

本专题整合了java多态相关内容,阅读专题下面的文章了解更多详细内容。

20

2025.11.27

硬盘接口类型介绍
硬盘接口类型介绍

硬盘接口类型有IDE、SATA、SCSI、Fibre Channel、USB、eSATA、mSATA、PCIe等等。详细介绍:1、IDE接口是一种并行接口,主要用于连接硬盘和光驱等设备,它主要有两种类型:ATA和ATAPI,IDE接口已经逐渐被SATA接口;2、SATA接口是一种串行接口,相较于IDE接口,它具有更高的传输速度、更低的功耗和更小的体积;3、SCSI接口等等。

1390

2023.10.19

PHP接口编写教程
PHP接口编写教程

本专题整合了PHP接口编写教程,阅读专题下面的文章了解更多详细内容。

318

2025.10.17

php8.4实现接口限流的教程
php8.4实现接口限流的教程

PHP8.4本身不内置限流功能,需借助Redis(令牌桶)或Swoole(漏桶)实现;文件锁因I/O瓶颈、无跨机共享、秒级精度等缺陷不适用高并发场景。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

2208

2025.12.29

java接口相关教程
java接口相关教程

本专题整合了java接口相关内容,阅读专题下面的文章了解更多详细内容。

36

2026.01.19

2026春节习俗大全
2026春节习俗大全

本专题整合了2026春节习俗大全,阅读专题下面的文章了解更多详细内容。

54

2026.02.11

Yandex网页版官方入口使用指南_国际版与俄罗斯版访问方法解析
Yandex网页版官方入口使用指南_国际版与俄罗斯版访问方法解析

本专题全面整理了Yandex搜索引擎的官方入口信息,涵盖国际版与俄罗斯版官网访问方式、网页版直达入口及免登录使用说明,帮助用户快速、安全地进入Yandex官网,高效使用其搜索与相关服务。

154

2026.02.11

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号