
本文介绍使用 torch.Tensor.scatter_add_ 配合索引展开与值重复,高效完成一维张量到另一维张量的一对多映射累加操作,避免 Python 循环,完全基于向量化运算。
本文介绍使用 `torch.tensor.scatter_add_` 配合索引展开与值重复,高效完成一维张量到另一维张量的一对多映射累加操作,避免 python 循环,完全基于向量化运算。
在 PyTorch 中处理「一对多」映射关系(即每个输入元素贡献至多个输出位置)并执行聚合(如求和)时,若采用 Python 循环或列表推导,不仅代码冗长,更会严重拖慢训练速度、破坏计算图完整性,且无法充分利用 GPU 并行能力。幸运的是,PyTorch 提供了高度优化的原语——scatter_add,它专为这类“按索引分散累加”场景设计,可一次性完成全部映射与聚合。
核心思想是将不规则映射结构(如嵌套列表 mapping)转化为两个齐次一维张量:
- src:待累加的源值序列,其中每个 input[i] 根据其映射目标数量被重复;
- index:对应的目标位置索引序列,与 src 严格对齐;
- out:初始化为零的输出张量,长度由最大目标索引决定。
以下为完整实现示例:
import torch # 输入定义 input = torch.tensor([0, 1, 2, 3], dtype=torch.float32) mapping = [[1], [0, 2, 4], [0, 3], [1, 2]] # 步骤 1:计算各输入项的重复次数(即每个 input[i] 映射到多少个 output 位置) reps = torch.tensor([len(x) for x in mapping]) # 步骤 2:构建 src —— 按 reps 重复 input 中每个元素 src = input.repeat_interleave(reps) # tensor([0, 1, 1, 1, 2, 2, 3, 3]) # 步骤 3:构建 index —— 展平 mapping,得到所有 (src[i] → output[j]) 的 j 序列 index = torch.tensor([j for sublist in mapping for j in sublist]) # tensor([1, 0, 2, 4, 0, 3, 1, 2]) # 步骤 4:初始化输出张量(长度 = max(index) + 1) out = torch.zeros(max(index) + 1, dtype=src.dtype) # 步骤 5:执行向量化累加:out[j] += src[i] for each (i,j) pair result = out.scatter_add(dim=0, index=index, src=src) print(result) # tensor([3., 3., 4., 2., 1.])
✅ 关键优势:
- 全程无 Python 循环,100% 张量操作,支持 CUDA 加速;
- 时间复杂度为 O(∑|mapping[i]|),空间复杂度为 O(len(output)),理论最优;
- 自动兼容梯度传播(scatter_add 是可微分操作),适用于模型中间层。
⚠️ 注意事项:
- index 中的索引必须是非负整数,且严格小于 out.size(dim),否则抛出 RuntimeError;
- 若 mapping 可能为空(如 []),需提前过滤或用 max(index, default=0) 防御;
- 当 output 维度极大但稀疏时,该方法仍会分配全量内存;如需极致稀疏支持,可考虑结合 torch.sparse 或自定义 CUDA kernel,但绝大多数场景 scatter_add 已足够高效。
总结而言,scatter_add 是解决 PyTorch 中「一对多映射+聚合」问题的标准、简洁且高性能方案。掌握其与 repeat_interleave、索引展平等组合技巧,能显著提升数据预处理与自定义层的表达力与执行效率。










