0

0

深入理解 NumPy einsum 操作的细节与机制

心靈之曲

心靈之曲

发布时间:2025-10-23 14:44:39

|

498人浏览过

|

来源于php中文网

原创

深入理解 NumPy einsum 操作的细节与机制

`np.einsum` 是 numpy 中一个强大且灵活的函数,用于执行多维数组的乘积、求和、转置等操作。本文将通过两个核心方法——分解求和过程和显式循环模拟,详细解析 `np.einsum('ijk,jil->kl', a, b)` 如何进行元素级的乘积和求和,帮助读者彻底理解其内部机制。

np.einsum 基础与挑战

np.einsum (Einstein summation convention) 提供了一种简洁的字符串表示法来描述张量运算,包括点积、外积、转置、求和等。其核心在于通过输入张量的索引字符串和输出张量的索引字符串来定义操作。例如,'ijk,jil->kl' 表示将两个张量 a 和 b 进行运算,其中 i 和 j 是求和索引(它们出现在输入中但未出现在输出中),k 和 l 是输出索引。

尽管 einsum 语法简洁,但理解其内部元素是如何组合、相乘并最终求和的,对于初学者而言可能是一个挑战。本文将以具体示例 np.einsum('ijk,jil->kl', a, b) 为切入点,深入探讨其运算细节。

假设我们有以下两个 NumPy 张量:

import numpy as np

a = np.arange(8.).reshape(4, 2, 1) # 形状 (4, 2, 1)
b = np.arange(16.).reshape(2, 4, 2) # 形状 (2, 4, 2)

print("Tensor a:\n", a)
print("Tensor b:\n", b)

我们的目标是理解 np.einsum('ijk,jil->kl', a, b) 的计算过程。

方法一:分解求和过程,观察中间乘积

理解 einsum 运算的一种有效方法是逐步分解其求和过程。通过暂时保留所有中间索引,我们可以观察到每个元素的乘积,然后再手动执行求和。

对于 np.einsum('ijk,jil->kl', a, b),输出索引是 kl。这意味着所有在输入索引中出现但未在 kl 中出现的索引(即 i 和 j)都将被求和。

为了查看所有未经求和的乘积,我们可以将输出索引字符串扩展为包含所有输入索引:'ijk,jil->ijkl'。这样,einsum 将返回一个形状为 (i_len, j_len, k_len, l_len) 的张量,其中每个元素都是 a 和 b 中对应元素的乘积,而没有任何求和操作。

# 步骤1: 获取所有未经求和的乘积
products = np.einsum('ijk,jil->ijkl', a, b)
print("所有未经求和的乘积 (shape:", products.shape, "):\n", products)

在这个 products 张量中,products[i, j, k, l] 对应于 a[i, j, k] * b[j, i, l] 的乘积。这清楚地展示了 a 和 b 的元素是如何根据索引匹配进行组合的。

现在,为了回到原始的 ->kl 行为,我们需要对 i 和 j 轴进行求和。我们可以分两步完成:

# 步骤2: 对 j 轴(products 的第1轴)进行求和
sum_over_j = products.sum(axis=1)
print("\n对 j 轴求和后 (shape:", sum_over_j.shape, "):\n", sum_over_j)

# 步骤3: 对 i 轴(sum_over_j 的第0轴)进行求和
final_result = sum_over_j.sum(axis=0)
print("\n对 i 轴求和后 (shape:", final_result.shape, "):\n", final_result)

# 验证与原始 einsum 结果一致
original_einsum_result = np.einsum('ijk,jil->kl', a, b)
print("\n原始 einsum 结果 (shape:", original_einsum_result.shape, "):\n", original_einsum_result)

assert np.allclose(final_result, original_einsum_result)
print("\n分解求和结果与原始 einsum 结果一致。")

通过这种分解方式,我们直观地看到了每个元素乘积的形成,以及随后如何通过对特定轴求和来聚合这些乘积。

方法二:显式循环模拟 einsum 运算

另一种深入理解 einsum 的方法是将其转换为等价的显式嵌套循环。这能最清晰地展示每个元素的访问和累加过程。

对于 np.einsum('ijk,jil->kl', a, b),我们可以构建一个循环来遍历所有可能的 i, j, k, l 组合,并按照 einsum 的规则进行乘积和累加。

首先,确定输出张量的形状。由于输出是 kl,其形状将是 (k_len, l_len)。然后,我们遍历所有可能的 i, j, k, l 值。

def sum_array_explicit_loop(A, B):
    # 获取张量 A 的维度长度
    i_len, j_len, k_len = A.shape
    # 获取张量 B 的维度长度 (注意 B 的形状是 (j_len, i_len, l_len) 
    # 如果按照 einsum 的 jil 索引来理解,但其原始形状是 (2, 4, 2),
    # 这里的 _ 和 l_len 对应 B 的第0维和第2维)
    # 实际上,B 的原始形状是 (B_dim0, B_dim1, B_dim2)
    # 在 'jil' 中,j 对应 B_dim0, i 对应 B_dim1, l 对应 B_dim2
    # 所以,B.shape[0] 是 j 的最大值,B.shape[1] 是 i 的最大值,B.shape[2] 是 l 的最大值
    # 但是,i_len 和 j_len 已经由 A 决定,所以我们只需要 l_len
    # 确保维度兼容性:A.shape[1] (j_len_A) 必须等于 B.shape[0] (j_len_B)
    # A.shape[0] (i_len_A) 必须等于 B.shape[1] (i_len_B)
    # 这里我们直接从 A 和 B 的实际形状推导循环范围

    # 重新确认循环范围的正确性:
    # i 循环范围由 A.shape[0] 决定
    # j 循环范围由 A.shape[1] 决定
    # k 循环范围由 A.shape[2] 决定
    # l 循环范围由 B.shape[2] 决定 (因为 B 的第三个索引是 l)

    # 对于 'ijk,jil->kl'
    # i 的范围是 A.shape[0]
    # j 的范围是 A.shape[1] (同时也是 B.shape[0])
    # k 的范围是 A.shape[2]
    # l 的范围是 B.shape[2]

    i_max = A.shape[0]
    j_max = A.shape[1]
    k_max = A.shape[2]
    l_max = B.shape[2] # l 是 B 的最后一个维度

    # 初始化结果张量,形状为 (k_len, l_len)
    ret = np.zeros((k_max, l_max))

    # 四重嵌套循环模拟 einsum 运算
    for i in range(i_max):
        for j in range(j_max):
            for k in range(k_max):
                for l in range(l_max):
                    # 核心操作:A[i, j, k] * B[j, i, l] 并累加到 ret[k, l]
                    # 注意 B 的索引顺序是 j, i, l,这意味着 B 的原始第0维对应 j,第1维对应 i,第2维对应 l
                    ret[k, l] += A[i, j, k] * B[j, i, l]
    return ret

# 使用显式循环计算结果
explicit_loop_result = sum_array_explicit_loop(a, b)
print("\n显式循环计算结果 (shape:", explicit_loop_result.shape, "):\n", explicit_loop_result)

assert np.allclose(explicit_loop_result, original_einsum_result)
print("\n显式循环结果与原始 einsum 结果一致。")

通过显式循环,我们可以清晰地看到:

  • ret[k, l] 是输出张量中的一个元素。
  • += 操作表示对所有匹配的 i 和 j 进行求和。
  • A[i, j, k] 按照 ijk 的顺序访问 a 的元素。
  • B[j, i, l] 按照 jil 的顺序访问 b 的元素。这意味着 b 的原始第一个维度被当作 j,第二个维度被当作 i,第三个维度被当作 l。这就是 einsum 灵活之处,它会自动处理这种维度重排(permutation)。

einsum 索引规则总结

从以上两种方法中,我们可以提炼出 einsum 索引字符串的关键规则:

  1. 匹配与乘积: einsum 会遍历所有输入张量中相同索引的组合。例如,在 'ijk,jil->kl' 中,i 和 j 同时出现在 a 和 b 的索引中,因此 einsum 会在它们的值相等时将 a[i,j,k] 和 b[j,i,l] 的元素相乘。
  2. 维度重排(Permutation): 输入字符串中的索引顺序决定了如何访问张量的维度。例如,'jil' 对于张量 b 意味着 b 的第一个维度被视为 j,第二个维度被视为 i,第三个维度被视为 l。einsum 会自动处理这种访问顺序,无需手动 transpose。
  3. 求和(Reduction): 任何出现在输入索引字符串中,但未出现在输出索引字符串中的索引,都将被求和。在 'ijk,jil->kl' 中,i 和 j 出现在输入中但未出现在 kl 中,因此 einsum 会对所有可能的 i 和 j 值进行求和。
  4. 输出维度(Output Dimensions): 输出索引字符串 (kl) 定义了结果张量的维度和顺序。结果张量的形状将由这些输出索引的长度决定。

结论

np.einsum 是一个极其强大的工具,它通过简洁的字符串语法封装了复杂的张量运算。通过分解求和过程和显式循环模拟,我们可以深入理解 einsum 如何在元素级别上执行乘积和求和,以及它如何灵活地处理张量的维度重排和广播。掌握这些细节不仅有助于调试和优化 einsum 表达式,还能提升对多维数组运算的整体理解。在实际应用中,einsum 通常比手动循环或组合多个 NumPy 函数更高效、更具可读性。

相关专题

更多
js 字符串转数组
js 字符串转数组

js字符串转数组的方法:1、使用“split()”方法;2、使用“Array.from()”方法;3、使用for循环遍历;4、使用“Array.split()”方法。本专题为大家提供js字符串转数组的相关的文章、下载、课程内容,供大家免费下载体验。

258

2023.08.03

js截取字符串的方法
js截取字符串的方法

js截取字符串的方法有substring()方法、substr()方法、slice()方法、split()方法和slice()方法。本专题为大家提供字符串相关的文章、下载、课程内容,供大家免费下载体验。

209

2023.09.04

java基础知识汇总
java基础知识汇总

java基础知识有Java的历史和特点、Java的开发环境、Java的基本数据类型、变量和常量、运算符和表达式、控制语句、数组和字符串等等知识点。想要知道更多关于java基础知识的朋友,请阅读本专题下面的的有关文章,欢迎大家来php中文网学习。

1468

2023.10.24

字符串介绍
字符串介绍

字符串是一种数据类型,它可以是任何文本,包括字母、数字、符号等。字符串可以由不同的字符组成,例如空格、标点符号、数字等。在编程中,字符串通常用引号括起来,如单引号、双引号或反引号。想了解更多字符串的相关内容,可以阅读本专题下面的文章。

620

2023.11.24

java读取文件转成字符串的方法
java读取文件转成字符串的方法

Java8引入了新的文件I/O API,使用java.nio.file.Files类读取文件内容更加方便。对于较旧版本的Java,可以使用java.io.FileReader和java.io.BufferedReader来读取文件。在这些方法中,你需要将文件路径替换为你的实际文件路径,并且可能需要处理可能的IOException异常。想了解更多java的相关内容,可以阅读本专题下面的文章。

550

2024.03.22

php中定义字符串的方式
php中定义字符串的方式

php中定义字符串的方式:单引号;双引号;heredoc语法等等。想了解更多字符串的相关内容,可以阅读本专题下面的文章。

546

2024.04.29

go语言字符串相关教程
go语言字符串相关教程

本专题整合了go语言字符串相关教程,阅读专题下面的文章了解更多详细内容。

165

2025.07.29

c++字符串相关教程
c++字符串相关教程

本专题整合了c++字符串相关教程,阅读专题下面的文章了解更多详细内容。

81

2025.08.07

Java JVM 原理与性能调优实战
Java JVM 原理与性能调优实战

本专题系统讲解 Java 虚拟机(JVM)的核心工作原理与性能调优方法,包括 JVM 内存结构、对象创建与回收流程、垃圾回收器(Serial、CMS、G1、ZGC)对比分析、常见内存泄漏与性能瓶颈排查,以及 JVM 参数调优与监控工具(jstat、jmap、jvisualvm)的实战使用。通过真实案例,帮助学习者掌握 Java 应用在生产环境中的性能分析与优化能力。

13

2026.01.20

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
React 教程
React 教程

共58课时 | 3.9万人学习

Pandas 教程
Pandas 教程

共15课时 | 0.9万人学习

ASP 教程
ASP 教程

共34课时 | 3.8万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号