0

0

利用Paddle2.1高层API实现9种蘑菇的识别

P粉084495128

P粉084495128

发布时间:2025-07-28 10:51:31

|

228人浏览过

|

来源于php中文网

原创

本文围绕九种蘑菇的图像分类任务展开,采用卷积神经网络结构。先解压数据集并标注,划分出训练集与验证集,定义数据集类并做数据增强。接着选用mobilenet_v2网络,配置优化器等,经100轮训练,通过回调函数保存最佳模型,最后存储模型以备后续评估测试。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

利用paddle2.1高层api实现9种蘑菇的识别 - php中文网

① 问题定义

九种蘑菇的分类的本质是图像分类任务,采用卷积审计网络网络结构进行相关实践。

② 数据准备

2.1 解压缩数据集

我们将网上获取的数据集以压缩包的方式上传到aistudio数据集中,并加载到我们的项目内。

在使用之前我们进行数据集压缩包的一个解压。

In [1]
!unzip -oq /home/aistudio/data/data82495/mushrooms_train.zip -d work/
   

2.2 数据标注

我们先看一下解压缩后的数据集长成什么样子。

In [1]
import paddle
paddle.seed(8888)import numpy as npfrom typing import Callable#参数配置config_parameters = {    "class_dim": 9,  #分类数
    "target_path":"/home/aistudio/work/",                     
    'train_image_dir': '/home/aistudio/work/trainImages',    'eval_image_dir': '/home/aistudio/work/evalImages',    'epochs':100,    'batch_size': 128,    'lr': 0.01}
   
In [3]
import osimport randomfrom matplotlib import pyplot as pltfrom PIL import Image

imgs = []
paths = os.listdir('work/mushrooms_train')for path in paths:   
    img_path = os.path.join('work/mushrooms_train', path)    if os.path.isdir(img_path):
        img_paths = os.listdir(img_path)
        img = Image.open(os.path.join(img_path, random.choice(img_paths)))
        imgs.append((img, path))

f, ax = plt.subplots(3, 3, figsize=(12,12))for i, img in enumerate(imgs[:9]):
    ax[i//3, i%3].imshow(img[0])
    ax[i//3, i%3].axis('off')
    ax[i//3, i%3].set_title('label: %s' % img[1])
plt.show()
       
               

2.3 划分数据集与数据集的定义

接下来我们使用标注好的文件进行数据集类的定义,方便后续模型训练使用。

虎课网
虎课网

虎课网是超过1800万用户信赖的自学平台,拥有海量设计、绘画、摄影、办公软件、职业技能等优质的高清教程视频,用户可以根据行业和兴趣爱好,自主选择学习内容,每天免费学习一个...

下载

2.3.1 划分数据集

In [3]
import osimport shutil

train_dir = config_parameters['train_image_dir']
eval_dir = config_parameters['eval_image_dir']
paths = os.listdir('work/mushrooms_train')if not os.path.exists(train_dir):
    os.mkdir(train_dir)if not os.path.exists(eval_dir):
    os.mkdir(eval_dir)for path in paths:
    imgs_dir = os.listdir(os.path.join('work/mushrooms_train', path))
    target_train_dir = os.path.join(train_dir,path)
    target_eval_dir = os.path.join(eval_dir,path)    if not os.path.exists(target_train_dir):
        os.mkdir(target_train_dir)    if not os.path.exists(target_eval_dir):
        os.mkdir(target_eval_dir)    for i in range(len(imgs_dir)):        if ' ' in imgs_dir[i]:
            new_name = imgs_dir[i].replace(' ', '_')        else:
            new_name = imgs_dir[i]
        target_train_path = os.path.join(target_train_dir, new_name)
        target_eval_path = os.path.join(target_eval_dir, new_name)     
        if i % 5 == 0:
            shutil.copyfile(os.path.join(os.path.join('work/mushrooms_train', path), imgs_dir[i]), target_eval_path)        else:
            shutil.copyfile(os.path.join(os.path.join('work/mushrooms_train', path), imgs_dir[i]), target_train_path)print('finished train val split!')
       
finished train val split!
       

2.3.2 导入数据集的定义实现

In [4]
#数据集的定义class TowerDataset(paddle.io.Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, transforms: Callable, mode: str ='train'):
        """
        步骤二:实现构造函数,定义数据读取方式
        """
        super(TowerDataset, self).__init__()
        
        self.mode = mode
        self.transforms = transforms

        train_image_dir = config_parameters['train_image_dir']
        eval_image_dir = config_parameters['eval_image_dir']

        train_data_folder = paddle.vision.DatasetFolder(train_image_dir)
        eval_data_folder = paddle.vision.DatasetFolder(eval_image_dir)        
        if self.mode  == 'train':
            self.data = train_data_folder        elif self.mode  == 'eval':
            self.data = eval_data_folder    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        data = np.array(self.data[index][0]).astype('float32')

        data = self.transforms(data)

        label = np.array([self.data[index][1]]).astype('int64')        
        return data, label        
    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return len(self.data)
   
In [5]
from paddle.vision import transforms as T#数据增强transform_train =T.Compose([T.Resize((256,256)),
                            T.RandomHorizontalFlip(5),
                            T.RandomRotation(15),
                            T.Transpose(),
                            T.Normalize(mean=[0, 0, 0],                           # 像素值归一化
                                        std =[255, 255, 255]),                    # transforms.ToTensor(), # transpose操作 + (img / 255),并且数据结构变为PaddleTensor
                            T.Normalize(mean=[0.50950350, 0.54632660, 0.57409690],# 减均值 除标准差    
                                        std= [0.26059777, 0.26041326, 0.29220656])# 计算过程:output[channel] = (input[channel] - mean[channel]) / std[channel]
                            ])
transform_eval =T.Compose([ T.Resize((256,256)),
                            T.Transpose(),
                            T.Normalize(mean=[0, 0, 0],                           # 像素值归一化
                                        std =[255, 255, 255]),                    # transforms.ToTensor(), # transpose操作 + (img / 255),并且数据结构变为PaddleTensor
                            T.Normalize(mean=[0.50950350, 0.54632660, 0.57409690],# 减均值 除标准差    
                                        std= [0.26059777, 0.26041326, 0.29220656])# 计算过程:output[channel] = (input[channel] - mean[channel]) / std[channel]
                            ])
   
In [6]
train_dataset = TowerDataset(mode='train',transforms=transform_train)
eval_dataset  = TowerDataset(mode='eval', transforms=transform_eval )#数据异步加载train_loader = paddle.io.DataLoader(train_dataset, 
                                    places=paddle.CUDAPlace(0), 
                                    batch_size=128, 
                                    shuffle=True,                                    #num_workers=2,
                                    #use_shared_memory=True
                                    )
eval_loader = paddle.io.DataLoader (eval_dataset, 
                                    places=paddle.CUDAPlace(0), 
                                    batch_size=128,                                    #num_workers=2,
                                    #use_shared_memory=True
                                    )
   

2.3.3 实例化数据集类

根据所使用的数据集需求实例化数据集类,并查看总样本量。

In [7]
print('训练集样本量: {},验证集样本量: {}'.format(len(train_loader), len(eval_loader)))
       
训练集样本量: 42,验证集样本量: 11
       

③ 模型选择和开发

3.1 网络构建

本次我们使用mobilenet_v2网络来完成我们的案例实践。

In [11]
import paddlefrom paddle.vision.models import mobilenet_v2
network=paddle.vision.models.mobilenet_v2(pretrained=True,num_classes=9)
model=paddle.Model(network)
       
2021-04-20 04:52:16,152 - INFO - unique_endpoints {''}
2021-04-20 04:52:16,153 - INFO - File /home/aistudio/.cache/paddle/hapi/weights/mobilenet_v2_x1.0.pdparams md5 checking...
2021-04-20 04:52:16,203 - INFO - Found /home/aistudio/.cache/paddle/hapi/weights/mobilenet_v2_x1.0.pdparams
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for classifier.1.weight. classifier.1.weight receives a shape [1280, 1000], but the expected shape is [1280, 9].
  warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for classifier.1.bias. classifier.1.bias receives a shape [1000], but the expected shape is [9].
  warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
       

④ 模型训练和优化器的选择

In [12]
#优化器选择class SaveBestModel(paddle.callbacks.Callback):
    def __init__(self, target=0.5, path='work/best_model', verbose=0):
        self.target = target
        self.epoch = None
        self.path = path    def on_epoch_end(self, epoch, logs=None):
        self.epoch = epoch    def on_eval_end(self, logs=None):
        if logs.get('acc') > self.target:
            self.target = logs.get('acc')
            self.model.save(self.path)            print('best acc is {} at epoch {}'.format(self.target, self.epoch))

callback_visualdl = paddle.callbacks.VisualDL(log_dir='work/mushroom')
callback_savebestmodel = SaveBestModel(target=0.5, path='work/best_model')
callbacks = [callback_visualdl, callback_savebestmodel]

base_lr = config_parameters['lr']
epochs = config_parameters['epochs']def make_optimizer(parameters=None):
    momentum = 0.9

    learning_rate= paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=base_lr, T_max=epochs, verbose=False)
    weight_decay=paddle.regularizer.L2Decay(0.01)
    optimizer = paddle.optimizer.Momentum(
        learning_rate=learning_rate,
        momentum=momentum,
        weight_decay=weight_decay,
        parameters=parameters)    return optimizer

optimizer = make_optimizer(model.parameters())

model.prepare(optimizer,
              paddle.nn.CrossEntropyLoss(),
              paddle.metric.Accuracy())
   
In [13]
model.fit(train_loader,
          eval_loader,
          epochs=100,
          batch_size=128,   
          callbacks=callbacks, 
          verbose=1)   # 日志展示格式
   

⑤模型训练效果展示

利用Paddle2.1高层API实现9种蘑菇的识别 - php中文网        

⑥模型存储

将我们训练得到的模型进行保存,以便后续评估和测试使用。

In [14]
model.save(get('model_save_dir'))
   

相关专题

更多
github中文官网入口 github中文版官网网页进入
github中文官网入口 github中文版官网网页进入

github中文官网入口https://docs.github.com/zh/get-started,GitHub 是一种基于云的平台,可在其中存储、共享并与他人一起编写代码。 通过将代码存储在GitHub 上的“存储库”中,你可以: “展示或共享”你的工作。 持续“跟踪和管理”对代码的更改。

1

2026.01.21

windows安全中心怎么关闭打开_windows安全中心操作指南
windows安全中心怎么关闭打开_windows安全中心操作指南

Windows安全中心可以通过系统设置轻松开关。 暂时关闭:打开“设置” -> “隐私和安全性” -> “Windows安全中心” -> “病毒和威胁防护” -> “管理设置”,将“实时保护”关闭。打开:同样路径将开关开启即可。如需彻底关闭,需在组策略(gpedit.msc)或注册表中禁用Windows Defender。

0

2026.01.21

C++游戏开发Unreal Engine_C++怎么用Unreal Engine开发游戏
C++游戏开发Unreal Engine_C++怎么用Unreal Engine开发游戏

虚幻引擎(Unreal Engine, 简称UE)是由Epic Games开发的一款功能强大的工业级3D游戏引擎,以高品质实时渲染(如Nanite和Lumen)闻名 。它基于C++语言,为开发者提供高效率的框架、强大的可视化脚本系统(蓝图)、以及针对PC、主机和移动端的完整开发工具,广泛用于游戏、电影制片等领域。

0

2026.01.21

Python GraphQL API 开发实战
Python GraphQL API 开发实战

本专题系统讲解 Python 在 GraphQL API 开发中的实际应用,涵盖 GraphQL 基础概念、Schema 设计、Query 与 Mutation 实现、权限控制、分页与性能优化,以及与现有 REST 服务和数据库的整合方式。通过完整示例,帮助学习者掌握 使用 Python 构建高扩展性、前后端协作友好的 GraphQL 接口服务,适用于中大型应用与复杂数据查询场景。

1

2026.01.21

云朵浏览器入口合集
云朵浏览器入口合集

本专题整合了云朵浏览器入口合集,阅读专题下面的文章了解更多详细地址。

22

2026.01.20

Java JVM 原理与性能调优实战
Java JVM 原理与性能调优实战

本专题系统讲解 Java 虚拟机(JVM)的核心工作原理与性能调优方法,包括 JVM 内存结构、对象创建与回收流程、垃圾回收器(Serial、CMS、G1、ZGC)对比分析、常见内存泄漏与性能瓶颈排查,以及 JVM 参数调优与监控工具(jstat、jmap、jvisualvm)的实战使用。通过真实案例,帮助学习者掌握 Java 应用在生产环境中的性能分析与优化能力。

29

2026.01.20

PS使用蒙版相关教程
PS使用蒙版相关教程

本专题整合了ps使用蒙版相关教程,阅读专题下面的文章了解更多详细内容。

175

2026.01.19

java用途介绍
java用途介绍

本专题整合了java用途功能相关介绍,阅读专题下面的文章了解更多详细内容。

125

2026.01.19

java输出数组相关教程
java输出数组相关教程

本专题整合了java输出数组相关教程,阅读专题下面的文章了解更多详细内容。

41

2026.01.19

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 9.5万人学习

Django 教程
Django 教程

共28课时 | 3.3万人学习

SciPy 教程
SciPy 教程

共10课时 | 1.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号