本文详解 Vision Transformer(ViT)自定义实现中常见的 mat1 and mat2 shapes cannot be multiplied 错误,聚焦 PatchEmbedding 层的维度转换逻辑缺陷,提供可直接复用的修正代码与调试方法。
本文详解 vision transformer(vit)自定义实现中常见的 `mat1 and mat2 shapes cannot be multiplied` 错误,聚焦 patchembedding 层的维度转换逻辑缺陷,提供可直接复用的修正代码与调试方法。
在 PyTorch 中从零构建 Vision Transformer 时,报错 RuntimeError: mat1 and mat2 shapes cannot be multiplied (30x50176 and 768x768) 是一个典型的输入特征维度与线性层权重维度不兼容问题。该错误发生在 PatchEmbedding.forward() 中的 self.projection(x) 调用处——此时 x 的最后一个维度(即特征长度)为 50176,而 nn.Linear(50176, 768) 的权重形状是 (768, 50176),但 PyTorch 的 F.linear 要求输入 x 形状为 (*, in_features),即末维必须等于 in_features = 50176;而实际传入的 x 形状却是 (30, 50176)(batch=30),导致 F.linear 尝试执行 (30×50176) @ (768×768) —— 显然维度无法对齐。
根本原因在于 PatchEmbedding.forward() 的张量重排逻辑存在严重错误:
- 输入图像尺寸:224 × 224,通道数 C=3,patch 大小 16×16 → 每图划分为 (224//16)² = 14² = 196 个 patch;
- 每个 patch 的原始展平维度为 patch_size² × in_channels = 16×16×3 = 768(✅ 这才是 ViT 的标准嵌入输入维度);
- 但当前代码中:
x = x.reshape(B, C, H // self.patch_size, W // self.patch_size, self.patch_size, self.patch_size) x = x.permute(0, 1, 4, 2, 5, 3) # ❌ 错误:将 C 和 patch 维度混排,破坏了 patch 内部结构 x = x.flatten(2) # ❌ 错误:从 dim=2 开始展平,得到 B × C × (14×14×16×16) = 30×3×50176 → 50176!
实际生成了 B × 3 × 50176 张量,再送入 Linear(50176, 768),必然失败。
✅ 正确做法是:每个 patch 独立展平为 768 维向量,再拼接为序列。修正后的 PatchEmbedding 如下:
class PatchEmbedding(torch.nn.Module):
def __init__(self, patch_size, in_channels, embed_dim):
super().__init__()
self.patch_size = patch_size
self.embed_dim = embed_dim
# 输入:每个 patch 是 [C, P, P] → 展平为 C*P*P 维 → 输出 embed_dim 维
self.projection = torch.nn.Linear(patch_size * patch_size * in_channels, embed_dim)
def forward(self, x):
B, C, H, W = x.shape
# Step 1: 切分图像为非重叠 patches → (B, C, num_patches_h, patch_h, num_patches_w, patch_w)
x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
# Shape: (B, C, H//P, P, W//P, P)
# Step 2: 调整维度并展平每个 patch → (B, num_patches_h * num_patches_w, C * P * P)
x = x.permute(0, 2, 4, 1, 3, 5).flatten(1, 2).flatten(2) # ✅ 关键:先合并空间维度,再展平 patch
# 或更清晰写法:
# x = x.reshape(B, C, -1, self.patch_size, self.patch_size)
# x = x.permute(0, 2, 1, 3, 4).flatten(2) # → (B, N, C*P*P), where N = (H//P)*(W//P)
# Step 3: 线性投影到 embedding 空间
x = self.projection(x) # → (B, N, embed_dim)
return x此外,还需同步修正 VisionTransformer.forward() 中的 class token 机制(原代码 cls_token = 0 未定义为可学习参数,且未拼接):
class VisionTransformer(torch.nn.Module):
def __init__(self, patch_size, in_channels, embed_dim, num_heads, num_layers, num_classes):
super().__init__()
self.patch_embed = PatchEmbedding(patch_size, in_channels, embed_dim)
self.num_patches = (224 // patch_size) ** 2 # 假设固定输入尺寸;生产环境建议动态计算
# ✅ 添加可学习的 class token
self.cls_token = torch.nn.Parameter(torch.randn(1, 1, embed_dim))
# ✅ 添加位置编码(ViT 必需!)
self.pos_embed = torch.nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
self.encoder = TransformerEncoder(embed_dim, num_heads, num_layers)
self.classifier = torch.nn.Linear(embed_dim, num_classes)
def forward(self, x):
x = self.patch_embed(x) # → (B, N, D)
# 拼接 cls token
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, D)
x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, D)
# 加位置编码
x = x + self.pos_embed
x = self.encoder(x)
x = x[:, 0] # 取 cls token 对应输出
x = self.classifier(x)
return x⚠️ 关键注意事项:
- TransformerEncoder 当前实现有结构性缺陷:它将 MultiHeadAttention 与 FFN、LayerNorm 错误地串联在 Sequential 中,但标准 ViT 要求 Attention → LayerNorm → FFN → LayerNorm(或 Pre-LN)。建议重构为独立子模块;
- MultiHeadAttention 中 einsum('bhnd,bhnd->bhn') 计算的是点积相似度,但 q 和 k 应转置以匹配 B×T×E 输入(当前 q/k 形状为 B×T×num_heads×head_dim,einsum 正确,但后续 softmax(dim=2) 应为 dim=-1 或 dim=3,需验证);
- 务必在训练前插入 shape debug 打印:
x = self.patch_embed(x) print("After patch_embed:", x.shape) # 应为 (B, 196, 768) x = torch.cat((cls_tokens, x), dim=1) print("After cat cls & pos:", x.shape) # 应为 (B, 197, 768)
总结:ViT 实现的核心在于 严格遵循 patch→flatten→project→pos_embed→cls_token→encoder 流程,任何一步的维度错位都会引发矩阵乘法异常。牢记 patch_size=16, img_size=224 ⇒ num_patches=196, patch_dim=768 这一黄金组合,并通过 print(x.shape) 在每层后校验,即可高效定位并修复此类形状错误。










