0

0

如何在 JAX 中正确计算批量矩阵指数(expm)

心靈之曲

心靈之曲

发布时间:2026-01-15 23:57:28

|

849人浏览过

|

来源于php中文网

原创

如何在 JAX 中正确计算批量矩阵指数(expm)

本文详解 jax 中 `jax.scipy.linalg.expm` 批量计算失败的常见原因与解决方案,涵盖新版原生支持、旧版兼容写法及关键形状调试技巧。

在使用 JAX 计算矩阵指数(如量子线路中的参数化幺正演化 $ e^{iA} $)时,一个典型错误是:

ValueError: expected A to be a square matrix

尽管你确认最后两维是方阵(如 (4, 4)),但报错仍发生——这往往源于 输入张量的维度结构不符合 expm 的隐式批处理规则

? 根本原因:expm 对输入形状有严格要求

jax.scipy.linalg.expm 自 JAX v0.4.7 起原生支持批量输入,但前提是:
✅ 输入数组的最后两个轴必须构成方阵(如 (..., n, n));
❌ 其余前导维度将被自动视为 batch 维度;
❌ 若中间存在非 batch 的冗余维度(如你的 A.shape = (2, 2, 2, 2, 2, 2, 2, 2, 4, 4)),它仍能工作;
⚠️ 但若 A 的最后两维不满足 n == n(例如 (4, 5)),或 A.ndim

在你的代码中,问题出在 pauli_matrix(num_qubits) 的构造逻辑:

def pauli_matrix(num_qubits):
    _pauli_matrices = jnp.array(
        [[[1, 0], [0, 1]], [[0, 1], [1, 0]], [[0, -1j], [1j, 0]], [[1, 0], [0, -1]]]
    )
    # ❌ 错误:对 _pauli_matrices 重复 kronecker 积,却未指定作用于哪一组 qubit
    # 且 [1:] 切片导致维度混乱,最终使 tensordot 结果 A 的 shape 不符合预期
    return reduce(jnp.kron, (_pauli_matrices for _ in range(num_qubits)))[1:]

该函数实际生成的是 (15, 4**num_qubits, 4**num_qubits) 形状的 Pauli 基(对 2-qubit 应为 (15, 4, 4)),但 reduce(jnp.kron, ...) 在 num_qubits=2 时会生成 (4^2, 4^2) = (16, 16) 矩阵,再 [1:] 切片得 (15, 16, 16) —— 而你 theta 是 (15, 2,2,2,2,2,2,2,2),tensordot 后 A 实际为 (2,2,2,2,2,2,2,2, 16, 16),并非你误以为的 (2,...,2,4,4)。因此 expm 接收的不是 (N, 4, 4),而是高维张量,但只要末两维是方阵,新版 JAX 就能处理。

✅ 正确做法:确保 A 的 shape 为 (..., d, d),其中 d = 2**num_qubits。

CoCo
CoCo

智谱AI推出的首个有记忆的企业自主Agent智能体

下载

✅ 解决方案一:升级 JAX 并规范输入(推荐)

确保使用 JAX ≥ 0.4.7:

pip install --upgrade jax jaxlib

然后修正 pauli_matrix 和 SpecialUnitary:

import jax.numpy as jnp
import jax.scipy.linalg as linalg
from functools import reduce

def pauli_basis_1q():
    return jnp.array([
        [[1., 0.], [0., 1.]],   # I
        [[0., 1.], [1., 0.]],   # X
        [[0., -1j], [1j, 0.]],  # Y
        [[1., 0.], [0., -1.]],  # Z
    ])

def pauli_matrix(num_qubits):
    """返回 (4**num_qubits - 1) 个 traceless n-qubit Pauli 算符,shape (15, 4, 4) for n=2"""
    basis = pauli_basis_1q()
    # 构造所有非恒等的 n-qubit Pauli 张量积:共 4^n - 1 个
    from itertools import product
    ops = []
    for indices in product(range(4), repeat=num_qubits):
        if all(i == 0 for i in indices):  # skip identity
            continue
        op = basis[indices[0]]
        for i in indices[1:]:
            op = jnp.kron(op, basis[i])
        ops.append(op)
    return jnp.stack(ops)  # shape: (15, 4, 4) for num_qubits=2

num_qubits = 2
d = 2 ** num_qubits  # 4
theta = jnp.pi * jnp.random.uniform(shape=(15,))  # 简化:单组参数,shape (15,)

A = jnp.tensordot(theta, pauli_matrix(num_qubits), axes=[[0], [0]])  # -> (4, 4)
U = linalg.expm(1j * A / 2)  # ✅ works: (4, 4)

# 批量示例:theta shape (8, 15) → A shape (8, 4, 4) → U shape (8, 4, 4)
theta_batch = jnp.pi * jnp.random.uniform(shape=(8, 15))
A_batch = jnp.einsum('bi,ij->bjk', theta_batch, pauli_matrix(num_qubits))  # (8, 4, 4)
U_batch = linalg.expm(1j * A_batch / 2)  # ✅ native batch support
print(U_batch.shape)  # (8, 4, 4)

⚙️ 解决方案二:旧版 JAX 兼容写法(jnp.vectorize)

若受限于旧版 JAX(

expm_vec = jnp.vectorize(linalg.expm, signature='(n,n)->(n,n)')

# A_batch shape: (B, d, d)
U_batch = expm_vec(1j * A_batch / 2)  # returns (B, d, d)

⚠️ 注意:vectorize 在 JIT 下可能不如原生批量高效,仅作兼容之用。

? 关键检查清单

  • ✅ 使用 A.shape[-2] == A.shape[-1] 验证末两维是否为方阵;
  • ✅ 避免在 tensordot 或 einsum 中引入意外维度(如你的原始 theta 有 9 维,极易出错);
  • ✅ 优先用 einsum 替代嵌套 tensordot 提升可读性;
  • ✅ 调试时打印 A.shape 和 A.dtype,确认无 float64(JAX 默认 float32,expm 要求浮点)。

掌握这些要点,你就能稳健地在 JAX 中实现量子态演化、李群指数映射等核心计算。

相关专题

更多
go语言 数组和切片
go语言 数组和切片

本专题整合了go语言数组和切片的区别与含义,阅读专题下面的文章了解更多详细内容。

46

2025.09.03

html编辑相关教程合集
html编辑相关教程合集

本专题整合了html编辑相关教程合集,阅读专题下面的文章了解更多详细内容。

37

2026.01.21

三角洲入口地址合集
三角洲入口地址合集

本专题整合了三角洲入口地址合集,阅读专题下面的文章了解更多详细内容。

17

2026.01.21

AO3中文版入口地址大全
AO3中文版入口地址大全

本专题整合了AO3中文版入口地址大全,阅读专题下面的的文章了解更多详细内容。

227

2026.01.21

妖精漫画入口地址合集
妖精漫画入口地址合集

本专题整合了妖精漫画入口地址合集,阅读专题下面的文章了解更多详细内容。

59

2026.01.21

java版本选择建议
java版本选择建议

本专题整合了java版本相关合集,阅读专题下面的文章了解更多详细内容。

3

2026.01.21

Java编译相关教程合集
Java编译相关教程合集

本专题整合了Java编译相关教程,阅读专题下面的文章了解更多详细内容。

14

2026.01.21

C++多线程相关合集
C++多线程相关合集

本专题整合了C++多线程相关教程,阅读专题下面的的文章了解更多详细内容。

6

2026.01.21

无人机驾驶证报考 uom民用无人机综合管理平台官网
无人机驾驶证报考 uom民用无人机综合管理平台官网

无人机驾驶证(CAAC执照)报考需年满16周岁,初中以上学历,身体健康(矫正视力1.0以上,无严重疾病),且无犯罪记录。个人需通过民航局授权的训练机构报名,经理论(法规、原理)、模拟飞行、实操(GPS/姿态模式)及地面站训练后考试合格,通常15-25天拿证。

27

2026.01.21

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Java 教程
Java 教程

共578课时 | 49.1万人学习

国外Web开发全栈课程全集
国外Web开发全栈课程全集

共12课时 | 1.0万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号