本文介绍如何使用 torch.tensor_split() 结合累积和,将 PyTorch 张量按指定长度列表(而非切片位置)精准分段,适用于数据预处理、批量划分等场景。
本文介绍如何使用 `torch.tensor_split()` 结合累积和,将 pytorch 张量按指定长度列表(而非切片位置)精准分段,适用于数据预处理、批量划分等场景。
在 PyTorch 中,若需将一个一维张量按「每段长度」(如 [1, 2, 5, 10])进行分割,而非按绝对索引位置(如 [1, 3, 8, 18])切片,不能直接传入长度列表给 torch.split() 或 torch.chunk()——因为前者要求累计长度边界,后者要求等长分块。正确做法是:先将长度列表转换为累积切分点(cumulative split points),再交由 torch.tensor_split() 处理。
torch.tensor_split() 是 PyTorch 1.8+ 推荐的通用分块函数,支持非等长分割,且行为稳定(相比已弃用的 torch.split 在某些边界下更可靠)。关键在于:它接收的是沿指定维度的切分位置索引(即“在哪之后切”),因此需将长度列表 [1,2,5,10] 转换为累计位置 [1, 3, 8, 18](即 1, 1+2, 1+2+5, 1+2+5+10),然后调用:
import torch import numpy as np x = torch.arange(20) # tensor([0, 1, ..., 19]) lengths = [1, 2, 5, 10] # 计算累计切分点(不包含起始0,也不含总长) split_points = np.cumsum(lengths).tolist() # → [1, 3, 8, 18] # 分割(返回 tuple of tensors) chunks = x.tensor_split(split_points) # 注意:tensor_split 会额外生成最后一个空/剩余段(若总长 > sum(lengths)) # 此处因 sum(lengths) == 18 < 20,故 chunks[-1] 包含剩余元素 [18, 19] # 若严格只需前 len(lengths) 段,取前 N 段即可: result = chunks[:-1] # 或显式限制:chunks[:len(lengths)] print(result)
输出:
(tensor([0]), tensor([1, 2]), tensor([3, 4, 5, 6, 7]), tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]))
✅ 注意事项与最佳实践:
- torch.tensor_split() 默认沿第 0 维操作,多维张量需显式指定 dim= 参数;
- 累计和必须为 Python list 或 1D torch.Tensor(np.cumsum(...).tolist() 最稳妥);
- 若 sum(lengths) == x.numel(),则 chunks[:-1] 与 chunks 完全等价;但为保持逻辑清晰与兼容性,建议始终按 len(lengths) 截断;
- 不推荐使用 torch.split(x, lengths):它虽接受长度列表,但语义不同——它按顺序分配各段长度,不校验总长是否越界,且当 sum(lengths) < x.numel() 时会静默丢弃尾部元素,易引发隐蔽 bug;
- 纯 PyTorch 方案(避免 NumPy 依赖)可改用:torch.tensor(lengths).cumsum(0).tolist()。
综上,tensor_split + cumsum 是语义明确、行为可预测、兼容性佳的标准解法,应作为 PyTorch 中“按长度列表分块”的首选模式。










