
本文介绍如何在 PyTorch 中根据给定的长度列表(而非切分位置)将一维张量精准划分为多个子张量,核心方法是结合 torch.tensor_split 与累积和计算,兼顾简洁性与可靠性。
本文介绍如何在 pytorch 中根据给定的长度列表(而非切分位置)将一维张量精准划分为多个子张量,核心方法是结合 `torch.tensor_split` 与累积和计算,兼顾简洁性与可靠性。
在 PyTorch 实际开发中,常需按「每段长度」而非「绝对索引位置」对张量进行分块(例如按 batch 内样本长度切分序列、按子任务尺寸划分数据)。但 PyTorch 原生不提供类似 NumPy 的 array_split 或直接接受长度列表的 split 接口。此时,torch.tensor_split() 是最稳妥的选择——它接受一组切分点(split points),即沿维度的绝对索引位置(不包含起始点 0,也不含末尾),而我们需要将用户提供的「各段长度列表」(如 [1, 2, 5, 10])转换为对应的累积切分点。
关键步骤如下:
- 对长度列表计算前缀和(cumulative sum),得到每个分块结束位置的索引;
- 将该前缀和转为 Python 列表,传入 tensor_split;
- 由于 tensor_split 会在最后一个切分点后额外生成一个空/剩余张量,而题目要求恰好返回 len(splits) 个块,因此需截取 [:-1]。
以下为完整可运行示例:
import torch
import numpy as np
# 示例数据
x = torch.arange(20) # tensor([0, 1, ..., 19])
lengths = [1, 2, 5, 10] # 各段期望长度
# 计算累积和作为切分点(注意:不含起始0,含各段结束索引)
split_points = np.cumsum(lengths).tolist() # → [1, 3, 8, 18]
# 使用 tensor_split 并去除末尾冗余块
chunks = x.tensor_split(split_points)[:-1]
for i, chunk in enumerate(chunks):
print(f"Chunk {i}: {chunk.tolist()}")输出:
Chunk 0: [0] Chunk 1: [1, 2] Chunk 2: [3, 4, 5, 6, 7] Chunk 3: [8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
✅ 注意事项:
- torch.tensor_split 自 PyTorch 1.8+ 起稳定支持,确保环境版本 ≥ 1.8;
- 输入长度列表之和必须 ≤ 张量总长度,否则 tensor_split 将在末尾补空张量(不会报错,但结果不符合预期);建议提前校验:assert sum(lengths) <= len(x);
- 若需纯 PyTorch 实现(避免依赖 NumPy),可用 torch.cumsum(torch.tensor(lengths), dim=0) 替代 np.cumsum;
- tensor_split 返回的是 tuple 而非 list,若需列表操作可显式转为 list(chunks)。
该方法简洁、高效且语义清晰,是 PyTorch 生态中处理「按长度分块」场景的标准实践。










