0

0

PyTorch 中实现可微分的张量选择:从硬索引到软选择的完整教程

霞舞

霞舞

发布时间:2026-03-10 13:33:31

|

944人浏览过

|

来源于php中文网

原创

在 PyTorch 中,直接使用非整数张量(如含梯度的浮点型标量)作为切片索引会导致梯度中断;本文详解为何 e[:d] 不可导,并提供基于 Gumbel-Softmax 重参数化的可微分软选择方案,附可运行代码与关键注意事项。

pytorch 中,直接使用非整数张量(如含梯度的浮点型标量)作为切片索引会导致梯度中断;本文详解为何 `e[:d]` 不可导,并提供基于 gumbel-softmax 重参数化的可微分软选择方案,附可运行代码与关键注意事项。

在深度学习中,我们常需根据模型输出动态选择张量中的部分元素(例如 top-k 检索、条件路由或注意力掩码生成)。然而,像 e[:d] 这类依赖于可学习变量 d 的硬索引操作(hard indexing)本质上不可导——因为索引本身是离散的、非连续的操作,PyTorch 的自动微分引擎无法计算其对 d 的梯度。即使将 d 强制转为 long(如 e[:d.to(torch.long)]),梯度也会在类型转换处截断,导致上游参数(如 a)无法更新。

要实现“可学习的选择”,必须用连续、可微的近似替代离散决策。主流方法是采用软选择(soft selection),核心思想是:不直接取索引,而是为每个候选位置分配一个可学习的权重,再通过加权聚合实现选择。其中,Gumbel-Softmax 重参数化技巧是兼顾可微性与离散语义的经典方案。

以下是一个端到端可微分的软选择实现(适用于一维张量按数量截取场景,如 e[:d] 的替代):

Freepik Mystic
Freepik Mystic

Freepik Mystic 是一款革命性的AI图像生成器,可以直接生成全高清图像

下载
import torch
import torch.nn.functional as F

# 原始设定:d 是含梯度的标量(如 min(a,b,c)),e 是待选数组
a = torch.tensor([4.], requires_grad=True)
b = torch.tensor([5.])
c = torch.tensor([6.])
d = a.min(b).min(c)  # d.shape == torch.Size([]), requires_grad=True

e = torch.arange(10, dtype=torch.float32)  # e.shape == [10]

# ✅ 可微分替代方案:将 "取前 d 个" 转为 "对前 floor(d)+1 个位置施加软权重"
# Step 1: 构建可学习的 logits(维度与 e 对齐),代表每个位置被选中的倾向
logits = torch.randn_like(e, requires_grad=True)  # 初始化为随机,实际中可由网络预测

# Step 2: 生成 soft selection weights(概率分布)
weights = F.softmax(logits, dim=0)  # shape [10], sum=1.0

# Step 3: 构造 soft mask,模拟“取前 k 个”的行为
# 我们定义 mask[i] = 1 if i < d, else 0 → 但 d 是浮点数,需平滑化
# 使用 sigmoid 构建平滑阶跃:mask[i] ≈ σ((d - i) * temperature)
temperature = 10.0  # 控制陡峭程度,越大越接近硬阈值
indices = torch.arange(len(e), dtype=torch.float32)
soft_mask = torch.sigmoid((d - indices) * temperature)  # shape [10]

# Step 4: 加权选择(可微)
f_soft = e * soft_mask  # shape [10],每个元素被缩放

# Step 5: 定义损失并反向传播(示例:最小化 f_soft 的 L2 norm)
loss = f_soft.sum()  # 或其他任务相关 loss
loss.backward()

print(f"d.grad = {d.grad}")   # 非 None!梯度成功回传至 d
print(f"a.grad = {a.grad}")   # 进而回传至原始参数 a

? 关键说明

  • 上述 soft_mask 使用 sigmoid((d - i) * T) 实现了对“前 d 个”位置的平滑、可微近似:当 i d 时趋近 0;temperature 控制过渡带宽,训练初期可用较小值(如 1–5)提升稳定性,后期增大以逼近硬选择。
  • 若需严格保持输出长度为 floor(d) 或支持更复杂选择逻辑(如 top-k、条件采样),推荐使用 torch.nn.functional.gumbel_softmax 配合 one_hot + argmax 的 Straight-Through Estimator(STE)变体,但需注意梯度估计偏差。
  • 永远避免 e[int(d.item())] 或 e[:d.long()] 等隐式转换操作——它们会切断计算图,使 d 及其上游参数无法更新。

总结而言,PyTorch 中的索引操作天然不可导,但通过将“选择”重构为连续权重分配 + 平滑掩码,我们既能保留梯度流,又能逼近原始语义。这一范式广泛应用于神经架构搜索(NAS)、稀疏激活、可微分搜索等前沿领域。实践中,应根据任务需求权衡软选择的平滑程度与离散精度,并始终通过 assert param.grad is not None 验证梯度连通性。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
string转int
string转int

在编程中,我们经常会遇到需要将字符串(str)转换为整数(int)的情况。这可能是因为我们需要对字符串进行数值计算,或者需要将用户输入的字符串转换为整数进行处理。php中文网给大家带来了相关的教程以及文章,欢迎大家前来学习阅读。

990

2023.08.02

int占多少字节
int占多少字节

int占4个字节,意味着一个int变量可以存储范围在-2,147,483,648到2,147,483,647之间的整数值,在某些情况下也可能是2个字节或8个字节,int是一种常用的数据类型,用于表示整数,需要根据具体情况选择合适的数据类型,以确保程序的正确性和性能。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

607

2024.08.29

c++怎么把double转成int
c++怎么把double转成int

本专题整合了 c++ double相关教程,阅读专题下面的文章了解更多详细内容。

314

2025.08.29

C++中int的含义
C++中int的含义

本专题整合了C++中int相关内容,阅读专题下面的文章了解更多详细内容。

235

2025.08.29

go语言 数组和切片
go语言 数组和切片

本专题整合了go语言数组和切片的区别与含义,阅读专题下面的文章了解更多详细内容。

53

2025.09.03

C++类型转换方式
C++类型转换方式

本专题整合了C++类型转换相关内容,想了解更多相关内容,请阅读专题下面的文章。

319

2025.07.15

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

466

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

27

2025.12.22

Kotlin Android模块化架构与组件化开发实践
Kotlin Android模块化架构与组件化开发实践

本专题围绕 Kotlin 在 Android 应用开发中的架构实践展开,重点讲解模块化设计与组件化开发的实现思路。内容包括项目模块拆分策略、公共组件封装、依赖管理优化、路由通信机制以及大型项目的工程化管理方法。通过真实项目案例分析,帮助开发者构建结构清晰、易扩展且维护成本低的 Android 应用架构体系,提升团队协作效率与项目迭代速度。

24

2026.03.09

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号