0

0

R-Drop论文复现

P粉084495128

P粉084495128

发布时间:2025-07-31 10:57:31

|

914人浏览过

|

来源于php中文网

原创

r-drop是基于dropout的改进正则化方法,通过对模型输出层施加约束减少过拟合。其让每个样本两次通过带dropout的同一模型,用kl散度约束两次输出一致,总损失为交叉熵与kl散度之和。代码实现仅增kl项,实验显示能有效提升模型正确率。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

r-drop论文复现 - php中文网

R-Drop: Regularized Dropout for Neural Networks

  由于深度神经网络非常容易过拟合,因此 Dropout 方法采用了随机丢弃每层的部分神经元,以此来避免在训练过程中的过拟合问题。正是因为每次随机丢弃部分神经元,导致每次丢弃后产生的子模型都不一样,所以 Dropout 的操作一定程度上使得训练后的模型是一种多个子模型的组合约束。基于 Dropout 的这种特殊方式对网络带来的随机性,研究员们提出了 R-Drop 来进一步对(子模型)网络的输出预测进行了正则约束。论文通过实验得出一种改进的正则化方法R-dropout,简单来说,它通过使用若干次(论文中使用了两次)dropout,定义新的损失函数。实验结果表明,尽管结构非常简单,但是却能很好的防止模型过拟合,进一步提高模型的正确率。模型主体如下图所示。

R-Drop论文复现 - php中文网        

论文贡献

  由于深度神经网络非常容易过拟合,因此 Dropout 方法采用了随机丢弃每层的部分神经元,以此来避免在训练过程中的过拟合问题。正是因为每次随机丢弃部分神经元,导致每次丢弃后产生的子模型都不一样,所以 Dropout 的操作一定程度上使得训练后的模型是一种多个子模型的组合约束。基于 Dropout 的这种特殊方式对网络带来的随机性,研究员们提出了 R-Drop 来进一步对(子模型)网络的输出预测进行了正则约束。

实现思路

  与传统作用于神经元(Dropout)或者模型参数(DropConnect)上的约束方法不同,R-Drop 作用于模型的输出层,弥补了 Dropout 在训练和测试时的不一致性。简单来说就是在每个 mini-batch 中,每个数据样本过两次带有 Dropout 的同一个模型,R-Drop 再使用 KL-divergence 约束两次的输出一致。既约束了由于 Dropout 带来的两个随机子模型的输出一致性。

R-Drop论文复现 - php中文网        

论文公式

模型的训练目标包含两个部分,一个是两次输出之间的KL散度,如下:

R-Drop论文复现 - php中文网        

另一个是模型自有的损失函数交叉熵,如下:

I-Shop购物系统
I-Shop购物系统

部分功能简介:商品收藏夹功能热门商品最新商品分级价格功能自选风格打印结算页面内部短信箱商品评论增加上一商品,下一商品功能增强商家提示功能友情链接用户在线统计用户来访统计用户来访信息用户积分功能广告设置用户组分类邮件系统后台实现更新用户数据系统图片设置模板管理CSS风格管理申诉内容过滤功能用户注册过滤特征字符IP库管理及来访限制及管理压缩,恢复,备份数据库功能上传文件管理商品类别管理商品添加/修改/

下载

R-Drop论文复现 - php中文网        

总损失函数为:

R-Drop论文复现 - php中文网        

代码实现

与传统的训练方法相比,R- Drop 只是简单增加了一个 KL-divergence 损失函数项,并没有其他任何改动。其PaddlePaddle版本对应的代码实现如下所示。

  • 散度损失

交叉熵=熵+相对熵(KL散度) 其与交叉熵的关系如下:

R-Drop论文复现 - php中文网        

代码实现示意

import paddle.nn.functional as F

# define your task model, which outputs the classifier logits
model = TaskModel()def compute_kl_loss(self, p, q, pad_mask=None):
    
    p_loss = F.kl_div(F.log_softmax(p, axis=-1), F.softmax(q, axis=-1), reduction='none')
    q_loss = F.kl_div(F.log_softmax(q, axis=-1), F.softmax(p, axis=-1), reduction='none')
    
    # pad_mask is for seq-level tasks    if pad_mask is not None:
        p_loss.masked_fill_(pad_mask, 0.)
        q_loss.masked_fill_(pad_mask, 0.)

    # You can choose whether to use function "sum" and "mean" depending on your task
    p_loss = p_loss.sum()
    q_loss = q_loss.sum()

    loss = (p_loss + q_loss) / 2
    return loss

# keep dropout and forward twice
logits = model(x)

logits2 = model(x)

# cross entropy loss for classifier
ce_loss = 0.5 * (cross_entropy_loss(logits, label) + cross_entropy_loss(logits2, label))

kl_loss = compute_kl_loss(logits, logits2)# 论文中对于CV任务的超参数
α = 0.6# carefully choose hyper-parameters
loss = ce_loss + α * kl_loss
   

代码实现实战

项目说明

本次实验以白菜生长的四个周期为例,进行生长情况识别实验。数据来自于讯飞的比赛。数据展示如下:发芽期、幼苗期、莲座期、结球期。

R-Drop论文复现 - php中文网 R-Drop论文复现 - php中文网R-Drop论文复现 - php中文网 R-Drop论文复现 - php中文网        

In [ ]
!cd 'data/data107306' && unzip -q img.zip!cd 'data/data106868' && unzip -q pdweights.zip
   
In [ ]
# 导入所需要的库from sklearn.utils import shuffleimport osimport pandas as pdimport numpy as npfrom PIL import Imageimport paddleimport paddle.nn as nnfrom paddle.io import Datasetimport paddle.vision.transforms as Timport paddle.nn.functional as Ffrom paddle.metric import Accuracyimport warnings
warnings.filterwarnings("ignore")# 读取数据train_images = pd.read_csv('data/data107306/img/df_all.csv')

train_images = shuffle(train_images)# 划分训练集和校验集all_size = len(train_images)# print(all_size)train_size = int(all_size * 0.9)
train_image_list = train_images[:train_size]
val_image_list = train_images[train_size:]


train_image_path_list = train_image_list['image'].values
label_list = train_image_list['label'].values
train_label_list = paddle.to_tensor(label_list, dtype='int64')

val_image_path_list = val_image_list['image'].values
val_label_list1 = val_image_list['label'].values
val_label_list = paddle.to_tensor(val_label_list1, dtype='int64')# 定义数据预处理data_transforms = T.Compose([
    T.Resize(size=(256, 256)),
   
    T.Transpose(),    # HWC -> CHW
    T.Normalize(
       
        mean = [0, 0, 0],
        std = [255, 255, 255],
        to_rgb=True)    
])# 构建Datasetclass MyDataset(paddle.io.Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, train_img_list, val_img_list,train_label_list,val_label_list, mode='train'):
        """
        步骤二:实现构造函数,定义数据读取方式,划分训练和测试数据集
        """
        super(MyDataset, self).__init__()
        self.img = []
        self.label = []
        self.valimg = []
        self.vallabel = []        # 借助pandas读csv的库
        self.train_images = train_img_list
        self.test_images = val_img_list
        self.train_label = train_label_list
        self.test_label = val_label_list        # self.mode = mode
        if mode == 'train':             # 读train_images的数据
            for img,la in zip(self.train_images, self.train_label):
                self.img.append('data/data107306/img/imgV/'+img)
                self.label.append(la)        else :            # 读test_images的数据
            for img,la in zip(self.test_images, self.test_label):
                self.img.append('data/data107306/img/imgV/'+img)
                self.label.append(la)    def load_img(self, image_path):
        # 实际使用时使用Pillow相关库进行图片读取即可,这里我们对数据先做个模拟
        image = Image.open(image_path).convert('RGB')
        image = np.array(image).astype('float32')        return image    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """

        image = self.load_img(self.img[index])
        label = self.label[index]       
        return data_transforms(image), label    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return len(self.img)
   
In [ ]
#train_loadertrain_dataset = MyDataset(train_img_list=train_image_path_list, val_img_list=val_image_path_list, train_label_list=train_label_list, val_label_list=val_label_list, mode='train')
train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=8, shuffle=True, num_workers=0)#val_loaderval_dataset = MyDataset(train_img_list=train_image_path_list, val_img_list=val_image_path_list, train_label_list=train_label_list, val_label_list=val_label_list, mode='test')
val_loader = paddle.io.DataLoader(val_dataset, places=paddle.CPUPlace(), batch_size=8, shuffle=True, num_workers=0)
   
In [ ]
from work.senet154 import SE_ResNeXt50_vd_32x4dfrom work.res2net import Res2Net50_vd_26w_4sfrom work.se_resnet import SE_ResNet50_vd# 模型封装# model_re2 = SE_ResNeXt50_vd_32x4d(class_num=4)model_re2 = Res2Net50_vd_26w_4s(class_dim=4)
model_ss = SE_ResNet50_vd(class_num=4)

model_ss.train()
model_re2.train()
epochs = 2optim1 = paddle.optimizer.Adam(learning_rate=3e-4, parameters=model_re2.parameters())
optim2 = paddle.optimizer.Adam(learning_rate=3e-4, parameters=model_ss.parameters())
   
In [ ]
import paddle.nn.functional as Fdef compute_kl_loss(p, q, pad_mask=None):

    p_loss = F.kl_div(F.log_softmax(p, axis=-1), F.softmax(q, axis=-1), reduction='none')
    q_loss = F.kl_div(F.log_softmax(q, axis=-1), F.softmax(p, axis=-1), reduction='none')    
    # pad_mask is for seq-level tasks
    if pad_mask is not None:
        p_loss.masked_fill_(pad_mask, 0.)
        q_loss.masked_fill_(pad_mask, 0.)    # You can choose whether to use function "sum" and "mean" depending on your task
    p_loss = p_loss.sum()
    q_loss = q_loss.sum()

    loss = (p_loss + q_loss) / 2
    return loss
   
In [7]
# 用Adam作为优化函数for epoch in range(epochs):    for batch_id, data in enumerate(train_loader()):

        x_data = data[0]
        y_data = data[1]

        predicts1 = model_re2(x_data)
        predicts2 = model_ss(x_data)
        
        loss1 = F.cross_entropy(predicts1, y_data, soft_label=False)
        loss2 = F.cross_entropy(predicts2, y_data, soft_label=False)        
        # cross entropy loss for classifier
        ce_loss = 0.5 * (loss1 + loss2)
        kl_loss = compute_kl_loss(predicts1, predicts2)        # 论文中对于CV任务的超参数
        α = 0.6
        # carefully choose hyper-parameters
        loss = ce_loss + α * kl_loss        # 计算损失
        acc1 = paddle.metric.accuracy(predicts1, y_data)
        acc2 = paddle.metric.accuracy(predicts2, y_data)
        
        loss.backward()        if batch_id % 50 == 0:            print("epoch: {}, batch_id: {}, loss1 is: {}".format(epoch, batch_id, loss.numpy()))

        optim1.step()
        optim1.clear_grad()
    
        optim2.step()
        optim2.clear_grad()
   

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
数据分析的方法
数据分析的方法

数据分析的方法有:对比分析法,分组分析法,预测分析法,漏斗分析法,AB测试分析法,象限分析法,公式拆解法,可行域分析法,二八分析法,假设性分析法。php中文网为大家带来了数据分析的相关知识、以及相关文章等内容。

503

2023.07.04

数据分析方法有哪几种
数据分析方法有哪几种

数据分析方法有:1、描述性统计分析;2、探索性数据分析;3、假设检验;4、回归分析;5、聚类分析。本专题为大家提供数据分析方法的相关的文章、下载、课程内容,供大家免费下载体验。

292

2023.08.07

网站建设功能有哪些
网站建设功能有哪些

网站建设功能包括信息发布、内容管理、用户管理、搜索引擎优化、网站安全、数据分析、网站推广、响应式设计、社交媒体整合和电子商务等功能。这些功能可以帮助网站管理员创建一个具有吸引力、可用性和商业价值的网站,实现网站的目标。

756

2023.10.16

数据分析网站推荐
数据分析网站推荐

数据分析网站推荐:1、商业数据分析论坛;2、人大经济论坛-计量经济学与统计区;3、中国统计论坛;4、数据挖掘学习交流论坛;5、数据分析论坛;6、网站数据分析;7、数据分析;8、数据挖掘研究院;9、S-PLUS、R统计论坛。想了解更多数据分析的相关内容,可以阅读本专题下面的文章。

534

2024.03.13

Python 数据分析处理
Python 数据分析处理

本专题聚焦 Python 在数据分析领域的应用,系统讲解 Pandas、NumPy 的数据清洗、处理、分析与统计方法,并结合数据可视化、销售分析、科研数据处理等实战案例,帮助学员掌握使用 Python 高效进行数据分析与决策支持的核心技能。

81

2025.09.08

Python 数据分析与可视化
Python 数据分析与可视化

本专题聚焦 Python 在数据分析与可视化领域的核心应用,系统讲解数据清洗、数据统计、Pandas 数据操作、NumPy 数组处理、Matplotlib 与 Seaborn 可视化技巧等内容。通过实战案例(如销售数据分析、用户行为可视化、趋势图与热力图绘制),帮助学习者掌握 从原始数据到可视化报告的完整分析能力。

59

2025.10.14

Go高并发任务调度与Goroutine池化实践
Go高并发任务调度与Goroutine池化实践

本专题围绕 Go 语言在高并发任务处理场景中的实践展开,系统讲解 Goroutine 调度模型、Channel 通信机制以及并发控制策略。内容包括任务队列设计、Goroutine 池化管理、资源限制控制以及并发任务的性能优化方法。通过实际案例演示,帮助开发者构建稳定高效的 Go 并发任务处理系统,提高系统在高负载环境下的处理能力与稳定性。

22

2026.03.10

Kotlin Android模块化架构与组件化开发实践
Kotlin Android模块化架构与组件化开发实践

本专题围绕 Kotlin 在 Android 应用开发中的架构实践展开,重点讲解模块化设计与组件化开发的实现思路。内容包括项目模块拆分策略、公共组件封装、依赖管理优化、路由通信机制以及大型项目的工程化管理方法。通过真实项目案例分析,帮助开发者构建结构清晰、易扩展且维护成本低的 Android 应用架构体系,提升团队协作效率与项目迭代速度。

48

2026.03.09

JavaScript浏览器渲染机制与前端性能优化实践
JavaScript浏览器渲染机制与前端性能优化实践

本专题围绕 JavaScript 在浏览器中的执行与渲染机制展开,系统讲解 DOM 构建、CSSOM 解析、重排与重绘原理,以及关键渲染路径优化方法。内容涵盖事件循环机制、异步任务调度、资源加载优化、代码拆分与懒加载等性能优化策略。通过真实前端项目案例,帮助开发者理解浏览器底层工作原理,并掌握提升网页加载速度与交互体验的实用技巧。

93

2026.03.06

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Git 教程
Git 教程

共21课时 | 4.2万人学习

Git版本控制工具
Git版本控制工具

共8课时 | 1.6万人学习

Git中文开发手册
Git中文开发手册

共0课时 | 94人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号