
本文介绍如何在 PyTorch 中利用 scatter_add 高效完成一对多索引映射:将源张量按不规则映射关系分散累加到目标张量中,完全避免 Python 循环,兼顾性能与可读性。
本文介绍如何在 pytorch 中利用 `scatter_add` 高效完成一对多索引映射:将源张量按不规则映射关系分散累加到目标张量中,完全避免 python 循环,兼顾性能与可读性。
在深度学习和图神经网络等场景中,常需将一个张量的元素按自定义、非均匀的索引关系“分发”并聚合到另一个更大或结构不同的张量中——例如,将节点特征按邻接关系聚合至超节点,或将稀疏事件流映射到时间桶中求和。此时若用 Python 循环遍历映射列表,不仅代码冗长,更会严重拖慢训练速度(尤其在 GPU 上触发主机-设备同步)。PyTorch 提供的 torch.Tensor.scatter_add_ 正是为此类“稀疏散列+原子累加”操作量身定制的原语。
核心思路是将不规则的二维映射结构(如 mapping[i] = [j1, j2, ...])展平为一维索引序列,并同步扩展输入张量,使二者长度一致,从而满足 scatter_add 的张量对齐要求。具体分为三步:
- 计算重复次数:统计每个 input[i] 需映射的目标位置数量,即 reps = [len(mapping[0]), len(mapping[1]), ...];
- 构建源值向量 src:使用 input.repeat_interleave(reps) 将每个 input[i] 重复 reps[i] 次,得到待累加的值序列;
- 构建索引向量 index:展平 mapping 得到全局目标下标序列;
- 初始化输出张量 out:大小为 max(index) + 1,类型与 input 一致,初始为零;
- 执行原子累加:调用 out.scatter_add_(dim=0, index=index, src=src) 完成全部映射。
以下为完整可运行示例:
import torch input = torch.tensor([0, 1, 2, 3], dtype=torch.float32) mapping = [[1], [0, 2, 4], [0, 3], [1, 2]] # Step 1: 计算每个 input 元素需重复的次数 reps = torch.tensor([len(x) for x in mapping]) # Step 2: 构建 src —— input[i] 重复 reps[i] 次 src = input.repeat_interleave(reps) # tensor([0., 1., 1., 1., 2., 2., 3., 3.]) # Step 3: 展平 mapping 得到全局索引 index = torch.tensor([r for x in mapping for r in x]) # tensor([1, 0, 2, 4, 0, 3, 1, 2]) # Step 4: 初始化 output 张量(注意:dtype 必须匹配 src) out = torch.zeros(max(index) + 1, dtype=src.dtype) # Step 5: 执行 scatter_add(in-place 累加) out.scatter_add_(dim=0, index=index, src=src) print(out) # tensor([3., 3., 4., 2., 1.])
✅ 关键注意事项:
- scatter_add_ 是 in-place 操作,若需保留原始 out,请先 out.clone() 或使用函数式接口 torch.scatter_add(out, dim, index, src)(PyTorch ≥ 1.12);
- index 中的值必须是非负整数,且严格小于 out.size(dim),否则将触发 RuntimeError;
- src 与 index 的长度必须相等,这是 scatter_add 的硬性要求,repeat_interleave 和列表推导式确保了这一点;
- 若 mapping 来源于 CPU 列表,建议尽早转为 torch.Tensor 并移至 GPU(如 .to(device)),避免混合设备操作;
- 对于超大规模映射(如百万级索引),可考虑使用 torch.sparse 或 torch.compile 进一步优化,但本方案在绝大多数中等规模任务中已足够高效。
该方法将原本 O(N×M) 的隐式循环(N 为 input 长度,M 为平均映射数)转化为底层高度优化的 CUDA kernel 调用,实测在 GPU 上提速可达 10–100 倍。掌握 scatter_add 不仅解决当前问题,更是构建高性能自定义聚合层(如 Pooling、Message Passing)的重要基石。






