0

0

在torch.vmap中高效创建与操作批处理张量

霞舞

霞舞

发布时间:2025-10-20 14:39:10

|

187人浏览过

|

来源于php中文网

原创

在torch.vmap中高效创建与操作批处理张量

在使用`torch.vmap`进行函数向量化时,直接在被向量化的函数内部使用`torch.zeros`创建新的张量并期望其自动获得批处理维度是一个常见挑战。本文将深入探讨这一问题,并提供一种优雅的解决方案:通过结合`clone()`和`torch.concatenate`,可以有效地在`vmap`环境中创建和填充具有正确批处理维度的张量,从而避免手动传递预先创建的批处理张量,实现代码的简洁与高效。

torch.vmap与批处理张量创建的挑战

torch.vmap是PyTorch中一个强大的工具,它允许用户对批量输入高效地应用一个单样本函数,而无需手动编写循环或调整张量维度。然而,当被向量化的函数需要在内部创建新的张量时,一个常见的陷阱是这些新创建的张量并不会自动继承批处理维度。

考虑一个计算多项式伴随矩阵的函数polycompanion。这个函数需要根据输入多项式polynomial的维度创建一个新的零矩阵companion,然后填充其部分内容。

import torch

poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)

def polycompanion(polynomial):
    # 计算伴随矩阵的维度
    deg = polynomial.shape[-1] - 2
    # 创建一个 (deg+1, deg+1) 的零矩阵
    companion = torch.zeros((deg + 1, deg + 1), dtype=torch.float32)
    # 填充单位矩阵部分
    companion[1:, :-1] = torch.eye(deg, dtype=torch.float32)
    # 填充最后一列,这部分依赖于输入多项式
    companion[:, -1] = -1. * polynomial[:-1] / polynomial[-1]
    return companion

# 尝试使用 vmap 向量化该函数
polycompanion_vmap = torch.vmap(polycompanion)

# 预期会遇到问题,因为 companion 不是 BatchedTensor
# print(polycompanion_vmap(poly_batched))
# 上述代码会因 vmap 无法处理非 BatchedTensor 的原地操作而失败

在上述代码中,torch.vmap在执行polycompanion时,polynomial是一个BatchedTensor。然而,companion = torch.zeros((deg + 1, deg + 1))创建的companion张量并不是BatchedTensor。当尝试对companion进行原地修改,特别是当修改操作涉及polynomial(一个BatchedTensor)时,vmap无法正确地跟踪和应用批处理语义,导致运行时错误。

常见的“丑陋”解决方案及其局限性

为了规避这个问题,一种常见的(但不推荐的)做法是预先在vmap外部创建批处理的零张量,并将其作为参数传递给被向量化的函数。

import torch

poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)

def polycompanion_workaround(polynomial, companion_template):
    # 注意:这里的 deg 需要根据 companion_template 的形状来推断,或者与 polynomial 保持一致
    # 为了简化,我们假设 companion_template 已经有正确的形状
    deg = companion_template.shape[-1] - 1 # 假设 companion_template 已经是 (deg+1, deg+1)

    # 在 template 上进行原地操作
    companion_template[1:, :-1] = torch.eye(deg, dtype=torch.float32)
    companion_template[:, -1] = -1. * polynomial[:-1] / polynomial[-1]
    return companion_template

polycompanion_vmap_workaround = torch.vmap(polycompanion_workaround)

# 预先创建批处理的零张量
batch_size = poly_batched.shape[0]
companion_dim = poly_batched.shape[-1] - 1 # (deg+1)
initial_companion = torch.zeros(batch_size, companion_dim, companion_dim, dtype=torch.float32)

# 传递预创建的批处理张量
output_workaround = polycompanion_vmap_workaround(poly_batched, initial_companion)
print("Workaround Output:")
print(output_workaround)

输出:

杰易OA办公自动化系统6.0
杰易OA办公自动化系统6.0

基于Intranet/Internet 的Web下的办公自动化系统,采用了当今最先进的PHP技术,是综合大量用户的需求,经过充分的用户论证的基础上开发出来的,独特的即时信息、短信、电子邮件系统、完善的工作流、数据库安全备份等功能使得信息在企业内部传递效率极大提高,信息传递过程中耗费降到最低。办公人员得以从繁杂的日常办公事务处理中解放出来,参与更多的富于思考性和创造性的工作。系统力求突出体系结构简明

下载
Workaround Output:
tensor([[[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]],

        [[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]]])

这种方法虽然能工作,但它破坏了函数的封装性,使得函数签名的设计变得复杂,且在函数内部无法动态决定新张量的批处理大小,不够灵活。

优雅的解决方案:clone()与torch.concatenate

解决此问题的关键在于,对于需要批处理的张量,我们必须确保其批处理维度在vmap的上下文中是明确的。如果一个张量的一部分内容依赖于批处理输入,而另一部分是固定的,我们可以将它们分别处理,然后合并。

核心思路是:

  1. 创建非批处理的固定部分(例如单位矩阵部分)。
  2. 创建批处理的动态部分(例如最后一列,它依赖于polynomial)。
  3. 使用clone()确保非批处理部分可以被独立地操作和复制。
  4. 使用torch.concatenate将这两部分沿着正确的维度合并,同时利用None来添加缺失的维度以进行匹配。
import torch

poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)

def polycompanion_refined(polynomial):
    deg = polynomial.shape[-1] - 2

    # 1. 创建一个非批处理的零矩阵作为基础
    companion_base = torch.zeros((deg + 1, deg + 1), dtype=torch.float32)

    # 2. 填充单位矩阵部分(这部分是固定的,不依赖于批处理)
    # 注意:这里我们只填充除了最后一列之外的部分
    companion_base[1:, :-1] = torch.eye(deg, dtype=torch.float32)

    # 3. 计算最后一列,这部分是依赖于 polynomial (BatchedTensor) 的,因此会是 BatchedTensor
    last_column_batched = -1. * polynomial[:-1] / polynomial[-1]

    # 4. 准备合并:
    #    - companion_base[:, :-1] 是非批处理的,需要 clone 以便后续操作。
    #      clone() 确保 vmap 可以对每个批次独立处理这个副本。
    #    - last_column_batched 是一个一维的 BatchedTensor,形状为 (batch_size, deg+1)。
    #      为了与 companion_base[:, :-1] (形状为 (deg+1, deg)) 合并,
    #      需要将其扩展为 (batch_size, deg+1, 1) 的形状,通过 [:, None] 实现。
    _companion = torch.concatenate([
        companion_base[:, :-1].clone(), # 克隆非批处理的左侧部分
        last_column_batched[:, None]    # 批处理的右侧列,添加一个维度使其可合并
    ], dim=1) # 沿着列维度合并

    return _companion

polycompanion_vmap_refined = torch.vmap(polycompanion_refined)
output_refined = polycompanion_vmap_refined(poly_batched)
print("\nRefined Solution Output:")
print(output_refined)

输出:

Refined Solution Output:
tensor([[[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]],

        [[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]]])

注意事项与总结

  1. torch.zeros_like的适用性:如果新张量的形状可以直接从一个批处理输入张量派生,并且所有元素都初始化为零,那么torch.zeros_like(batched_input)可以很好地工作,因为它会创建一个BatchedTensor。然而,在伴随矩阵的例子中,我们需要一个特定形状的零矩阵,其大小与输入张量的最后一个维度相关,但并非完全相同,且后续需要部分填充。因此,zeros_like在此场景下并不直接适用。
  2. clone()的重要性:在vmap环境中,当一个张量(如companion_base[:, :-1])不是BatchedTensor但需要与BatchedTensor(如last_column_batched)合并时,对其调用clone()可以有效地为每个批次创建一个独立的副本。这使得vmap能够独立地处理每个批次的合并操作,而不会因为原始张量不是批处理的而产生冲突。
  3. 维度匹配:torch.concatenate要求所有输入张量在非合并维度上具有相同的形状。在我们的例子中,last_column_batched是一个形状为(batch_size, deg+1)的一维批处理张量。为了与形状为(deg+1, deg)的companion_base[:, :-1].clone()合并,我们需要将last_column_batched的形状调整为(batch_size, deg+1, 1),这通过[:, None]索引实现,它在最后一个维度上添加了一个新的维度。

通过这种clone()和torch.concatenate的组合,我们能够在torch.vmap的上下文中,在函数内部灵活且优雅地创建和填充新的批处理张量,从而保持代码的简洁性和功能性,避免了不必要的外部参数传递。这种模式对于在vmap函数中构建复杂张量结构非常有用。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

432

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

24

2025.12.22

php中文乱码如何解决
php中文乱码如何解决

本文整理了php中文乱码如何解决及解决方法,阅读节专题下面的文章了解更多详细内容。

1

2026.01.28

Java 消息队列与异步架构实战
Java 消息队列与异步架构实战

本专题系统讲解 Java 在消息队列与异步系统架构中的核心应用,涵盖消息队列基本原理、Kafka 与 RabbitMQ 的使用场景对比、生产者与消费者模型、消息可靠性与顺序性保障、重复消费与幂等处理,以及在高并发系统中的异步解耦设计。通过实战案例,帮助学习者掌握 使用 Java 构建高吞吐、高可靠异步消息系统的完整思路。

1

2026.01.28

Python 自然语言处理(NLP)基础与实战
Python 自然语言处理(NLP)基础与实战

本专题系统讲解 Python 在自然语言处理(NLP)领域的基础方法与实战应用,涵盖文本预处理(分词、去停用词)、词性标注、命名实体识别、关键词提取、情感分析,以及常用 NLP 库(NLTK、spaCy)的核心用法。通过真实文本案例,帮助学习者掌握 使用 Python 进行文本分析与语言数据处理的完整流程,适用于内容分析、舆情监测与智能文本应用场景。

23

2026.01.27

拼多多赚钱的5种方法 拼多多赚钱的5种方法
拼多多赚钱的5种方法 拼多多赚钱的5种方法

在拼多多上赚钱主要可以通过无货源模式一件代发、精细化运营特色店铺、参与官方高流量活动、利用拼团机制社交裂变,以及成为多多进宝推广员这5种方法实现。核心策略在于通过低成本、高效率的供应链管理与营销,利用平台社交电商红利实现盈利。

120

2026.01.26

edge浏览器怎样设置主页 edge浏览器自定义设置教程
edge浏览器怎样设置主页 edge浏览器自定义设置教程

在Edge浏览器中设置主页,请依次点击右上角“...”图标 > 设置 > 开始、主页和新建标签页。在“Microsoft Edge 启动时”选择“打开以下页面”,点击“添加新页面”并输入网址。若要使用主页按钮,需在“外观”设置中开启“显示主页按钮”并设定网址。

51

2026.01.26

苹果官方查询网站 苹果手机正品激活查询入口
苹果官方查询网站 苹果手机正品激活查询入口

苹果官方查询网站主要通过 checkcoverage.apple.com/cn/zh/ 进行,可用于查询序列号(SN)对应的保修状态、激活日期及技术支持服务。此外,查找丢失设备请使用 iCloud.com/find,购买信息与物流可访问 Apple (中国大陆) 订单状态页面。

192

2026.01.26

npd人格什么意思 npd人格有什么特征
npd人格什么意思 npd人格有什么特征

NPD(Narcissistic Personality Disorder)即自恋型人格障碍,是一种心理健康问题,特点是极度夸大自我重要性、需要过度赞美与关注,同时极度缺乏共情能力,背后常掩藏着低自尊和不安全感,影响人际关系、工作和生活,通常在青少年时期开始显现,需由专业人士诊断。

7

2026.01.26

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
React 教程
React 教程

共58课时 | 4.2万人学习

Pandas 教程
Pandas 教程

共15课时 | 1.0万人学习

ASP 教程
ASP 教程

共34课时 | 4.1万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号