
本文介绍如何在 PyTorch 中构建一个专业、可复用的 BatchSampler,精确控制每个 batch 中正样本(如稀疏目标)与负样本的比例(例如 10:90),支持自动过采样不足的正例,且完全兼容标准 DataLoader 流程。
本文介绍如何在 pytorch 中构建一个专业、可复用的 `batchsampler`,精确控制每个 batch 中正样本(如稀疏目标)与负样本的比例(例如 10:90),支持自动过采样不足的正例,且完全兼容标准 `dataloader` 流程。
在处理极度不平衡的数据集(如医学图像中阳性像素占比仅 0.01%)时,简单使用 WeightedRandomSampler 往往无法保证每个 batch 内部的类别分布可控——它仅提供概率加权,不保证每批恰好含指定数量的正样本。此时,更“torch-native”且精准的方案是实现自定义 BatchSampler:它直接生成整批索引(而非单个索引),交由 DataLoader 按序调用 Dataset.__getitem__,从而彻底掌控采样逻辑。
以下是一个生产就绪的 BalancedBatchSampler 实现,满足您的全部需求:
- ✅ 每 batch 严格包含 n_positive 个正样本 + n_negative 个负样本
- ✅ 正样本不足时自动有放回随机重复(replacement=True)
- ✅ 负样本从全量负索引池中无放回随机采样(避免同 batch 重复)
- ✅ 支持 shuffle=True/False 控制每 epoch 内 batch 顺序
- ✅ 与 DataLoader(..., batch_sampler=...) 无缝集成,无需修改 Dataset
import torch
from torch.utils.data import BatchSampler, DataLoader
import random
from typing import List, Iterator, Optional
class BalancedBatchSampler(BatchSampler):
def __init__(
self,
positive_idx: List[int],
negative_idx: List[int],
batch_size: int = 100,
n_positive: int = 10,
drop_last: bool = False,
shuffle: bool = True,
generator: Optional[torch.Generator] = None
):
if n_positive > batch_size:
raise ValueError(f"n_positive ({n_positive}) must be <= batch_size ({batch_size})")
self.positive_idx = positive_idx
self.negative_idx = negative_idx
self.batch_size = batch_size
self.n_positive = n_positive
self.n_negative = batch_size - n_positive
self.drop_last = drop_last
self.shuffle = shuffle
self.generator = generator
# 计算总 batch 数(按正样本约束)
n_batches_by_positive = len(positive_idx) // n_positive
if not self.drop_last and len(positive_idx) % n_positive != 0:
n_batches_by_positive += 1
# 按负样本约束校验(通常负样本充足,此步为健壮性检查)
if len(negative_idx) < self.n_negative:
raise ValueError(f"Insufficient negative samples: {len(negative_idx)} < {self.n_negative}")
self._n_batches = n_batches_by_positive
def __len__(self) -> int:
return self._n_batches
def __iter__(self) -> Iterator[List[int]]:
# 每次迭代前重置索引池(支持 shuffle)
pos_pool = self.positive_idx.copy()
neg_pool = self.negative_idx.copy()
if self.shuffle:
random.shuffle(pos_pool, _rand=random.random)
random.shuffle(neg_pool, _rand=random.random)
# 生成所有 batch 的索引列表
for _ in range(self._n_batches):
# 正样本:有放回采样(允许重复)
batch_pos = random.choices(pos_pool, k=self.n_positive)
# 负样本:无放回采样(避免 batch 内重复)
batch_neg = random.sample(neg_pool, k=self.n_negative)
# 合并并打乱 batch 内顺序(提升训练稳定性)
batch_indices = batch_pos + batch_neg
if self.shuffle:
random.shuffle(batch_indices, _rand=random.random)
yield batch_indices使用示例:
# 假设你的 Dataset 已定义好(如题中 MyDataset)
ds = MyDataset() # 包含 self.positive_idx 和 self.negative_idx
# 创建平衡批采样器:每 batch 100 样本,其中 10 个正例 + 90 个负例
sampler = BalancedBatchSampler(
positive_idx=ds.positive_idx,
negative_idx=ds.negative_idx,
batch_size=100,
n_positive=10,
shuffle=True # 由 sampler 控制 shuffle,DataLoader 中 shuffle=False
)
# 关键:传入 batch_sampler(而非 sampler),且禁用 DataLoader 的 shuffle
dl = DataLoader(
ds,
batch_sampler=sampler,
num_workers=4,
pin_memory=True
)
# 验证效果
for i, (images, labels) in enumerate(dl):
print(f"Batch {i}: images.shape={images.shape}, "
f"positive_ratio={labels.sum().item() / labels.numel():.4f}")
if i >= 2: break # 仅查看前 3 个 batch重要注意事项:
- ? DataLoader 参数设置:必须使用 batch_sampler= 参数,并将 shuffle=False(因为重采样逻辑已由 BalancedBatchSampler 承担)。若同时启用 shuffle=True,会导致行为不可预测。
- ? 内存与效率:该采样器在每次 __iter__() 调用时生成全量索引列表,适用于中等规模数据集(
- ? 与 DistributedSampler 兼容性:如需多卡训练,需将 BalancedBatchSampler 封装进 DistributedSampler 的 sampler= 参数,或改写为支持 rank/world_size 的分布式版本。
- ? 扩展性:可通过继承增加多类别平衡、分层采样、动态难度加权等功能,保持接口一致性。
综上,BalancedBatchSampler 是 PyTorch 生态中实现确定性类别平衡最直接、最可控的方式。它规避了 WeightedRandomSampler 的统计波动,也优于手动在 __getitem__ 中做逻辑判断(破坏数据加载管线分离原则),是工业级不平衡学习任务的推荐实践。










