0

0

CPM-Distill:经过知识蒸馏的小型文本生成模型

P粉084495128

P粉084495128

发布时间:2025-07-18 13:46:16

|

246人浏览过

|

来源于php中文网

原创

本文介绍知识蒸馏技术及基于PaddleNLP加载CPM-Distill模型实现文本生成。知识蒸馏是模型压缩方法,以“教师-学生网络”思想,让简单模型拟合复杂模型输出,效果优于从头训练。CPM-Distill由GPT-2 Large蒸馏得到,文中还给出安装依赖、加载模型、解码方法及文本生成示例。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

cpm-distill:经过知识蒸馏的小型文本生成模型 - php中文网

引入

  • 近些年来,随着 Bert 这样的大规模预训练模型的问世,NLP 领域的模型也逐渐变得越来越大了
  • 受限于算力水平,如此大规模的模型要应用在实际的部署场景都是不太实际的
  • 因此需要通过一些方式对大规模的模型进行压缩,使其能够在部署场景下达到一个相对可用的速度
  • 常见的模型压缩方法有:剪枝、量化、知识蒸馏等
  • 最近 CPM(Chinese Pre-Trained Models)项目又开源了一个使用知识蒸馏得到的小型文本生成模型 CPM-Distill
  • 本次项目就简单介绍一下知识蒸馏技术并且通过 PaddleNLP 套件加载 CPM-Distill 模型实现文本生成

相关项目

  • Paddle2.0:构建一个经典的文本生成模型GPT-2
  • 文本生成:使用GPT-2加载CPM-LM模型实现简单的问答机器人
  • 文本生成:让AI帮你写文章吧
  • 【AI创造营】PaddleHub 配合 PaddleNLP 实现简单的文本生成

相关资料

  • 论文:
    • CPM: A Large-scale Generative Chinese Pre-trained Language Model
    • Distilling the Knowledge in a Neural Network
  • 官方实现:TsinghuaAI/CPM-Distill

模型压缩技术

CPM-Distill:经过知识蒸馏的小型文本生成模型 - php中文网

知识蒸馏(Knowledge Distillation)

  • 知识蒸馏是一种模型压缩方法,是一种基于“教师-学生网络思想”的训练方法。

  • 由 Hinton 在 2015 年 Distilling the Knowledge in a Neural Network 的论文首次提出了知识蒸馏的并尝试在 CV 领域中使用,旨在把大模型学到的知识灌输到小模型中,以达到缩小模型的目标,示意图如下:

CPM-Distill:经过知识蒸馏的小型文本生成模型 - php中文网

MaxAI
MaxAI

MaxAI.me是一款功能强大的浏览器AI插件,集成了多种AI模型。

下载
  • 说人话就是指用一个简单模型去拟合复杂模型的输出,这个输出也叫做“软标签”,当然也可以加入真实数据作为“硬标签”一同训练。
  • 使用知识蒸馏技术相比直接从头训练的效果一般会更好一些,因为教师模型能够指导学生模型收敛到一个更佳的位置。

CPM-Distill:经过知识蒸馏的小型文本生成模型 - php中文网

  • 知识蒸馏技术除了可以用来将网络从大网络转化成一个小网络,并保留接近于大网络的性能;
  • 也可以将多个网络的学到的知识转移到一个网络中,使得单个网络的性能接近 emsemble 的结果。

蒸馏模型信息

  • 教师模型为 GPT-2 Large,具体的模型参数如下:
teacher_model = GPTModel(
    vocab_size=30000,
    hidden_size=2560,
    num_hidden_layers=32,
    num_attention_heads=32,
    intermediate_size=10240,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=1024,
    type_vocab_size=1,
    initializer_range=0.02,
    pad_token_id=0,
    topo=None)
  • 学生模型为 GPT-2 Small,具体的模型参数如下:
teacher_model = GPTModel(
    vocab_size=30000,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=1024,
    type_vocab_size=1,
    initializer_range=0.02,
    pad_token_id=0,
    topo=None)

蒸馏 loss

  • 将大模型和小模型每个位置上输出之间的 KL 散度作为蒸馏 loss,同时加上原来的 language model loss。总 loss 如下:

CPM-Distill:经过知识蒸馏的小型文本生成模型 - php中文网

其中 LlmLlm 为 GPT-2 原始的 language modeling loss。

安装依赖

In [ ]
!pip install paddlenlp==2.0.1 sentencepiece==0.1.92

加载模型

In [1]
import paddlefrom paddlenlp.transformers import GPTModel, GPTForPretraining, GPTChineseTokenizer# tokenizer 与 CPM-LM 模型一致tokenizer = GPTChineseTokenizer.from_pretrained('gpt-cpm-large-cn')# 实例化 GPT2-small 模型gpt = GPTModel(
    vocab_size=30000,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=1024,
    type_vocab_size=1,
    initializer_range=0.02,
    pad_token_id=0,
    topo=None)# 加载预训练模型参数params = paddle.load('data/data92160/gpt-cpm-small-cn-distill.pdparams')# 设置参数gpt.set_dict(params)# 使用 GPTForPretraining 向模型中添加输出层model = GPTForPretraining(gpt)# 将模型设置为评估模式model.eval()
[2021-05-28 19:38:04,469] [    INFO] - Found /home/aistudio/.paddlenlp/models/gpt-cpm-large-cn/gpt-cpm-cn-sentencepiece.model

模型解码

In [40]
import paddleimport numpy as np# Greedy Searchdef greedy_search(text, max_len=32, end_word=None):
    # # 终止标志
    if end_word is not None:
        stop_id = tokenizer.encode(end_word)['input_ids']
        length = len(stop_id)    else:
        stop_id = [tokenizer.eod_token_id]
        length = len(stop_id)    
    # 初始预测
    ids = tokenizer.encode(text)['input_ids']
    input_id = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
    output, cached_kvs = model(input_id, use_cache=True)
    next_token = int(np.argmax(output[0, -1].numpy()))
    ids.append(next_token)    # 使用缓存进行继续预测
    for i in range(max_len-1):
        input_id = paddle.to_tensor(np.array([next_token]).reshape(1, -1).astype('int64'))
        output, cached_kvs = model(input_id, use_cache=True, cache=cached_kvs)
        next_token = int(np.argmax(output[0, -1].numpy()))
        ids.append(next_token)        # 根据终止标志停止预测
        if ids[-length:]==stop_id:            if end_word is None:
               ids = ids[:-1]            break
    
    return tokenizer.convert_ids_to_string(ids)
In [39]
import paddleimport numpy as np# top_k and top_p filteringdef top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.shape[-1])  # Safety check
    logits_np = logits.numpy()    if top_k > 0:        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits_np < np.sort(logits_np)[-top_k]
        logits_np[indices_to_remove] = filter_value    if top_p < 1.0:
        sorted_logits = paddle.sort(logits, descending=True)
        sorted_indices = paddle.argsort(logits, descending=True).numpy()
        cumulative_probs = paddle.cumsum(paddle.nn.functional.softmax(sorted_logits, axis=-1), axis=-1).numpy()        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits_np[indices_to_remove] = filter_value    return paddle.to_tensor(logits_np)# Nucleus Sampledef nucleus_sample(text, max_len=32, end_word=None, repitition_penalty=1.0, temperature=1.0, top_k=0, top_p=1.0):
    # 终止标志
    if end_word is not None:
        stop_id = tokenizer.encode(end_word)['input_ids']
        length = len(stop_id)    else:
        stop_id = [tokenizer.eod_token_id]
        length = len(stop_id)    # 初始预测
    ids = tokenizer.encode(text)['input_ids']
    input_id = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
    output, cached_kvs = model(input_id, use_cache=True)
    next_token_logits = output[0, -1, :]    for id in set(ids):
        next_token_logits[id] /= repitition_penalty
    next_token_logits = next_token_logits / temperature
    filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
    next_token = paddle.multinomial(paddle.nn.functional.softmax(filtered_logits, axis=-1), num_samples=1).numpy()
    ids += [int(next_token)]    # 使用缓存进行继续预测
    for i in range(max_len-1):
        input_id = paddle.to_tensor(np.array([next_token]).reshape(1, -1).astype('int64'))
        output, cached_kvs = model(input_id, use_cache=True, cache=cached_kvs)
        next_token_logits = output[0, -1, :]        for id in set(ids):
            next_token_logits[id] /= repitition_penalty
        next_token_logits = next_token_logits / temperature
        filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
        next_token = paddle.multinomial(paddle.nn.functional.softmax(filtered_logits, axis=-1), num_samples=1).numpy()
        ids += [int(next_token)]        # 根据终止标志停止预测
        if ids[-length:]==stop_id:            if end_word is None:
               ids = ids[:-1]            break

    return tokenizer.convert_ids_to_string(ids)

文本生成

In [41]
# 输入文本inputs = input('请输入文本:')print(inputs)# 使用 Nucleus Sample 进行文本生成outputs = greedy_search(
    inputs, # 输入文本
    max_len=128, # 最大生成文本的长度
    end_word=None)# 打印输出print(outputs)
请输入文本:请在此处输入你的姓名
请在此处输入你的姓名,然后点击“确定”,就可以开始游戏了。
游戏目标:在限定时间内,成功地把所有的牌都通通打完。
In [43]
# 输入文本inputs = input('请输入文本:')print(inputs)for x in range(5):    # 使用 Nucleus Sample 进行文本生成
    outputs = nucleus_sample(
        inputs, # 输入文本
        max_len=128, # 最大生成文本的长度
        end_word='。', # 终止符号
        repitition_penalty=1.0, # 重复度抑制
        temperature=1.0, # 温度
        top_k=3000, # 取前k个最大输出再进行采样
        top_p=0.9 # 抑制概率低于top_p的输出再进行采样
    )    # 打印输出
    print(outputs)
请输入文本:请在此处输入你的姓名
请在此处输入你的姓名、学校、专业及学科,并在社交媒体上公布你的个人简介。
请在此处输入你的姓名或者电话,对方会及时通知你。
请在此处输入你的姓名、民族及籍贯信息,当您找到 CADULI 的联系方式后,我们会按您所选择的申请中心,以电子邮件的形式向您发送邮件。
请在此处输入你的姓名和电话号码,由资深会所接待员进行介绍,因为此处有不少中国的大老板,英文能看。
请在此处输入你的姓名、联系电话、银行卡号和手机号。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

WorkBuddy
WorkBuddy

腾讯云推出的AI原生桌面智能体工作台

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
TypeScript类型系统进阶与大型前端项目实践
TypeScript类型系统进阶与大型前端项目实践

本专题围绕 TypeScript 在大型前端项目中的应用展开,深入讲解类型系统设计与工程化开发方法。内容包括泛型与高级类型、类型推断机制、声明文件编写、模块化结构设计以及代码规范管理。通过真实项目案例分析,帮助开发者构建类型安全、结构清晰、易维护的前端工程体系,提高团队协作效率与代码质量。

76

2026.03.13

Python异步编程与Asyncio高并发应用实践
Python异步编程与Asyncio高并发应用实践

本专题围绕 Python 异步编程模型展开,深入讲解 Asyncio 框架的核心原理与应用实践。内容包括事件循环机制、协程任务调度、异步 IO 处理以及并发任务管理策略。通过构建高并发网络请求与异步数据处理案例,帮助开发者掌握 Python 在高并发场景中的高效开发方法,并提升系统资源利用率与整体运行性能。

117

2026.03.12

C# ASP.NET Core微服务架构与API网关实践
C# ASP.NET Core微服务架构与API网关实践

本专题围绕 C# 在现代后端架构中的微服务实践展开,系统讲解基于 ASP.NET Core 构建可扩展服务体系的核心方法。内容涵盖服务拆分策略、RESTful API 设计、服务间通信、API 网关统一入口管理以及服务治理机制。通过真实项目案例,帮助开发者掌握构建高可用微服务系统的关键技术,提高系统的可扩展性与维护效率。

350

2026.03.11

Go高并发任务调度与Goroutine池化实践
Go高并发任务调度与Goroutine池化实践

本专题围绕 Go 语言在高并发任务处理场景中的实践展开,系统讲解 Goroutine 调度模型、Channel 通信机制以及并发控制策略。内容包括任务队列设计、Goroutine 池化管理、资源限制控制以及并发任务的性能优化方法。通过实际案例演示,帮助开发者构建稳定高效的 Go 并发任务处理系统,提高系统在高负载环境下的处理能力与稳定性。

63

2026.03.10

Kotlin Android模块化架构与组件化开发实践
Kotlin Android模块化架构与组件化开发实践

本专题围绕 Kotlin 在 Android 应用开发中的架构实践展开,重点讲解模块化设计与组件化开发的实现思路。内容包括项目模块拆分策略、公共组件封装、依赖管理优化、路由通信机制以及大型项目的工程化管理方法。通过真实项目案例分析,帮助开发者构建结构清晰、易扩展且维护成本低的 Android 应用架构体系,提升团队协作效率与项目迭代速度。

109

2026.03.09

JavaScript浏览器渲染机制与前端性能优化实践
JavaScript浏览器渲染机制与前端性能优化实践

本专题围绕 JavaScript 在浏览器中的执行与渲染机制展开,系统讲解 DOM 构建、CSSOM 解析、重排与重绘原理,以及关键渲染路径优化方法。内容涵盖事件循环机制、异步任务调度、资源加载优化、代码拆分与懒加载等性能优化策略。通过真实前端项目案例,帮助开发者理解浏览器底层工作原理,并掌握提升网页加载速度与交互体验的实用技巧。

108

2026.03.06

Rust内存安全机制与所有权模型深度实践
Rust内存安全机制与所有权模型深度实践

本专题围绕 Rust 语言核心特性展开,深入讲解所有权机制、借用规则、生命周期管理以及智能指针等关键概念。通过系统级开发案例,分析内存安全保障原理与零成本抽象优势,并结合并发场景讲解 Send 与 Sync 特性实现机制。帮助开发者真正理解 Rust 的设计哲学,掌握在高性能与安全性并重场景中的工程实践能力。

243

2026.03.05

PHP高性能API设计与Laravel服务架构实践
PHP高性能API设计与Laravel服务架构实践

本专题围绕 PHP 在现代 Web 后端开发中的高性能实践展开,重点讲解基于 Laravel 框架构建可扩展 API 服务的核心方法。内容涵盖路由与中间件机制、服务容器与依赖注入、接口版本管理、缓存策略设计以及队列异步处理方案。同时结合高并发场景,深入分析性能瓶颈定位与优化思路,帮助开发者构建稳定、高效、易维护的 PHP 后端服务体系。

684

2026.03.04

AI安装教程大全
AI安装教程大全

2026最全AI工具安装教程专题:包含各版本AI绘图、AI视频、智能办公软件的本地化部署手册。全篇零基础友好,附带最新模型下载地址、一键安装脚本及常见报错修复方案。每日更新,收藏这一篇就够了,让AI安装不再报错!

179

2026.03.04

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Git 教程
Git 教程

共21课时 | 4.2万人学习

Git版本控制工具
Git版本控制工具

共8课时 | 1.6万人学习

Git中文开发手册
Git中文开发手册

共0课时 | 94人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号