
`np.einsum` 是 numpy 中一个强大且灵活的函数,用于执行多维数组的乘积、求和、转置等操作。本文将通过两个核心方法——分解求和过程和显式循环模拟,详细解析 `np.einsum('ijk,jil->kl', a, b)` 如何进行元素级的乘积和求和,帮助读者彻底理解其内部机制。
np.einsum 基础与挑战
np.einsum (Einstein summation convention) 提供了一种简洁的字符串表示法来描述张量运算,包括点积、外积、转置、求和等。其核心在于通过输入张量的索引字符串和输出张量的索引字符串来定义操作。例如,'ijk,jil->kl' 表示将两个张量 a 和 b 进行运算,其中 i 和 j 是求和索引(它们出现在输入中但未出现在输出中),k 和 l 是输出索引。
尽管 einsum 语法简洁,但理解其内部元素是如何组合、相乘并最终求和的,对于初学者而言可能是一个挑战。本文将以具体示例 np.einsum('ijk,jil->kl', a, b) 为切入点,深入探讨其运算细节。
假设我们有以下两个 NumPy 张量:
import numpy as np
a = np.arange(8.).reshape(4, 2, 1) # 形状 (4, 2, 1)
b = np.arange(16.).reshape(2, 4, 2) # 形状 (2, 4, 2)
print("Tensor a:\n", a)
print("Tensor b:\n", b)我们的目标是理解 np.einsum('ijk,jil->kl', a, b) 的计算过程。
方法一:分解求和过程,观察中间乘积
理解 einsum 运算的一种有效方法是逐步分解其求和过程。通过暂时保留所有中间索引,我们可以观察到每个元素的乘积,然后再手动执行求和。
对于 np.einsum('ijk,jil->kl', a, b),输出索引是 kl。这意味着所有在输入索引中出现但未在 kl 中出现的索引(即 i 和 j)都将被求和。
为了查看所有未经求和的乘积,我们可以将输出索引字符串扩展为包含所有输入索引:'ijk,jil->ijkl'。这样,einsum 将返回一个形状为 (i_len, j_len, k_len, l_len) 的张量,其中每个元素都是 a 和 b 中对应元素的乘积,而没有任何求和操作。
# 步骤1: 获取所有未经求和的乘积
products = np.einsum('ijk,jil->ijkl', a, b)
print("所有未经求和的乘积 (shape:", products.shape, "):\n", products)在这个 products 张量中,products[i, j, k, l] 对应于 a[i, j, k] * b[j, i, l] 的乘积。这清楚地展示了 a 和 b 的元素是如何根据索引匹配进行组合的。
现在,为了回到原始的 ->kl 行为,我们需要对 i 和 j 轴进行求和。我们可以分两步完成:
# 步骤2: 对 j 轴(products 的第1轴)进行求和
sum_over_j = products.sum(axis=1)
print("\n对 j 轴求和后 (shape:", sum_over_j.shape, "):\n", sum_over_j)
# 步骤3: 对 i 轴(sum_over_j 的第0轴)进行求和
final_result = sum_over_j.sum(axis=0)
print("\n对 i 轴求和后 (shape:", final_result.shape, "):\n", final_result)
# 验证与原始 einsum 结果一致
original_einsum_result = np.einsum('ijk,jil->kl', a, b)
print("\n原始 einsum 结果 (shape:", original_einsum_result.shape, "):\n", original_einsum_result)
assert np.allclose(final_result, original_einsum_result)
print("\n分解求和结果与原始 einsum 结果一致。")通过这种分解方式,我们直观地看到了每个元素乘积的形成,以及随后如何通过对特定轴求和来聚合这些乘积。
方法二:显式循环模拟 einsum 运算
另一种深入理解 einsum 的方法是将其转换为等价的显式嵌套循环。这能最清晰地展示每个元素的访问和累加过程。
对于 np.einsum('ijk,jil->kl', a, b),我们可以构建一个循环来遍历所有可能的 i, j, k, l 组合,并按照 einsum 的规则进行乘积和累加。
首先,确定输出张量的形状。由于输出是 kl,其形状将是 (k_len, l_len)。然后,我们遍历所有可能的 i, j, k, l 值。
def sum_array_explicit_loop(A, B):
# 获取张量 A 的维度长度
i_len, j_len, k_len = A.shape
# 获取张量 B 的维度长度 (注意 B 的形状是 (j_len, i_len, l_len)
# 如果按照 einsum 的 jil 索引来理解,但其原始形状是 (2, 4, 2),
# 这里的 _ 和 l_len 对应 B 的第0维和第2维)
# 实际上,B 的原始形状是 (B_dim0, B_dim1, B_dim2)
# 在 'jil' 中,j 对应 B_dim0, i 对应 B_dim1, l 对应 B_dim2
# 所以,B.shape[0] 是 j 的最大值,B.shape[1] 是 i 的最大值,B.shape[2] 是 l 的最大值
# 但是,i_len 和 j_len 已经由 A 决定,所以我们只需要 l_len
# 确保维度兼容性:A.shape[1] (j_len_A) 必须等于 B.shape[0] (j_len_B)
# A.shape[0] (i_len_A) 必须等于 B.shape[1] (i_len_B)
# 这里我们直接从 A 和 B 的实际形状推导循环范围
# 重新确认循环范围的正确性:
# i 循环范围由 A.shape[0] 决定
# j 循环范围由 A.shape[1] 决定
# k 循环范围由 A.shape[2] 决定
# l 循环范围由 B.shape[2] 决定 (因为 B 的第三个索引是 l)
# 对于 'ijk,jil->kl'
# i 的范围是 A.shape[0]
# j 的范围是 A.shape[1] (同时也是 B.shape[0])
# k 的范围是 A.shape[2]
# l 的范围是 B.shape[2]
i_max = A.shape[0]
j_max = A.shape[1]
k_max = A.shape[2]
l_max = B.shape[2] # l 是 B 的最后一个维度
# 初始化结果张量,形状为 (k_len, l_len)
ret = np.zeros((k_max, l_max))
# 四重嵌套循环模拟 einsum 运算
for i in range(i_max):
for j in range(j_max):
for k in range(k_max):
for l in range(l_max):
# 核心操作:A[i, j, k] * B[j, i, l] 并累加到 ret[k, l]
# 注意 B 的索引顺序是 j, i, l,这意味着 B 的原始第0维对应 j,第1维对应 i,第2维对应 l
ret[k, l] += A[i, j, k] * B[j, i, l]
return ret
# 使用显式循环计算结果
explicit_loop_result = sum_array_explicit_loop(a, b)
print("\n显式循环计算结果 (shape:", explicit_loop_result.shape, "):\n", explicit_loop_result)
assert np.allclose(explicit_loop_result, original_einsum_result)
print("\n显式循环结果与原始 einsum 结果一致。")通过显式循环,我们可以清晰地看到:
- ret[k, l] 是输出张量中的一个元素。
- += 操作表示对所有匹配的 i 和 j 进行求和。
- A[i, j, k] 按照 ijk 的顺序访问 a 的元素。
- B[j, i, l] 按照 jil 的顺序访问 b 的元素。这意味着 b 的原始第一个维度被当作 j,第二个维度被当作 i,第三个维度被当作 l。这就是 einsum 灵活之处,它会自动处理这种维度重排(permutation)。
einsum 索引规则总结
从以上两种方法中,我们可以提炼出 einsum 索引字符串的关键规则:
- 匹配与乘积: einsum 会遍历所有输入张量中相同索引的组合。例如,在 'ijk,jil->kl' 中,i 和 j 同时出现在 a 和 b 的索引中,因此 einsum 会在它们的值相等时将 a[i,j,k] 和 b[j,i,l] 的元素相乘。
- 维度重排(Permutation): 输入字符串中的索引顺序决定了如何访问张量的维度。例如,'jil' 对于张量 b 意味着 b 的第一个维度被视为 j,第二个维度被视为 i,第三个维度被视为 l。einsum 会自动处理这种访问顺序,无需手动 transpose。
- 求和(Reduction): 任何出现在输入索引字符串中,但未出现在输出索引字符串中的索引,都将被求和。在 'ijk,jil->kl' 中,i 和 j 出现在输入中但未出现在 kl 中,因此 einsum 会对所有可能的 i 和 j 值进行求和。
- 输出维度(Output Dimensions): 输出索引字符串 (kl) 定义了结果张量的维度和顺序。结果张量的形状将由这些输出索引的长度决定。
结论
np.einsum 是一个极其强大的工具,它通过简洁的字符串语法封装了复杂的张量运算。通过分解求和过程和显式循环模拟,我们可以深入理解 einsum 如何在元素级别上执行乘积和求和,以及它如何灵活地处理张量的维度重排和广播。掌握这些细节不仅有助于调试和优化 einsum 表达式,还能提升对多维数组运算的整体理解。在实际应用中,einsum 通常比手动循环或组合多个 NumPy 函数更高效、更具可读性。









