0

0

在 torch.vmap 中高效处理内部张量创建

碧海醫心

碧海醫心

发布时间:2025-10-20 14:47:07

|

213人浏览过

|

来源于php中文网

原创

在 torch.vmap 中高效处理内部张量创建

理解 torch.vmap 与内部张量创建的挑战

torch.vmap 是 PyTorch 提供的一个强大工具,它允许我们将一个处理单个样本的函数(即非批处理函数)转换为一个能够高效处理一批样本的函数,而无需手动管理批处理维度。这在编写通用代码和加速计算方面非常有用。然而,当被 vmap 向量化的函数内部需要创建新的张量,并且这些张量的形状依赖于批处理输入的形状时,就会遇到一个常见的陷阱。

考虑以下场景:我们有一个函数 polycompanion,它接收一个多项式系数张量,并计算其伴随矩阵。伴随矩阵的维度取决于多项式的次数。

import torch

poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)

def polycompanion(polynomial):
    # polynomial.shape[-1] 是多项式系数的个数,例如 [a, b, c, d] 代表 ax^3 + bx^2 + cx + d
    # 次数 deg = 系数个数 - 1 - 1 = 系数个数 - 2 (如果最后一个系数是常数项)
    deg = polynomial.shape[-1] - 2

    # 尝试创建伴随矩阵
    companion = torch.zeros((deg + 1, deg + 1), dtype=torch.float32)

    # 填充单位矩阵部分
    companion[1:, :-1] = torch.eye(deg, dtype=torch.float32)

    # 填充最后一列
    # 注意这里 polynomial[:-1] 表示除了最后一个系数以外的所有系数
    # polynomial[-1] 表示最后一个系数
    companion[:, -1] = -1. * polynomial[:-1] / polynomial[-1]
    return companion

# 尝试使用 vmap 向量化
polycompanion_vmap = torch.vmap(polycompanion)

try:
    print(polycompanion_vmap(poly_batched))
except Exception as e:
    print(f"Initial attempt failed: {e}")

上述代码在执行 polycompanion_vmap(poly_batched) 时会失败。原因是 polycompanion 函数内部通过 torch.zeros((deg+1, deg+1)) 创建了一个新的 companion 张量。尽管 deg 是从 polynomial(一个批处理输入)派生出来的,但 torch.zeros 本身创建的是一个普通的、非批处理的张量。当 vmap 试图对这个非批处理的 companion 张量执行批处理操作(例如,将其与从 polynomial 派生的批处理张量进行索引或赋值)时,就会出现维度不匹配或类型不兼容的问题,因为 vmap 期望所有参与运算的张量都带有批处理维度。

为什么 torch.zeros 不会自动批处理?

torch.vmap 的核心机制是跟踪批处理维度,并将操作提升到批处理层面。它能识别作为 vmap 输入的张量及其通过各种张量操作(如加法、乘法、切片等)派生出的张量,并为它们自动添加和管理批处理维度。然而,像 torch.zeros 这种从零开始创建新张量的操作,其默认行为是创建一个标准张量,不包含任何批处理维度信息。即使其形状参数 (deg+1, deg+1) 是基于批处理输入计算得出的,torch.zeros 也无法“感知”到外部的 vmap 上下文,从而无法自动生成一个 BatchedTensor。

torch.zeros_like 是一个例外,因为它基于一个已存在的张量来创建新张量。如果这个已存在的张量是 BatchedTensor,那么 torch.zeros_like 也能创建出一个 BatchedTensor。但在本例中,我们没有一个现成的 BatchedTensor 可以作为 zeros_like 的模板来创建 companion。

规避方案:预分配与外部传递

一种可行的(但不理想的)规避方法是,在调用 vmap 之前,手动创建一个带有批处理维度的 companion 张量,并将其作为函数的额外输入传递给 vmap。

def polycompanion_workaround(polynomial, companion_template):
    # 注意:这里的 deg 现在从 companion_template 的形状推断,因为它已经有了批处理维度
    deg = companion_template.shape[-1] - 1 

    # 在传入的 companion_template 上进行就地修改
    companion_template[1:, :-1] = torch.eye(deg, dtype=torch.float32)
    companion_template[:, -1] = -1. * polynomial[:-1] / polynomial[-1]
    return companion_template

polycompanion_vmap_workaround = torch.vmap(polycompanion_workaround)

# 预先创建批处理的 companion 模板
# poly_batched.shape[0] 是批次大小
# poly_batched.shape[-1]-1 是伴随矩阵的行/列维度
companion_init_shape = (poly_batched.shape[0], poly_batched.shape[-1] - 1, poly_batched.shape[-1] - 1)
pre_batched_companion = torch.zeros(companion_init_shape, dtype=torch.float32)

print("--- Workaround Output ---")
print(polycompanion_vmap_workaround(poly_batched, pre_batched_companion))

这种方法虽然能够正确输出结果,但存在明显缺点:

杰易OA办公自动化系统6.0
杰易OA办公自动化系统6.0

基于Intranet/Internet 的Web下的办公自动化系统,采用了当今最先进的PHP技术,是综合大量用户的需求,经过充分的用户论证的基础上开发出来的,独特的即时信息、短信、电子邮件系统、完善的工作流、数据库安全备份等功能使得信息在企业内部传递效率极大提高,信息传递过程中耗费降到最低。办公人员得以从繁杂的日常办公事务处理中解放出来,参与更多的富于思考性和创造性的工作。系统力求突出体系结构简明

下载
  1. 函数签名改变:polycompanion 函数现在需要一个额外的 companion_template 参数,这破坏了其原始的、独立处理单个样本的语义。
  2. 外部依赖:在调用 vmap 之前,必须手动计算并创建具有正确批处理维度的 pre_batched_companion 张量,增加了代码的复杂性和耦合性。

推荐解决方案:利用 clone 和 concatenate

为了在 vmap 上下文中优雅地创建和填充张量,我们可以避免在非批处理的 torch.zeros 张量上进行就地修改。相反,我们将伴随矩阵视为由两部分组成:一个包含单位矩阵的左侧部分,以及一个由多项式系数计算得出的右侧(最后一列)部分。然后,我们分别构建这两部分,并使用 torch.concatenate 将它们合并。

关键在于:

  1. 静态部分:对于伴随矩阵中相对固定的部分(如单位矩阵),我们可以先在一个非批处理的 torch.zeros 张量上构建。
  2. 动态部分:对于依赖于批处理输入的部分(如最后一列),我们直接从批处理输入 polynomial 计算。
  3. 合并:使用 torch.concatenate 将这两部分合并。concatenate 是一种张量操作,vmap 能够很好地处理其批处理行为。

以下是改进后的 polycompanion 函数:

def polycompanion_optimized(polynomial):
    deg = polynomial.shape[-1] - 2

    # 1. 创建一个基础的非批处理张量来填充单位矩阵部分
    # 这是一个临时的、非批处理的张量
    base_matrix = torch.zeros((deg + 1, deg + 1), dtype=torch.float32)
    base_matrix[1:, :-1] = torch.eye(deg, dtype=torch.float32)

    # 2. 提取 base_matrix 的左侧部分,并进行克隆
    # clone() 创建了一个新的张量,虽然它仍然是非批处理的,
    # 但在 vmap 上下文中,当它与批处理张量拼接时,vmap 会正确处理
    left_part = base_matrix[:, :-1].clone()

    # 3. 计算伴随矩阵的最后一列
    # 这一部分完全从批处理输入 polynomial 派生,因此 vmap 会将其视为批处理张量
    # polynomial[:-1] 是 (deg+1,) 形状
    # polynomial[-1] 是标量
    # 结果是一个 (deg+1,) 形状的张量
    last_column_values = -1. * polynomial[:-1] / polynomial[-1]

    # 4. 扩展最后一列的维度,使其可以与 left_part 进行拼接
    # last_column_values 是 (deg+1,),我们需要将其变为 (deg+1, 1)
    last_column_reshaped = last_column_values[:, None] 

    # 5. 使用 concatenate 组合左右两部分
    # vmap 会识别 left_part 和 last_column_reshaped,并为它们在批次维度上执行拼接
    final_companion = torch.concatenate([left_part, last_column_reshaped], dim=1)

    return final_companion

polycompanion_vmap_optimized = torch.vmap(polycompanion_optimized)

print("\n--- Optimized Solution Output ---")
print(polycompanion_vmap_optimized(poly_batched))

输出:

tensor([[[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]],

        [[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]]])

这个解决方案成功地生成了批处理的伴随矩阵,同时保持了 polycompanion_optimized 函数的简洁性,使其能够独立处理单个样本,并且不需要外部预分配张量。

注意事项与最佳实践

  • 函数式编程思维:在使用 torch.vmap 时,尽量采用函数式编程的思维,即函数主要通过返回新张量来完成操作,而不是通过就地修改输入张量。这有助于 vmap 更好地跟踪张量的依赖关系和批处理维度。
  • 避免在 vmap 内部进行就地修改:除非你确切知道自己在做什么,并且只对批处理输入进行就地修改,否则应避免在 vmap 内部对非批处理张量进行就地修改。
  • clone() 的作用:在上述解决方案中,clone() 是关键。它创建了一个 base_matrix 切片的新副本。虽然 base_matrix 本身是非批处理的,但通过 clone() 得到的 left_part 可以被 concatenate 操作正确地与批处理的 last_column_reshaped 结合。
  • 维度匹配:当使用 torch.concatenate 或 torch.stack 时,确保所有参与拼接的张量在非拼接维度上形状一致。[:, None] 技巧常用于为张量添加一个维度,使其符合拼接要求。
  • 性能考量:虽然 concatenate 方案解决了功能问题,但频繁创建和拼接中间张量可能会带来一定的性能开销。对于极致性能敏感的场景,可能需要权衡 vmap 的便利性与手动批处理的优化潜力。然而,对于大多数情况,vmap 带来的代码简化和潜在加速(尤其是在支持的后端)是值得的。

总结

在 torch.vmap 中处理函数内部的张量创建是一个常见的挑战。通过理解 vmap 对批处理张量的期望,并采用 clone() 结合 torch.concatenate 的策略,我们能够优雅地构建出所需的批处理张量,而无需妥协函数的简洁性或引入复杂的外部依赖。这种方法体现了在 PyTorch 中进行高效张量操作的灵活性和强大功能,是掌握 torch.vmap 的一个重要技巧。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
go语言 数组和切片
go语言 数组和切片

本专题整合了go语言数组和切片的区别与含义,阅读专题下面的文章了解更多详细内容。

46

2025.09.03

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

432

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

24

2025.12.22

Python 自然语言处理(NLP)基础与实战
Python 自然语言处理(NLP)基础与实战

本专题系统讲解 Python 在自然语言处理(NLP)领域的基础方法与实战应用,涵盖文本预处理(分词、去停用词)、词性标注、命名实体识别、关键词提取、情感分析,以及常用 NLP 库(NLTK、spaCy)的核心用法。通过真实文本案例,帮助学习者掌握 使用 Python 进行文本分析与语言数据处理的完整流程,适用于内容分析、舆情监测与智能文本应用场景。

10

2026.01.27

拼多多赚钱的5种方法 拼多多赚钱的5种方法
拼多多赚钱的5种方法 拼多多赚钱的5种方法

在拼多多上赚钱主要可以通过无货源模式一件代发、精细化运营特色店铺、参与官方高流量活动、利用拼团机制社交裂变,以及成为多多进宝推广员这5种方法实现。核心策略在于通过低成本、高效率的供应链管理与营销,利用平台社交电商红利实现盈利。

109

2026.01.26

edge浏览器怎样设置主页 edge浏览器自定义设置教程
edge浏览器怎样设置主页 edge浏览器自定义设置教程

在Edge浏览器中设置主页,请依次点击右上角“...”图标 > 设置 > 开始、主页和新建标签页。在“Microsoft Edge 启动时”选择“打开以下页面”,点击“添加新页面”并输入网址。若要使用主页按钮,需在“外观”设置中开启“显示主页按钮”并设定网址。

16

2026.01.26

苹果官方查询网站 苹果手机正品激活查询入口
苹果官方查询网站 苹果手机正品激活查询入口

苹果官方查询网站主要通过 checkcoverage.apple.com/cn/zh/ 进行,可用于查询序列号(SN)对应的保修状态、激活日期及技术支持服务。此外,查找丢失设备请使用 iCloud.com/find,购买信息与物流可访问 Apple (中国大陆) 订单状态页面。

138

2026.01.26

npd人格什么意思 npd人格有什么特征
npd人格什么意思 npd人格有什么特征

NPD(Narcissistic Personality Disorder)即自恋型人格障碍,是一种心理健康问题,特点是极度夸大自我重要性、需要过度赞美与关注,同时极度缺乏共情能力,背后常掩藏着低自尊和不安全感,影响人际关系、工作和生活,通常在青少年时期开始显现,需由专业人士诊断。

7

2026.01.26

windows安全中心怎么关闭 windows安全中心怎么执行操作
windows安全中心怎么关闭 windows安全中心怎么执行操作

关闭Windows安全中心(Windows Defender)可通过系统设置暂时关闭,或使用组策略/注册表永久关闭。最简单的方法是:进入设置 > 隐私和安全性 > Windows安全中心 > 病毒和威胁防护 > 管理设置,将实时保护等选项关闭。

6

2026.01.26

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
React 教程
React 教程

共58课时 | 4.2万人学习

Pandas 教程
Pandas 教程

共15课时 | 1.0万人学习

ASP 教程
ASP 教程

共34课时 | 4.1万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号