
本文详解Vision Transformer(ViT)自定义实现中常见的mat1 and mat2 shapes cannot be multiplied错误,聚焦Patch Embedding层输出维度失配问题,通过形状追踪、代码修正与关键设计原则,帮助开发者快速定位并修复线性投影层的输入维度不匹配缺陷。
本文详解vision transformer(vit)自定义实现中常见的`mat1 and mat2 shapes cannot be multiplied`错误,聚焦patch embedding层输出维度失配问题,通过形状追踪、代码修正与关键设计原则,帮助开发者快速定位并修复线性投影层的输入维度不匹配缺陷。
在PyTorch中从零构建Vision Transformer时,报错 RuntimeError: mat1 and mat2 shapes cannot be multiplied (30x50176 and 768x768) 是一个典型的张量维度不兼容信号——它明确指出:某次线性变换(nn.Linear)试图将一个形状为 (30, 50176) 的输入(即 batch_size=30,特征维=50176)乘以权重矩阵 (768, 768),而矩阵乘法要求输入的最后一维(50176)必须等于权重的第一维(768),显然不成立。
问题根源不在Transformer主干或注意力机制,而始于最前端的 PatchEmbedding 模块。我们来逐步追踪前向传播中的张量形状变化:
假设输入图像为标准ViT配置:batch_size=30, C=3, H=224, W=224(经Resize(224,224)和ToTensor()后):
# 输入 x.shape == (30, 3, 224, 224) x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size) # ❌ 错误的reshape顺序!原代码中: # x.reshape(B, C, H // self.patch_size, W // self.patch_size, self.patch_size, self.patch_size) # 导致后续permute逻辑混乱,最终flatten(2)产生错误维度
✅ 正确的patch划分应将每个 (patch_size, patch_size) 区域展平为一个向量,共 (H//p) * (W//p) 个patch,每个patch含 C * p * p 维特征。因此:
- num_patches = (224 // 16) * (224 // 16) = 14 * 14 = 196
- patch_dim = C * patch_size * patch_size = 3 * 16 * 16 = 768
这意味着:PatchEmbedding 的输入特征维必须是 768,才能被 nn.Linear(768, embed_dim)(即 Linear(768, 768))正确处理。
但原代码中 x.flatten(2) 的位置和 reshape 顺序有误,导致实际展平维度远超预期(50176 ≈ 30 × 196 × ?,实为误算的中间态)。根本修复方式如下:
✅ 修正 PatchEmbedding 实现(推荐清晰写法)
class PatchEmbedding(torch.nn.Module):
def __init__(self, patch_size, in_channels, embed_dim):
super().__init__()
self.patch_size = patch_size
self.embed_dim = embed_dim
# 输入维度 = 3 * 16 * 16 = 768 → 输出 embed_dim = 768
self.projection = torch.nn.Linear(
in_features=patch_size * patch_size * in_channels,
out_features=embed_dim
)
def forward(self, x):
B, C, H, W = x.shape
# Step 1: 将空间维度拆分为 patch 网格 —— (B, C, H, W) → (B, C, num_h, p, num_w, p)
p = self.patch_size
num_h, num_w = H // p, W // p
x = x.reshape(B, C, num_h, p, num_w, p)
# Step 2: 调整顺序使 patch 内部连续 → (B, num_h, num_w, C, p, p)
x = x.permute(0, 2, 4, 1, 3, 5)
# Step 3: 展平每个 patch 的通道与像素 → (B, num_h * num_w, C * p * p)
x = x.flatten(3) # flatten last two dims: (C, p, p) → (C*p*p)
x = x.flatten(1, 2) # flatten num_h and num_w → (B, num_h*num_w, C*p*p)
# Step 4: 投影到 embedding 空间
x = self.projection(x) # (B, 196, 768)
return x? 验证形状:输入 (30, 3, 224, 224) → 输出 (30, 196, 768),完美匹配后续 MultiHeadAttention(embed_dim=768) 的输入要求。
⚠️ 其他关键注意事项
-
cls_token 使用需显式添加:标准ViT在patch嵌入后需拼接可学习的 [CLS] token。当前代码中 x[:, cls_token] 假设 cls_token=0,但未实际插入该token,会导致索引越界或语义错误。应在 PatchEmbedding 后补充:
self.cls_token = torch.nn.Parameter(torch.randn(1, 1, embed_dim)) # 在 VisionTransformer.forward 中: cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) # (B, 197, 768)
-
TransformerEncoder 结构需符合标准范式:当前实现将 LayerNorm 放在注意力之后、FFN之前,但标准ViT采用 Post-LN(即每个子层后接LN),且FFN应为两层线性+GELU(非单层Linear(embed_dim, embed_dim))。建议修正为:
class TransformerBlock(torch.nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.attn = MultiHeadAttention(embed_dim, num_heads) self.norm1 = torch.nn.LayerNorm(embed_dim) self.mlp = torch.nn.Sequential( torch.nn.Linear(embed_dim, embed_dim * 4), torch.nn.GELU(), torch.nn.Dropout(0.1), torch.nn.Linear(embed_dim * 4, embed_dim), torch.nn.Dropout(0.1) ) self.norm2 = torch.nn.LayerNorm(embed_dim) def forward(self, x): x = x + self.attn(self.norm1(x)) # residual + attn x = x + self.mlp(self.norm2(x)) # residual + mlp return x -
调试技巧:插入 shape 打印
在 forward 中添加临时日志,快速定位断点:print(f"[PatchEmbed] input: {x.shape} → output: {x.shape}")
✅ 总结
该错误本质是Patch Embedding 层未正确生成 (B, num_patches, embed_dim) 形状张量,导致下游线性层输入维度爆炸(50176)。修复核心在于:
- 严格按 C × p × p 计算单patch特征维;
- 使用 permute + flatten 正确重组张量,确保 flatten 作用于正确的轴;
- 始终验证每层输入/输出shape,尤其在自定义几何操作时;
- 遵循ViT原始论文结构(如CLS token、LayerNorm位置、MLP扩展比)。
完成上述修正后,模型即可顺利运行,为后续位置编码、训练调优打下坚实基础。









