
本文介绍使用 PyTorch torch.func(原 functorch)高效计算批量输入下模型输出对全部可训练参数的完整雅可比矩阵,即形状为 (B, Y, P) 的梯度张量,避免显式循环,显著提升性能。
本文介绍使用 pytorch `torch.func`(原 functorch)高效计算批量输入下模型输出对全部可训练参数的完整雅可比矩阵,即形状为 `(b, y, p)` 的梯度张量,避免显式循环,显著提升性能。
在深度学习可解释性、元学习、二阶优化(如牛顿法)、神经正切核(NTK)分析及对抗样本生成等任务中,常需获取每个样本的每个输出维度对模型所有参数的梯度——即输出向量关于参数向量的雅可比矩阵(Jacobian),其形状为 (batch_size, output_dim, num_params)。传统方法通过双重嵌套循环(遍历 batch 和输出维度)调用 torch.autograd.grad,时间复杂度高、无法利用 GPU 并行性,且难以扩展。
PyTorch 2.0+ 内置的 torch.func 模块为此类高阶自动微分任务提供了原生、高效、函数式(functional)的解决方案,核心依赖三个关键能力:
- functional_call:将模块(nn.Module)转化为纯函数,接收参数字典、缓冲区字典和输入,避免隐式状态依赖;
- jacrev:计算标量或向量值函数关于输入(此处为参数字典)的反向模式雅可比;
- vmap:对任意函数进行批量化向量化(vectorized mapping),实现零开销的 batch 维度并行,替代 Python 循环。
以下是一个完整、可运行的实现示例:
import torch
import torch.nn as nn
from torch import func
# 示例模型(同问题中 MLP)
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
# 初始化模型与数据
net = MLP(28*28, 20, 10)
X = torch.randn(4, 28*28) # batch=4, input flattened
# 1. 提取参数与缓冲区(分离状态,确保函数式调用)
params = {k: v for k, v in net.named_parameters()}
buffers = {k: v for k, v in net.named_buffers()}
# 2. 定义单样本雅可比计算函数
def jac_per_sample(x):
# 构造纯函数:给定参数字典,返回该样本的输出向量
def model_fn(p):
return func.functional_call(net, (p, buffers), x)
# 计算输出向量(10维)关于参数字典的雅可比:返回 dict[str, Tensor]
jacobian_dict = func.jacrev(model_fn)(params)
# 将各参数梯度展平并拼接为 (10, P) 张量
grads_flat = torch.cat([j.view(j.size(0), -1) for j in jacobian_dict.values()], dim=1)
return grads_flat # shape: (10, P)
# 3. 使用 vmap 批量处理所有样本 → 输出 shape: (B, 10, P)
grads_batch = func.vmap(jac_per_sample)(X) # 自动向量化 batch 维度
print(f"Output shape: {grads_batch.shape}") # e.g., torch.Size([4, 10, 4020])✅ 关键优势说明:
- 无显式循环:vmap 在 C++ 层实现向量化,避免 Python 解释器开销;
- GPU 友好:整个流程可在 CUDA 张量上无缝运行,充分利用显存带宽与并行计算单元;
- 内存可控:jacrev 对每个输出分量分别反向传播,不构造稠密 (Y×P) 矩阵(除非必要),适合大模型;
- 函数式安全:functional_call 避免修改原模型参数,支持高阶导数与嵌套微分。
⚠️ 注意事项与最佳实践:
- torch.func 要求 PyTorch ≥ 2.0(推荐 ≥ 2.2);若环境受限,可安装 functorch(已归档,仅兼容旧版);
- jacrev 默认对输出每个元素独立求导,适用于 output_dim 不过大(如分类 logits ≤ 1000)的场景;若 Y 极大(如像素级回归),应改用 jacfwd 或分块计算;
- 参数字典 params 必须包含 全部 requires_grad=True 的参数,且键名需与 named_parameters() 严格一致;
- 如需二阶导数(如 Hessian),可嵌套 func.grad 或 func.hessian,但需注意内存增长;
- 实际部署前务必验证数值正确性:torch.allclose(grads_batch[0], manual_jac_for_first_sample)。
综上,借助 torch.func 的函数式编程范式,我们能以简洁、高性能、可维护的方式完成原本繁琐的逐样本雅可比计算任务,这是现代 PyTorch 高阶自动微分能力的典型体现。










