
本文介绍如何使用 torch.tensor_split 结合累积和,将 PyTorch 张量按指定长度列表(而非切分点位置)进行精确分段,适用于数据批处理、序列截断等场景。
本文介绍如何使用 `torch.tensor_split` 结合累积和,将 pytorch 张量按指定长度列表(而非切分点位置)进行精确分段,适用于数据批处理、序列截断等场景。
在 PyTorch 中,若需将一个一维张量按「每段长度」(如 [1, 2, 5, 10])拆分为多个子张量,不能直接使用 torch.split——因为 torch.split(tensor, split_size_or_sections) 要求所有段长度相等(当传入整数时),或需显式提供各段起始/结束索引(当传入 list[int] 时,它实际解释为切分点位置,即 split_points = [1, 2, 5, 10] 表示在索引 1、2、5、10 处切割,这与题设语义不同)。
题设中 splits = [1, 2, 5, 10] 的含义是:第一段取 1 个元素,第二段取 2 个,第三段取 5 个,第四段取 10 个,累计覆盖前 1+2+5+10 = 18 个元素。因此,真正需要的切分点(cumulative split points)应为 [1, 1+2=3, 1+2+5=8, 1+2+5+10=18],即 [1, 3, 8, 18]。此时可调用 tensor_split 在这些位置执行切割。
✅ 推荐方案:torch.tensor_split + np.cumsum(简洁、通用、无循环)
import torch import numpy as np # 示例数据 t = torch.arange(20) # shape: (20,) splits = [1, 2, 5, 10] # 各段期望长度 # 计算累积切分点(不包含起始 0,也不包含末尾总长) split_points = np.cumsum(splits).tolist() # → [1, 3, 8, 18] # 切分(返回 tuple of tensors) chunks = t.tensor_split(split_points) # 注意:tensor_split 会在每个 split_point 处切割,产生 len(split_points)+1 个片段 # 但题设只要前 len(splits) 段(即忽略最后一段剩余部分) result = chunks[:-1] # → 4 个 tensor,对应 splits 中的 4 个长度 print([c.tolist() for c in result]) # 输出: # [[0], [1, 2], [3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15, 16, 17]]
? 关键说明:
- torch.tensor_split(input, split_points) 将 input 在 split_points[i] 索引处左闭右开切割,生成 len(split_points) + 1 个子张量;
- np.cumsum(splits) 自动计算前缀和,比手写 torch.tensor(splits).cumsum(0).tolist() 更轻量(无需引入 torch 依赖);
- [:-1] 截断是为了排除最后一段(即索引 ≥18 的剩余元素),严格匹配题设输出;若需保留剩余部分,可省略该切片。
⚠️ 注意事项:
- 输入张量长度必须 ≥ sum(splits),否则 tensor_split 可能触发索引越界或静默截断(取决于 PyTorch 版本),建议提前校验:assert t.numel() >= sum(splits);
- tensor_split 是 PyTorch 1.8+ 引入的稳定 API,旧版本请升级或改用 torch.narrow 手动实现;
- 该方法天然支持任意维度张量(默认沿 dim=0 切分),如需按列切分,添加 dim=1 参数即可。
? 进阶技巧:若需动态兼容「剩余部分也作为一段」,可直接使用 chunks 全量结果,并通过 len(splits) 控制取舍逻辑;对于超大规模张量,np.cumsum 性能远优于 Python 循环,且全程不涉及 CPU-GPU 数据迁移(split_points 仅为纯 Python list)。










