0

0

JAX vjp 失败原因与 vmap + custom_vjp 的正确组合方式

霞舞

霞舞

发布时间:2026-01-09 12:53:25

|

160人浏览过

|

来源于php中文网

原创

JAX vjp 失败原因与 vmap + custom_vjp 的正确组合方式

当对带有 `custom_vjp` 的函数先 `vmap` 再调用 `vjp` 时,若在定义 `vmap` 版本后覆盖了原始函数名,会导致前向传播中递归调用错误的 vmapped 版本,从而引发 cotangent 形状不匹配的错误。

在 JAX 中,custom_vjp 的前向函数(fwd)必须严格调用原始未变换的函数,以确保其输入/输出形状与 vjp 约定一致:即前向传播返回的 primal_out 形状应与后续 vjp 接收的 cotangent 形状完全匹配(即 cotangent.shape == primal_out.shape)。

问题代码中,关键错误在于:

test_func = vmap(test_func, in_axes=(None, 0))  # ❌ 覆盖了原始 test_func

这导致 test_func_fwd 内部调用的 test_func(jnp.dot(R, R)) 实际执行的是 已 vmapped 的版本,而 jnp.dot(R, R) 的输入 R 是标量(因 R 是 jnp.dot 的结果,shape 为 ()),但 vmapped test_func 期望 R 具有 batch 维度(如 (10, 3)),于是内部逻辑错乱,最终使 primal_out 的隐式形状与 vjp 期望不符——vjp 认为输出是 (10,),但 bwd 接收到的 residual 和 cotangent 却因前向误调而维度失配,触发报错:

ValueError: Shape of cotangent input to vjp pullback function (10,) must be the same as the shape of corresponding primal input (10, 3).

该错误信息虽表述为“cotangent 应与 primal input 同形”,实则是 JAX 在反向传播校验阶段,因前向路径被污染,无法正确推导出梯度传播所需的张量结构所致。

魔珐星云
魔珐星云

无需昂贵GPU,一键解锁超写实/二次元等多风格3D数字人,跨端适配千万级并发的具身智能平台。

下载

✅ 正确做法是:保留原始 test_func 不变,仅将 vmap 结果赋给新变量名

# ✅ 保持原始 test_func 不被覆盖
test_func_mapped = vmap(test_func, in_axes=(None, 0))

# 在 vjp 中使用映射后的版本
primal, f_vjp = vjp(partial(test_func_mapped, f), jnp.ones((10, 3)))
cotangent = jnp.ones(10)  # shape matches primal_out: (10,)
cotangent_out = f_vjp(cotangent)

print(cotangent_out[0].shape)  # → (10, 3)

? 补充注意事项:

  • custom_vjp 的 fwd 函数中禁止调用任何高阶变换(如 vmap, jit, grad)后的函数,除非明确设计为支持嵌套;
  • 若需批量处理并保留 vjp 可用性,推荐使用 vmap 包裹整个 vjp 调用(即 vmap(vjp(...))),而非先 vmap 函数再 vjp;
  • 对于复杂控制流或状态依赖场景,建议通过 jax.custom_vjp + non-differentiable arguments 显式隔离不变参数,并始终在 fwd/bwd 中使用原始函数引用。

遵循“函数变换不覆盖原名”这一原则,可避免绝大多数 vmap 与 custom_vjp 组合时的静默行为异常。

相关专题

更多
点击input框没有光标怎么办
点击input框没有光标怎么办

点击input框没有光标的解决办法:1、确认输入框焦点;2、清除浏览器缓存;3、更新浏览器;4、使用JavaScript;5、检查硬件设备;6、检查输入框属性;7、调试JavaScript代码;8、检查页面其他元素;9、考虑浏览器兼容性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

183

2023.11.24

Golang 性能分析与pprof调优实战
Golang 性能分析与pprof调优实战

本专题系统讲解 Golang 应用的性能分析与调优方法,重点覆盖 pprof 的使用方式,包括 CPU、内存、阻塞与 goroutine 分析,火焰图解读,常见性能瓶颈定位思路,以及在真实项目中进行针对性优化的实践技巧。通过案例讲解,帮助开发者掌握 用数据驱动的方式持续提升 Go 程序性能与稳定性。

9

2026.01.22

html编辑相关教程合集
html编辑相关教程合集

本专题整合了html编辑相关教程合集,阅读专题下面的文章了解更多详细内容。

56

2026.01.21

三角洲入口地址合集
三角洲入口地址合集

本专题整合了三角洲入口地址合集,阅读专题下面的文章了解更多详细内容。

28

2026.01.21

AO3中文版入口地址大全
AO3中文版入口地址大全

本专题整合了AO3中文版入口地址大全,阅读专题下面的的文章了解更多详细内容。

378

2026.01.21

妖精漫画入口地址合集
妖精漫画入口地址合集

本专题整合了妖精漫画入口地址合集,阅读专题下面的文章了解更多详细内容。

115

2026.01.21

java版本选择建议
java版本选择建议

本专题整合了java版本相关合集,阅读专题下面的文章了解更多详细内容。

3

2026.01.21

Java编译相关教程合集
Java编译相关教程合集

本专题整合了Java编译相关教程,阅读专题下面的文章了解更多详细内容。

16

2026.01.21

C++多线程相关合集
C++多线程相关合集

本专题整合了C++多线程相关教程,阅读专题下面的的文章了解更多详细内容。

9

2026.01.21

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Java 教程
Java 教程

共578课时 | 49.2万人学习

国外Web开发全栈课程全集
国外Web开发全栈课程全集

共12课时 | 1.0万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号