0

0

PyTorch从零实现ViT时的张量形状错误排查与修复指南

碧海醫心

碧海醫心

发布时间:2026-03-18 16:46:01

|

822人浏览过

|

来源于php中文网

原创

PyTorch从零实现ViT时的张量形状错误排查与修复指南

本文详解Vision Transformer(ViT)自定义实现中常见的mat1 and mat2 shapes cannot be multiplied错误,聚焦Patch Embedding层输出维度失配问题,通过形状追踪、代码修正与关键设计原则,帮助开发者快速定位并修复线性投影层的输入维度不匹配缺陷。

本文详解vision transformer(vit)自定义实现中常见的`mat1 and mat2 shapes cannot be multiplied`错误,聚焦patch embedding层输出维度失配问题,通过形状追踪、代码修正与关键设计原则,帮助开发者快速定位并修复线性投影层的输入维度不匹配缺陷。

在PyTorch中从零构建Vision Transformer时,报错 RuntimeError: mat1 and mat2 shapes cannot be multiplied (30x50176 and 768x768) 是一个典型的张量维度不兼容信号——它明确指出:某次线性变换(nn.Linear)试图将一个形状为 (30, 50176) 的输入(即 batch_size=30,特征维=50176)乘以权重矩阵 (768, 768),而矩阵乘法要求输入的最后一维(50176)必须等于权重的第一维(768),显然不成立。

问题根源不在Transformer主干或注意力机制,而始于最前端的 PatchEmbedding 模块。我们来逐步追踪前向传播中的张量形状变化:

假设输入图像为标准ViT配置:batch_size=30, C=3, H=224, W=224(经Resize(224,224)和ToTensor()后):

# 输入 x.shape == (30, 3, 224, 224)
x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
# ❌ 错误的reshape顺序!原代码中:
# x.reshape(B, C, H // self.patch_size, W // self.patch_size, self.patch_size, self.patch_size)
# 导致后续permute逻辑混乱,最终flatten(2)产生错误维度

✅ 正确的patch划分应将每个 (patch_size, patch_size) 区域展平为一个向量,共 (H//p) * (W//p) 个patch,每个patch含 C * p * p 维特征。因此:

  • num_patches = (224 // 16) * (224 // 16) = 14 * 14 = 196
  • patch_dim = C * patch_size * patch_size = 3 * 16 * 16 = 768

这意味着:PatchEmbedding 的输入特征维必须是 768,才能被 nn.Linear(768, embed_dim)(即 Linear(768, 768))正确处理。

但原代码中 x.flatten(2) 的位置和 reshape 顺序有误,导致实际展平维度远超预期(50176 ≈ 30 × 196 × ?,实为误算的中间态)。根本修复方式如下:

Hotpot AI Background Remover
Hotpot AI Background Remover

Hotpot.ai推出的图片背景移除工具

下载

✅ 修正 PatchEmbedding 实现(推荐清晰写法)

class PatchEmbedding(torch.nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        # 输入维度 = 3 * 16 * 16 = 768 → 输出 embed_dim = 768
        self.projection = torch.nn.Linear(
            in_features=patch_size * patch_size * in_channels,
            out_features=embed_dim
        )

    def forward(self, x):
        B, C, H, W = x.shape
        # Step 1: 将空间维度拆分为 patch 网格 —— (B, C, H, W) → (B, C, num_h, p, num_w, p)
        p = self.patch_size
        num_h, num_w = H // p, W // p
        x = x.reshape(B, C, num_h, p, num_w, p)
        # Step 2: 调整顺序使 patch 内部连续 → (B, num_h, num_w, C, p, p)
        x = x.permute(0, 2, 4, 1, 3, 5)
        # Step 3: 展平每个 patch 的通道与像素 → (B, num_h * num_w, C * p * p)
        x = x.flatten(3)  # flatten last two dims: (C, p, p) → (C*p*p)
        x = x.flatten(1, 2)  # flatten num_h and num_w → (B, num_h*num_w, C*p*p)
        # Step 4: 投影到 embedding 空间
        x = self.projection(x)  # (B, 196, 768)
        return x

? 验证形状:输入 (30, 3, 224, 224) → 输出 (30, 196, 768),完美匹配后续 MultiHeadAttention(embed_dim=768) 的输入要求。

⚠️ 其他关键注意事项

  • cls_token 使用需显式添加:标准ViT在patch嵌入后需拼接可学习的 [CLS] token。当前代码中 x[:, cls_token] 假设 cls_token=0,但未实际插入该token,会导致索引越界或语义错误。应在 PatchEmbedding 后补充:

    self.cls_token = torch.nn.Parameter(torch.randn(1, 1, embed_dim))
    # 在 VisionTransformer.forward 中:
    cls_tokens = self.cls_token.expand(B, -1, -1)
    x = torch.cat((cls_tokens, x), dim=1)  # (B, 197, 768)
  • TransformerEncoder 结构需符合标准范式:当前实现将 LayerNorm 放在注意力之后、FFN之前,但标准ViT采用 Post-LN(即每个子层后接LN),且FFN应为两层线性+GELU(非单层Linear(embed_dim, embed_dim))。建议修正为:

    class TransformerBlock(torch.nn.Module):
        def __init__(self, embed_dim, num_heads):
            super().__init__()
            self.attn = MultiHeadAttention(embed_dim, num_heads)
            self.norm1 = torch.nn.LayerNorm(embed_dim)
            self.mlp = torch.nn.Sequential(
                torch.nn.Linear(embed_dim, embed_dim * 4),
                torch.nn.GELU(),
                torch.nn.Dropout(0.1),
                torch.nn.Linear(embed_dim * 4, embed_dim),
                torch.nn.Dropout(0.1)
            )
            self.norm2 = torch.nn.LayerNorm(embed_dim)
    
        def forward(self, x):
            x = x + self.attn(self.norm1(x))   # residual + attn
            x = x + self.mlp(self.norm2(x))     # residual + mlp
            return x
  • 调试技巧:插入 shape 打印
    在 forward 中添加临时日志,快速定位断点:

    print(f"[PatchEmbed] input: {x.shape} → output: {x.shape}")

✅ 总结

该错误本质是Patch Embedding 层未正确生成 (B, num_patches, embed_dim) 形状张量,导致下游线性层输入维度爆炸(50176)。修复核心在于:

  1. 严格按 C × p × p 计算单patch特征维;
  2. 使用 permute + flatten 正确重组张量,确保 flatten 作用于正确的轴;
  3. 始终验证每层输入/输出shape,尤其在自定义几何操作时;
  4. 遵循ViT原始论文结构(如CLS token、LayerNorm位置、MLP扩展比)。

完成上述修正后,模型即可顺利运行,为后续位置编码、训练调优打下坚实基础。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

WorkBuddy
WorkBuddy

腾讯云推出的AI原生桌面智能体工作台

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

473

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

29

2025.12.22

Python WebSocket实时通信与异步服务开发实践
Python WebSocket实时通信与异步服务开发实践

本专题聚焦 Python 在实时通信场景中的开发实践,系统讲解 WebSocket 协议原理、长连接管理、消息推送机制以及异步服务架构设计。内容包括客户端与服务端通信实现、连接稳定性优化、消息队列集成及高并发处理策略。通过完整案例,帮助开发者构建高效稳定的实时通信系统,适用于聊天应用、实时数据推送等场景。

2

2026.03.18

Java Spring Security权限控制与认证机制实战
Java Spring Security权限控制与认证机制实战

本专题围绕 Java 后端安全体系建设展开,重点讲解 Spring Security 在权限控制与认证机制中的应用实践。内容涵盖用户认证流程、权限模型设计、JWT 鉴权方案、OAuth2 集成以及接口安全防护策略。通过实际项目案例,帮助开发者构建安全可靠的后端认证体系,提升系统安全性与可扩展能力。

0

2026.03.18

抖漫入口地址合集
抖漫入口地址合集

本专题整合了抖漫入口地址相关合集,阅读专题下面的文章了解更多详细地址。

110

2026.03.17

多环境下的 Nginx 安装、结构与运维实战
多环境下的 Nginx 安装、结构与运维实战

本专题聚焦多环境下Nginx实战,详解开发、测试及生产环境的差异化安装策略与目录结构规划。深入剖析配置模块化设计、灰度发布流程及跨环境同步机制。结合监控告警、故障排查与自动化运维工具,提供全链路管理方案,助力团队构建灵活、高可用的Nginx服务体系,从容应对复杂业务场景挑战。

13

2026.03.17

PS 批量添加图片
PS 批量添加图片

本专题整合了PS批量添加图片教程合集,阅读专题下面的文章了解更多详细操作。

10

2026.03.17

Nginx 基础架构:从安装配置到系统化管理
Nginx 基础架构:从安装配置到系统化管理

本专题深入解析Nginx基础架构,涵盖从源码编译与包管理安装,到核心配置文件优化及虚拟主机部署。进一步探讨日志轮转、性能调优、高可用集群构建及自动化运维策略,助力管理员实现从单一服务搭建到企业级系统化管理的全面升级,确保Web服务高效、稳定运行。

7

2026.03.17

mulerun骡子快跑入口地址汇总
mulerun骡子快跑入口地址汇总

本专题整合了mulerun入口地址合集,阅读专题下面的文章了解更多详细内容。

216

2026.03.17

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号