0

0

JAX 中实现动态轴索引的数组滚动(roll)操作

聖光之護

聖光之護

发布时间:2026-02-18 10:46:12

|

740人浏览过

|

来源于php中文网

原创

JAX 中实现动态轴索引的数组滚动(roll)操作

在 JAX 编译函数中,jnp.roll 不支持动态 axis 参数;本文介绍一种基于 lax.broadcasted_iota 与索引映射的纯静态可追踪方案,实现沿变量轴高效、可 jit 的数组滚动。

在 jax 编译函数中,`jnp.roll` 不支持动态 `axis` 参数;本文介绍一种基于 `lax.broadcasted_iota` 与索引映射的纯静态可追踪方案,实现沿变量轴高效、可 jit 的数组滚动。

JAX 的函数式与静态图特性要求所有控制流和形状/索引相关操作必须在编译期可推导(即“静态可知”)。jnp.roll(A, shift, axis=ind) 中若 ind 是 traced 值(如来自 lax.map 的迭代变量),则触发 ConcretizationTypeError——因为底层 roll 实际通过 lax.gather/lax.scatter 等原语实现,而这些原语强制 axis 必须是 Python int 或 compile-time 常量。

直接使用 @jax.jit(static_argnums=...) 无法解决该问题:static_argnums 仅对函数参数本身生效,而 lax.map(fn, xs) 会将 xs 中每个元素作为动态输入传入 fn,因此 ind 在 roll(ind) 内部始终是 tracer,无法提升为 static。

✅ 正确解法:绕过 jnp.roll,手动构建动态轴上的循环移位索引。核心思想是:

Nimo.space
Nimo.space

智能画布式AI工作台

下载
  • 利用 jax.lax.broadcasted_iota 生成各维度的标准坐标网格;
  • 对目标轴 ind 对应的索引序列单独执行 jnp.roll(..., -1);
  • 用 jnp.meshgrid(..., sparse=True) 高效构造稀疏索引元组;
  • 最终通过高级索引 A[tuple(indices)] 完成等效滚动。

以下为完整、可 jit 的实现:

import jax
import jax.numpy as jnp

def roll_dynamic(A, ind, shift=-1):
    """Roll array A along dynamic axis `ind` by `shift` positions.

    Works under jit/lax.map. Requires all dimensions of A to be equal
    (for clean iota-based indexing; generalization possible but more complex).
    """
    assert A.ndim > 0
    assert len(set(A.shape)) == 1, "All dimensions must be equal for this implementation."

    D = A.ndim
    N = A.shape[0]

    # Step 1: Create base indices for each axis: shape (D, N)
    # e.g., for D=4, N=4 → [[0,1,2,3], [0,1,2,3], [0,1,2,3], [0,1,2,3]]
    base_indices = jax.lax.broadcasted_iota(jnp.int32, (D, N), 1)

    # Step 2: Identify which axis to roll — broadcast `ind` to (D, 1)
    axis_mask = (jnp.arange(D)[:, None] == ind)  # shape (D, 1)

    # Step 3: Roll only the target axis' indices
    rolled_axis_indices = jnp.roll(base_indices, shift, axis=-1)
    indices = jnp.where(axis_mask, rolled_axis_indices, base_indices)

    # Step 4: Build sparse meshgrid indices for advanced indexing
    # meshgrid(..., sparse=True) returns tuple of length D, each shape (N,)
    grid = jnp.meshgrid(*indices, indexing='ij', sparse=True)

    return A[tuple(grid)]

# 示例验证
A = jnp.arange(256).reshape(4, 4, 4, 4)
indList = jnp.asarray([0, 1, 2])

# ✅ 可安全用于 lax.map
result = jax.lax.map(lambda ind: roll_dynamic(A, ind), indList)

# 验证等价性(与传统 roll 对齐)
for ind in range(4):
    ref = jnp.roll(A, -1, axis=ind)
    assert jnp.array_equal(result[ind], ref), f"Failed at axis {ind}"

⚠️ 注意事项:

  • 维度约束:上述实现假设 A.shape 各维相等(如 (4,4,4,4)),这是为简化 iota 构造与广播逻辑。若需支持不规则形状,需为每维单独生成 jnp.arange(dim_size) 并拼接,但会显著增加代码复杂度与 trace 开销。
  • 性能权衡:相比原生 jnp.roll,该方法引入额外索引计算与 gather 操作,在超大数组上可能略慢,但保证了完全可编译性。
  • shift 支持:shift 参数同样可设为 traced 值(如 lax.map 中传入不同偏移),只需将其也加入 roll_dynamic 参数并参与 jnp.roll(..., shift, ...) 即可。
  • 扩展建议:对高维稀疏场景,可结合 lax.dynamic_slice + lax.concatenate 手动拼接切片,避免全量索引内存开销;但需自行处理正/负 shift 边界。

总结而言,当 JAX 原语限制迫使你脱离高层 API 时,以 lax.iota/lax.broadcasted_iota 构建结构化索引 + jnp.where 动态路由 + 高级索引取值,是解决“动态轴操作”问题的经典范式。它不仅适用于 roll,还可推广至 swapaxes、moveaxis、甚至自定义轴重排等场景。

本站声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
java基础知识汇总
java基础知识汇总

java基础知识有Java的历史和特点、Java的开发环境、Java的基本数据类型、变量和常量、运算符和表达式、控制语句、数组和字符串等等知识点。想要知道更多关于java基础知识的朋友,请阅读本专题下面的的有关文章,欢迎大家来php中文网学习。

1553

2023.10.24

string转int
string转int

在编程中,我们经常会遇到需要将字符串(str)转换为整数(int)的情况。这可能是因为我们需要对字符串进行数值计算,或者需要将用户输入的字符串转换为整数进行处理。php中文网给大家带来了相关的教程以及文章,欢迎大家前来学习阅读。

770

2023.08.02

int占多少字节
int占多少字节

int占4个字节,意味着一个int变量可以存储范围在-2,147,483,648到2,147,483,647之间的整数值,在某些情况下也可能是2个字节或8个字节,int是一种常用的数据类型,用于表示整数,需要根据具体情况选择合适的数据类型,以确保程序的正确性和性能。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

572

2024.08.29

c++怎么把double转成int
c++怎么把double转成int

本专题整合了 c++ double相关教程,阅读专题下面的文章了解更多详细内容。

254

2025.08.29

C++中int的含义
C++中int的含义

本专题整合了C++中int相关内容,阅读专题下面的文章了解更多详细内容。

210

2025.08.29

go语言 数组和切片
go语言 数组和切片

本专题整合了go语言数组和切片的区别与含义,阅读专题下面的文章了解更多详细内容。

47

2025.09.03

golang map内存释放
golang map内存释放

本专题整合了golang map内存相关教程,阅读专题下面的文章了解更多相关内容。

75

2025.09.05

golang map相关教程
golang map相关教程

本专题整合了golang map相关教程,阅读专题下面的文章了解更多详细内容。

36

2025.11.16

pixiv网页版官网登录与阅读指南_pixiv官网直达入口与在线访问方法
pixiv网页版官网登录与阅读指南_pixiv官网直达入口与在线访问方法

本专题系统整理pixiv网页版官网入口及登录访问方式,涵盖官网登录页面直达路径、在线阅读入口及快速进入方法说明,帮助用户高效找到pixiv官方网站,实现便捷、安全的网页端浏览与账号登录体验。

462

2026.02.13

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号