
本文介绍如何使用 torch.tensor_split() 配合累积和,将 PyTorch 张量按指定长度列表(而非切片位置)精准分段,适用于数据批处理、序列截断等场景。
本文介绍如何使用 `torch.tensor_split()` 配合累积和,将 pytorch 张量按指定长度列表(而非切片位置)精准分段,适用于数据批处理、序列截断等场景。
在 PyTorch 中,torch.split() 要求传入的是每段的长度(如 [1, 2, 5, 10]),而 torch.chunk() 则按等份数划分——但本例需求略有不同:给定一个 长度列表 splits = [1, 2, 5, 10],需将张量从头开始依次切出长度为 1、2、5、10 的子张量,剩余部分(如有)默认丢弃。注意:这不是按「绝对索引位置」(如 [1, 3, 8, 18])切分,而是按「连续段长度」进行累积划分。
此时最简洁、健壮的方案是使用 torch.tensor_split() —— 它接受一组分割点(split points),即沿维度的 切片起始索引位置(不包括 0)。我们需要将长度列表转换为对应的累积索引位置:
- splits = [1, 2, 5, 10]
- 累积和 cumsum = [1, 3, 8, 18] → 表示第 1 个切分点在索引 1 处(左闭右开:[0:1]),第 2 个在索引 3 处([1:3]),依此类推。
关键细节:tensor_split 会在每个指定位置“切一刀”,最终生成 len(split_points) + 1 段;而我们仅需前 len(splits) 段(对应各指定长度),因此需截掉最后一段。
✅ 正确实现如下:
import torch
import numpy as np
t = torch.arange(20) # shape: (20,)
splits = [1, 2, 5, 10]
# 计算累积索引位置(切分点)
split_points = np.cumsum(splits).tolist() # [1, 3, 8, 18]
# 分割并截取前 len(splits) 段
segments = t.tensor_split(split_points)[:-1]
for i, seg in enumerate(segments):
print(f"Segment {i}: {seg.tolist()} (length: {len(seg)})")输出:
Segment 0: [0] (length: 1) Segment 1: [1, 2] (length: 2) Segment 2: [3, 4, 5, 6, 7] (length: 5) Segment 3: [8, 9, 10, 11, 12, 13, 14, 15, 16, 17] (length: 10)
⚠️ 注意事项:
- tensor_split 是 PyTorch 1.8+ 引入的稳定 API,确保环境版本兼容;
- np.cumsum 可替换为纯 PyTorch 实现(如 torch.tensor(splits).cumsum(0)),避免依赖 NumPy:
split_points = torch.tensor(splits).cumsum(0).tolist()
- 若输入张量长度不足以容纳所有 sum(splits) 元素,tensor_split 不会报错,但末段可能被截断——建议预先校验:assert t.numel() >= sum(splits);
- 此方法默认沿第 0 维操作;如需在其他维度分割,请显式指定 dim= 参数(如 t.tensor_split(split_points, dim=1))。
总结:将「按长度列表分割」转化为「按累积索引点分割」,再借助 tensor_split + 切片裁剪,是清晰、高效且符合 PyTorch 设计哲学的解决方案。相比手动循环拼接或多次 narrow(),它兼具可读性与执行性能。










