0

0

PyTorch高效矩阵运算:从循环到广播机制的优化实践

DDD

DDD

发布时间:2025-10-07 13:09:01

|

793人浏览过

|

来源于php中文网

原创

PyTorch高效矩阵运算:从循环到广播机制的优化实践

本教程旨在解决PyTorch中矩阵操作的效率问题,特别是当涉及对多个标量-矩阵运算结果求和时。文章将详细阐述如何将低效的Python循环转换为利用PyTorch广播机制的向量化操作,从而显著提升代码性能,实现GPU加速,并确保数值计算的准确性,最终输出简洁高效的优化方案。

1. 问题背景与低效实现分析

pytorch深度学习框架中,python循环(for 循环)通常会导致性能瓶颈,尤其是在处理大型张量时。这是因为python循环是在cpu上执行的,无法充分利用gpu的并行计算能力,也无法利用底层c++或cuda优化的张量操作。

考虑以下一个典型的低效实现,它试图计算一系列矩阵操作的总和:

import torch

m = 100
n = 100
b = torch.rand(m) # 形状为 (m,) 的一维张量
a = torch.rand(m) # 形状为 (m,) 的一维张量
sumation_old = 0
A = torch.rand(n, n) # 形状为 (n, n) 的二维矩阵

# 低效的循环实现
for i in range(m):
    # 每次迭代都进行矩阵减法、标量乘法和矩阵除法
    sumation_old = sumation_old + a[i] / (A - b[i] * torch.eye(n))

print("循环实现的求和结果 (部分):")
print(sumation_old[:2, :2]) # 打印部分结果

在这个例子中,我们迭代 m 次,每次迭代都执行以下操作:

  1. b[i] * torch.eye(n):一个标量与一个单位矩阵相乘。
  2. A - ...:一个矩阵与上一步的结果相减。
  3. a[i] / ...:一个标量除以上一步的矩阵。
  4. 将结果累加到 sumation_old。

这种逐元素或逐次迭代的计算方式,在 m 较大时会显著降低程序执行效率。

2. 向量化:利用PyTorch广播机制

PyTorch的广播(Broadcasting)机制允许不同形状的张量在满足一定条件时执行逐元素操作,而无需显式地复制数据。这是实现向量化操作的关键。其核心思想是,通过巧妙地调整张量的维度,使得操作能够一次性在整个张量上完成,而不是通过循环逐个处理。

对于本例中的操作 a[i] / (A - b[i] * torch.eye(n)),我们可以将其分解为以下几个步骤进行向量化:

  1. 准备 torch.eye(n): torch.eye(n) 的形状是 (n, n)。为了与 b 中的所有元素进行广播乘法,我们需要将其扩展一个维度,使其变为 (1, n, n)。
  2. 准备 b: b 的形状是 (m,)。为了与 (1, n, n) 的单位矩阵进行广播乘法,我们需要将其形状调整为 (m, 1, 1)。
  3. *计算 `b[i] torch.eye(n)的向量化版本:** 将b(形状(m, 1, 1)) 与扩展后的单位矩阵torch.eye(n).unsqueeze(0)(形状(1, n, n)) 相乘。根据广播规则,结果将是形状为(m, n, n)的张量,其中B[k, :, :]等于b[k] * torch.eye(n)`。
  4. 准备 A: A 的形状是 (n, n)。为了与上一步得到的 (m, n, n) 张量进行广播减法,我们需要将其扩展一个维度,使其变为 (1, n, n)。
  5. *计算 `A - b[i] torch.eye(n)的向量化版本:** 将扩展后的A.unsqueeze(0)(形状(1, n, n)) 与上一步得到的B(形状(m, n, n)) 相减。结果将是形状为(m, n, n)` 的张量。
  6. 准备 a: a 的形状是 (m,)。为了与上一步得到的 (m, n, n) 张量进行广播除法,我们需要将其形状调整为 (m, 1, 1)。
  7. 计算 a[i] / (...) 的向量化版本: 将调整后的 a.unsqueeze(1).unsqueeze(2) (形状 (m, 1, 1)) 除以上一步得到的 A_minus_B (形状 (m, n, n))。结果将是形状为 (m, n, n) 的张量。
  8. 求和: 对最终的 (m, n, n) 张量沿着第一个维度(即 m 维度)进行求和,得到最终的 (n, n) 结果。

3. 优化实现与代码示例

根据上述向量化策略,我们可以将原始的循环代码重构为以下高效的PyTorch实现:

WPS AI
WPS AI

金山办公发布的AI办公应用,提供智能文档写作、阅读理解和问答、智能人机交互的能力。

下载
import torch

m = 100
n = 100
b = torch.rand(m)
a = torch.rand(m)
A = torch.rand(n, n)

# 1. 准备单位矩阵并扩展维度
# torch.eye(n) 的形状是 (n, n)
# unsqueeze(0) 后变为 (1, n, n)
identity_matrix_expanded = torch.eye(n).unsqueeze(0)

# 2. 准备 b 并扩展维度
# b 的形状是 (m,)
# unsqueeze(1).unsqueeze(2) 后变为 (m, 1, 1)
b_expanded = b.unsqueeze(1).unsqueeze(2)

# 3. 计算 b[i] * torch.eye(n) 的向量化版本
# (m, 1, 1) * (1, n, n) -> 广播后得到 (m, n, n)
B_terms = identity_matrix_expanded * b_expanded

# 4. 准备 A 并扩展维度
# A 的形状是 (n, n)
# unsqueeze(0) 后变为 (1, n, n)
A_expanded = A.unsqueeze(0)

# 5. 计算 A - b[i] * torch.eye(n) 的向量化版本
# (1, n, n) - (m, n, n) -> 广播后得到 (m, n, n)
A_minus_B_terms = A_expanded - B_terms

# 6. 准备 a 并扩展维度
# a 的形状是 (m,)
# unsqueeze(1).unsqueeze(2) 后变为 (m, 1, 1)
a_expanded = a.unsqueeze(1).unsqueeze(2)

# 7. 计算 a[i] / (...) 的向量化版本
# (m, 1, 1) / (m, n, n) -> 广播后得到 (m, n, n)
division_results = a_expanded / A_minus_B_terms

# 8. 对结果沿第一个维度(m 维度)求和
# torch.sum(..., dim=0) 将 (m, n, n) 压缩为 (n, n)
summation_new = torch.sum(division_results, dim=0)

print("\n向量化实现的求和结果 (部分):")
print(summation_new[:2, :2]) # 打印部分结果

# 完整优化代码(更简洁)
print("\n完整优化代码:")
B = torch.eye(n).unsqueeze(0) * b.unsqueeze(1).unsqueeze(2)
A_minus_B = A.unsqueeze(0) - B
summation_new_concise = torch.sum(a.unsqueeze(1).unsqueeze(2) / A_minus_B, dim=0)
print(summation_new_concise[:2, :2])

4. 数值精度与验证

由于浮点数运算的特性,以及不同计算路径(循环累加 vs. 向量化一次性计算)可能导致微小的舍入误差累积,直接使用 == 运算符比较两个结果张量可能会返回 False,即使它们在数学上是等价的。

为了正确地比较两个浮点张量是否“相等”(即在可接受的误差范围内),PyTorch提供了 torch.allclose() 函数。

# 重新运行循环实现以获取 sumation_old
sumation_old = 0
for i in range(m):
    sumation_old = sumation_old + a[i] / (A - b[i] * torch.eye(n))

# 比较结果
print(f"\n直接比较 (summation_old == summation_new).all(): {(sumation_old == summation_new).all()}")
print(f"使用 torch.allclose 比较: {torch.allclose(sumation_old, summation_new)}")

torch.allclose 会返回 True,表明尽管存在微小的数值差异,但两个结果在数值上是等价的。

5. 总结与注意事项

  • 性能提升: 向量化是PyTorch及其他数值计算库中提高性能的关键技术。它将一系列独立的标量或小张量操作转换为单个大型张量操作,从而能够充分利用底层高度优化的C++/CUDA实现,并实现GPU加速。
  • 代码简洁性: 向量化代码通常比循环代码更简洁、更易读,减少了样板代码。
  • 内存管理: 虽然广播机制避免了显式复制,但中间张量的创建仍然会占用内存。在处理极其巨大的张量时,需要注意内存消耗。
  • 维度匹配: 理解 unsqueeze()、view()、reshape() 等维度操作以及广播规则是编写高效PyTorch代码的基础。广播要求张量维度从末尾开始向前匹配,或者其中一个维度为1。
  • 数值稳定性: 尽管 torch.allclose 可以验证结果的近似相等性,但在某些极端数值计算场景下,不同的实现路径确实可能导致显著的数值差异。通常,向量化实现由于其并行性,有时在数值稳定性上甚至优于串行累加。

通过本教程,读者应能掌握在PyTorch中将循环操作向量化的基本原理和实践方法,从而编写出更高效、更专业的深度学习代码。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
java基础知识汇总
java基础知识汇总

java基础知识有Java的历史和特点、Java的开发环境、Java的基本数据类型、变量和常量、运算符和表达式、控制语句、数组和字符串等等知识点。想要知道更多关于java基础知识的朋友,请阅读本专题下面的的有关文章,欢迎大家来php中文网学习。

1502

2023.10.24

Go语言中的运算符有哪些
Go语言中的运算符有哪些

Go语言中的运算符有:1、加法运算符;2、减法运算符;3、乘法运算符;4、除法运算符;5、取余运算符;6、比较运算符;7、位运算符;8、按位与运算符;9、按位或运算符;10、按位异或运算符等等。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

232

2024.02.23

php三元运算符用法
php三元运算符用法

本专题整合了php三元运算符相关教程,阅读专题下面的文章了解更多详细内容。

87

2025.10.17

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

433

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

24

2025.12.22

java入门学习合集
java入门学习合集

本专题整合了java入门学习指南、初学者项目实战、入门到精通等等内容,阅读专题下面的文章了解更多详细学习方法。

1

2026.01.29

java配置环境变量教程合集
java配置环境变量教程合集

本专题整合了java配置环境变量设置、步骤、安装jdk、避免冲突等等相关内容,阅读专题下面的文章了解更多详细操作。

2

2026.01.29

java成品学习网站推荐大全
java成品学习网站推荐大全

本专题整合了java成品网站、在线成品网站源码、源码入口等等相关内容,阅读专题下面的文章了解更多详细推荐内容。

0

2026.01.29

Java字符串处理使用教程合集
Java字符串处理使用教程合集

本专题整合了Java字符串截取、处理、使用、实战等等教程内容,阅读专题下面的文章了解详细操作教程。

0

2026.01.29

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 22.4万人学习

Django 教程
Django 教程

共28课时 | 3.7万人学习

SciPy 教程
SciPy 教程

共10课时 | 1.3万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号