
jax 的 `@jit` 并非仅编译一次全局函数,而是根据输入的形状、dtype 和静态参数等构建缓存键,为每组兼容输入独立缓存一份 jaxpr 与编译产物,从而兼顾性能与语义正确性。
JAX 的 @jit 装饰器在首次调用时执行追踪(tracing)→ JAXPR 生成 → XLA 编译三步流程,并将结果缓存;但关键在于:缓存不是“单例式”的,而是多态的。JAX 会为每一组具有不同“缓存键(cache key)”的输入生成并保存独立的 JAXPR 与编译后可执行代码。
这个缓存键由以下要素共同决定:
- 所有数组参数的 shape 与 dtype(例如 f32[8] 与 f32[3] 视为不同键);
- 所有被标记为 static_argnums 或 static_argnames 的参数的 Python 值哈希(如 @jit(static_argnums=(0,)) 下传入的整数或布尔值);
- 全局配置状态(如 jax.default_device()、jax.debug_nans 等);
- 函数定义本身(源码哈希或 AST 等效性)。
因此,在你的示例中:
import jax
import jax.numpy as jnp
@jax.jit
def test(x):
if x.shape[0] > 4:
return 1
else:
return -1
x8 = jnp.ones(8) # shape = (8,)
x3 = jnp.ones(3) # shape = (3,)首次调用 test(x8) 时,JAX 追踪得到一个恒返回 1 的 JAXPR(因为 x.shape[0] == 8 > 4 为常量真,分支被完全剪枝),并缓存该版本;
当调用 test(x3) 时,因输入 shape 变为 (3,),缓存键不匹配,JAX 重新追踪,此时 x.shape[0] == 3 > 4 为假,JAXPR 恒返回 -1,并缓存第二个版本。
你可以通过 func._cache_size() 直观验证这一行为:
print(test._cache_size()) # 0 — 尚未调用 test(x8) print(test._cache_size()) # 1 — 缓存了 (8,) 版本 test(x8) print(test._cache_size()) # 1 — 复用,不新增 test(x3) print(test._cache_size()) # 2 — 新增 (3,) 版本
✅ 重要提示:JAX 不在运行时插入条件跳转(如 if 分支),而是在追踪阶段依据具体输入值做常量传播与控制流展开。这意味着 if 语句是否被“执行”,取决于其条件能否在追踪时被完全求值(即所有参与运算的值均为已知常量,如 shape、static 参数等)。这也解释了为何 test(jnp.ones(8)) 的 JAXPR 中不出现条件节点——它已被静态消除。
若希望强制共享同一份编译代码(例如规避多次编译开销),可显式使用 static_argnums 将 shape 相关逻辑移出追踪范围(需确保语义允许):
@jax.jit(static_argnums=(0,))
def test_static(n):
x = jnp.ones(n) # n 是 Python int,不参与 tracing
if n > 4: # 此处 n 是 static,if 在 tracing 时求值
return 1
else:
return -1但注意:这会使 n 成为编译常量,每次传入新 n 都触发全新编译(除非 n 值重复),且无法支持 n 来自运行时数组(如 n = x.shape[0])。
总结:JAX 的 jit 缓存是细粒度、基于输入签名的多版本缓存,而非单次全局编译。理解缓存键构成,有助于合理设计函数接口(如适时使用 static_argnums)、诊断意外重编译,以及正确预期控制流行为。










