深入理解 NumPy einsum 的张量运算细节

深入理解 NumPy einsum 的张量运算细节

numpy的`einsum`提供了一种简洁高效的张量运算方式,通过爱因斯坦求和约定实现元素乘法与求和。本文将深入解析`np.einsum(‘ijk,jil->kl’, a, b)`这类复杂表达式的内部机制,通过中间索引输出和等效循环两种方法,详细阐述其如何基于共享和非共享索引完成张量元素的组合与累加,帮助读者透彻理解其工作原理,从而更有效地利用`einsum`处理复杂的张量操作。

1. np.einsum 简介与核心机制

np.einsum(Einstein Summation Convention,爱因斯坦求和约定)是NumPy中一个强大而灵活的函数,用于执行各种张量运算,包括点积、外积、转置、求和、矩阵乘法等。其核心在于通过字符串形式的索引标记来定义输入张量的维度关系和期望的输出张量结构。

一个典型的einsum表达式形如 ‘输入下标,输入下标->输出下标’。其中,重复出现的下标表示对该维度进行求和(隐式求和),未在输出下标中出现的输入下标也会被求和。本文将以一个具体的例子np.einsum(‘ijk,jil->kl’, a, b)来深入剖析其内部的元素乘法与求和过程。

假设我们有两个张量 a 和 b:

import numpy as np  # 张量 a 的形状为 (4, 2, 1) a = np.arange(8.).reshape(4, 2, 1) print("张量 a:n", a) # 张量 b 的形状为 (2, 4, 2) b = np.arange(16.).reshape(2, 4, 2) print("n张量 b:n", b)

对于表达式 np.einsum(‘ijk,jil->kl’, a, b):

  • 张量 a 的维度由 i, j, k 表示。
  • 张量 b 的维度由 j, i, l 表示。
  • 输出张量的维度由 k, l 表示。

这意味着:

  1. a 的第一个维度 i 与 b 的第二个维度 i 相乘。
  2. a 的第二个维度 j 与 b 的第一个维度 j 相乘。
  3. k 和 l 是输出张量的维度。
  4. 由于 i 和 j 在输出下标 kl 中没有出现,因此将对 i 和 j 维度进行求和。

2. 方法一:通过中间输出分解求和过程

为了更好地理解 einsum 如何组合并求和元素,我们可以首先生成一个不进行任何求和的中间结果。通过在输出下标中包含所有输入下标,我们可以观察到每个元素乘法的具体结果。

表达式 np.einsum(‘ijk,jil->ijkl’, a, b) 告诉 einsum 保持所有 i, j, k, l 维度,仅执行元素乘法而不进行任何求和。这样,输出张量 ijkl 的每个元素 output[i,j,k,l] 都将是 a[i,j,k] * b[j,i,l] 的结果。

深入理解 NumPy einsum 的张量运算细节

商汤商量

商汤科技研发的AI对话工具,商量商量,都能解决。

深入理解 NumPy einsum 的张量运算细节36

查看详情 深入理解 NumPy einsum 的张量运算细节

# 步骤 1: 执行元素乘法,不进行任何求和 # 通过在输出下标中包含所有输入下标 (i, j, k, l),可以查看每个元素的乘积。 # 此时输出张量的形状为 (i_len, j_len, k_len, l_len) = (4, 2, 1, 2) intermediate_product = np.einsum('ijk,jil->ijkl', a, b) print("中间乘积 (ijkl):n", intermediate_product)  # 步骤 2: 对 'j' 维度进行求和 # 原始表达式 'ijk,jil->kl' 中,j 是一个被求和的维度。 # 对应到 intermediate_product,j 是其第二个轴 (axis=1)。 sum_over_j = intermediate_product.sum(axis=1) print("n对 'j' 维度求和后的结果:n", sum_over_j) # 此时张量形状变为 (i_len, k_len, l_len) = (4, 1, 2)  # 步骤 3: 对 'i' 维度进行求和 # 原始表达式 'ijk,jil->kl' 中,i 也是一个被求和的维度。 # 对应到 sum_over_j,i 是其第一个轴 (axis=0)。 final_result_method1 = sum_over_j.sum(axis=0) print("n对 'i' 维度求和后的最终结果:n", final_result_method1) # 最终张量形状为 (k_len, l_len) = (1, 2),与 'kl' 匹配。

通过这种分解方式,我们清晰地看到了 einsum 如何首先进行所有可能的元素乘法,然后按照爱因斯坦求和约定对未出现在输出下标中的维度进行累加求和。

3. 方法二:等效循环实现

理解 einsum 的另一种有效方法是将其转换为等效的显式循环。这有助于我们追踪每个元素是如何被计算和累加的。

对于 np.einsum(‘ijk,jil->kl’, a, b),我们可以将其转换为以下嵌套循环:

def sum_array_explicit_loop(A, B):     # 获取张量 A 的维度长度     i_len, j_len, k_len = A.shape     # 获取张量 B 的最后一个维度长度(对应 l)     # 注意:B 的实际形状是 (j_max, i_max, l_max),因此 l_len 对应 B.shape[2]     l_len = B.shape[2]      # 初始化结果张量,其形状应为 (k_len, l_len)     ret = np.zeros((k_len, l_len))      # 遍历所有可能的 i, j, k, l 组合     # i 遍历 A 的第一个维度 (0-3)     # j 遍历 A 的第二个维度 (0-1)     # k 遍历 A 的第三个维度 (0-0)     # l 遍历 B 的第三个维度 (0-1)     for i in range(i_len):         for j in range(j_len):             for k in range(k_len):                 for l in range(l_len):                     # 根据 einsum 表达式 'ijk,jil->kl',执行元素乘法                     # A 的索引为 (i, j, k)                     # B 的索引为 (j, i, l)                     # 将乘积累加到结果张量 ret[k, l] 中                     ret[k, l] += A[i, j, k] * B[j, i, l]     return ret  final_result_method2 = sum_array_explicit_loop(a, b) print("n通过显式循环计算的最终结果:n", final_result_method2)

通过运行上述代码,我们可以看到它与 np.einsum(‘ijk,jil->kl’, a, b) 直接计算的结果以及方法一分解求和后的结果是完全一致的。

# 验证与直接使用 einsum 的结果是否一致 einsum_direct_result = np.einsum('ijk,jil->kl', a, b) print("n直接使用 einsum 的结果:n", einsum_direct_result)  # 比较两种方法的结果 print("n两种方法结果是否一致 (方法1 vs 直接einsum):", np.allclose(final_result_method1, einsum_direct_result)) print("两种方法结果是否一致 (方法2 vs 直接einsum):", np.allclose(final_result_method2, einsum_direct_result))

这个循环清晰地展示了 einsum 的内部逻辑:它遍历所有与输入张量形状兼容的索引组合,对每个组合执行元素乘法,并将结果累加到由输出下标定义的相应位置。

4. 注意事项与总结

  • 效率:尽管显式循环有助于理解,但在实际应用中,np.einsum 通常比手动编写的python循环效率高得多,因为它在底层利用了优化的C或Fortran实现。
  • 索引的对应关系:理解输入张量和输出张量中索引的对应关系是使用 einsum 的关键。共享的输入索引表示这些维度必须匹配,并且如果它们未出现在输出中,则会进行求和。
  • 灵活性:einsum 可以表达非常广泛的张量操作,从简单的转置(’ij->ji’)到复杂的张量积。熟练掌握其索引规则能够极大地简化张量运算的代码。
  • 调试:当遇到复杂的 einsum 表达式时,使用本文介绍的“中间输出”方法(即在输出下标中包含所有输入下标)是一个非常有用的调试技巧,可以帮助你一步步追踪计算过程。

通过本文的详细解析,相信读者对 np.einsum 在处理复杂张量运算时的内部机制有了更深入的理解。掌握 einsum 不仅能提升代码的简洁性和效率,更能加深对张量代数的理解。

上一篇
下一篇
text=ZqhQzanResources