Qwen量化脚本run_gptq.py解析

Qwen量化脚本run_gptq.py解析

代码路径 https://github.com/QwenLM/Qwen/
run_gptq.py路径 https://github.com/QwenLM/Qwen/blob/main/run_gptq.py

代码解析:

import argparse
import json
from typing import Dict
import logging

import torch
import transformers
from transformers import AutoTokenizer
from transformers.trainer_pt_utils import LabelSmoother
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
IGNORE_TOKEN_ID = LabelSmoother.ignore_index

#其中json文件格式如下
# [
#   {
#     "id": "identity_0",
#     "conversations": [
#       {
#         "from": "user",
#         "value": "xxxx"
#       },
#       {
#         "from": "assistant",
#         "value": "xxx"
#       }
#     ]
#   },
#   {
#     "id": "identity_1",
#     "conversations": [
#       {
#         "from": "user",
#         "value": "xxx"
#       },
#       {
#         "from": "assistant",
#         "value": "xxx"
#       }
#     ]
#   },
# ]

def preprocess(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    max_len: int,
    system_message: str = "You are a helpful assistant."
) -> Dict:
    """preprocess函数接收一个包含对话数据的json列表作为输入,\n
    通过调用transformers库中的tokenizer对数据进行编码,\n
    并按照特定格式构建输入ID序列和目标ID序列.\n
    返回一个包含预处理数据的列表,这些数据已转换为PyTorch张量,适合于后续模型训练或推断"""
    
    #roles字典:为对话中的角色("user"和"assistant")分配特殊的前缀标签,用于区分对话双方
    roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}

    #im_start和im_end:指定tokenizer中im_start_id和im_end_id对应的整数ID。
    im_start = tokenizer.im_start_id
    im_end = tokenizer.im_end_id

    #nl_tokens:存储tokenizer处理换行符\n得到的输入ID序列。
    nl_tokens = tokenizer('\n').input_ids

    #_system、_user和_assistant:分别存储经过tokenizer处理后的"system"、"user"和"assistant"标签及其后的换行符对应的输入ID序列。
    _system = tokenizer('system').input_ids + nl_tokens
    _user = tokenizer('user').input_ids + nl_tokens
    _assistant = tokenizer('assistant').input_ids + nl_tokens

    # Apply prompt templates 定义空列表data,用于存放预处理后的数据样本
    data = []

    # input_ids, targets = [], []

    #遍历输入数据sources中的每个样本(source)
    for i, source in enumerate(sources):
        source = source["conversations"]

        #检查首个对话是否由用户发起(即source[0]["from"]是否为"user"),如果不是,则从源数据中移除首个对话。
        #过滤无效的identity
        if roles[source[0]["from"]] != roles["user"]:
            source = source[1:]

        #初始化空列表input_id和target,分别用于存储当前样本的输入ID序列和目标ID序列
        input_id, target = [], []

        #添加系统消息:将系统消息(包含system_message内容)转换为ID序列,添加到input_id和target中。
        system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
        input_id += system

        #target中的非关键部分(如系统标签和消息内容)用IGNORE_TOKEN_ID填充。
        target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens

        assert len(input_id) == len(target)

        #遍历源数据中的每个对话(sentence)
        for j, sentence in enumerate(source):

            # 提取角色和消息内容,并转换为ID序列
            role = roles[sentence["from"]]
            _input_id = tokenizer(role).input_ids + nl_tokens + \
                tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
            
            #添加到input_id中
            input_id += _input_id

            #根据角色类型,生成对应_target的目标ID序列,_target只提取assistant的对话内容,忽略user的对话内容。
            if role == '<|im_start|>user':
                #若角色为"user",则目标ID序列仅包含开始标签和结束标签,用忽略ID填充对话内容。
                _target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens

            #若角色为"assistant",则目标ID序列包含开始标签、忽略ID填充(仅对角色标签)、对话内容(不包括角色标签和结束标签)、结束标签
            elif role == '<|im_start|>assistant':
                _target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \
                    _input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens
            else:
                raise NotImplementedError
            
            target += _target


        assert len(input_id) == len(target)
        #截取并转换为张量:
        #截取input_id和target至最大长度max_len
        input_id = torch.tensor(input_id[:max_len], dtype=torch.int)
        target = torch.tensor(target[:max_len], dtype=torch.int)
        
        #创建一个字典,包含键input_ids(存储输入张量)和attention_mask(等于输入张量,用于指示非填充位置)。将该字典添加到data列表中
        data.append(dict(input_ids=input_id, attention_mask=input_id.ne(tokenizer.pad_token_id)))

    return data


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Model Quantization using AutoGPTQ")
    parser.add_argument("--model_name_or_path", type=str, help="model path")
    parser.add_argument("--data_path", type=str, help="calibration data path")
    parser.add_argument("--out_path", type=str, help="output path of the quantized model")
    parser.add_argument("--max_len", type=int, default=8192, help="max length of calibration data")
    parser.add_argument("--bits", type=int, default=4, help="the bits of quantized model. 4 indicates int4 models.")
    parser.add_argument("--group-size", type=int, default=128, help="the group size of quantized model")
    args = parser.parse_args()
    
    quantize_config = BaseQuantizeConfig(
        bits=args.bits,
        group_size=args.group_size,
        damp_percent=0.01,
        desc_act=False,  # set to False can significantly speed up inference but the perplexity may slightly bad
        static_groups=False,
        sym=True,
        true_sequential=True,
        model_name_or_path=None,
        model_file_base_name="model"
    )

    #使用AutoTokenizer类从给定路径args.model_name_or_path加载预训练的tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    tokenizer.pad_token_id = tokenizer.eod_id

    #加载json数据文件,调用process函数预处理数据,返回处理后的数据
    data = preprocess(json.load(open(args.data_path)), tokenizer, args.max_len)

    #加载预训练的模型
    model = AutoGPTQForCausalLM.from_pretrained(args.model_name_or_path, quantize_config, device_map="auto", trust_remote_code=True)

    logging.basicConfig(
        format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
    )

    #对模型进行量化,不在GPU上缓存示例数据
    model.quantize(data, cache_examples_on_gpu=False)

    #保存量化后的模型
    model.save_quantized(args.out_path, use_safetensors=True)
    #将tokenizer保存到输出路径
    tokenizer.save_pretrained(args.out_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/553528.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

遥测终端赋能水库泄洪监测预警,筑牢度汛安全防线!

4月10日&#xff0c;水利部召开水库安全度汛视频会议。会议要求着力强化水库防洪“四预”措施&#xff0c;加快构建雨水情监测预报“三道防线”&#xff0c;完善预警信息发布机制&#xff0c;推进数字孪生水利工程建设&#xff0c;为科学调度指挥决策提供支持。强调坚决牢牢守住…

基于3D点云的散货库存体积计算

首先&#xff0c;你需要散货库存的点云。 我将使用 IntelRealSense 捕获的散货库存的 .ply文件。 然而&#xff0c;任何其他产生点云的成像技术都同样有效。 点击这里查看本教程的 Github 上的代码。 NSDT工具推荐&#xff1a; Three.js AI纹理开发包 - YOLO合成数据生成器 - …

二叉树的中序遍历 - LeetCode 热题 36

大家好&#xff01;我是曾续缘&#x1f603; 今天是《LeetCode 热题 100》系列 发车第 36 天 二叉树第 1 题 ❤️点赞 &#x1f44d; 收藏 ⭐再看&#xff0c;养成习惯 二叉树的中序遍历 给定一个二叉树的根节点 root &#xff0c;返回 它的 中序 遍历 。 示例 1&#xff1a; 输…

爬楼梯(c)

文章目录 描述分析思路关键代码运行结果 描述 给定一个整数数组 cost &#xff0c;其中 cost[i]是从楼梯第i 个台阶向上爬需要支付的费用&#xff0c;下标从0开始。-旦你支付此费用&#xff0c;即可选择向上爬一个或者两个台阶 要求&#xff1a;请你计算并返回达到楼梯顶部的…

4.17

while(1) { HAL_ADC_Start(&hadc); adcVal HAL_ADC_GetValue(&hadc); TIM3->CCR3 adcVal-2000; } 1.总结串口的发送和接收功能使用到的函数 HAL_UART_Transmit_DMA(&huart1,"hello world",strlen("hello world")); HAL_UART_Tr…

Linux:如何删除指定时间之前修改的文件?

1、与文件有关的时间 在说明如何删除符合这种要求的文件之前&#xff0c;先来看看与文件有关的有哪些时间 简名全名中文名含义atimeaccess time访问时间文件中的数据最后被访问的时间mtimemodify time修改时间文件中的数据最后被修改的时间ctime change time变化时间文件的元…

JavaSE高阶篇-IO流

第一部分 file类 1&#xff09;File类 计算机常识: 1.名字为".jpg"的一定是图片吗? 不一定,有可能是文件夹 2.什么叫做文本文档: 用记事本打开,人能看懂的文件 比如:.txt .html .css等 .doc -> 不是 …

如何安装 IntelliJ IDEA 最新版本——详细教程

IntelliJ IDEA 简称 IDEA&#xff0c;被业界公认为最好的 Java 集成开发工具&#xff0c;尤其在智能代码助手、代码自动提示、代码重构、代码版本管理(Git、SVN、Maven)、单元测试、代码分析等方面有着亮眼的发挥。IDEA 产于捷克&#xff0c;开发人员以严谨著称的东欧程序员为主…

vscode 搭建stm32开发环境记录(eide+cortex-debug+jlink)

前言 clion使用的快过期了&#xff0c;所以就准备使用vscode 来代替clion作为代码开发环境 vscode 插件安装 创建个空白工程 添加项目相关的源文件&#xff0c;和配置宏定义和头文件目录 编译和烧录(ok) 结合cortex-debug 结果(测试ok)

数据可视化-ECharts Html项目实战(13)

在之前的文章中&#xff0c;我们深入学习ECharts动态主题切换和自定义ECharts主题。想了解的朋友可以查看这篇文章。同时&#xff0c;希望我的文章能帮助到你&#xff0c;如果觉得我的文章写的不错&#xff0c;请留下你宝贵的点赞&#xff0c;谢谢。 数据可视化-ECharts Html项…

Linux执行命令监控详细实现原理和使用教程,以及相关工具的使用

Linux执行命令监控详细实现原理和使用教程&#xff0c;以及相关工具的使用。 0x00 背景介绍 Linux上的HIDS需要实时对执行的命令进行监控&#xff0c;分析异常或入侵行为&#xff0c;有助于安全事件的发现和预防。为了获取执行命令&#xff0c;大致有如下方法&#xff1a; 遍…

MySQL-笔记-06.数据高级查询

目录 6.1 连接查询 6.1.1 交叉连接&#xff08;cross join&#xff09; 6.1.2 内连接&#xff08;inner join&#xff09; 6.1.3 外连接&#xff08;outer join&#xff09; 6.1.3.1 左外连接&#xff08;left [outer] join&#xff09; 6.1.3.2 右外连接&#xff08;rig…

第2章:车辆纵向控制

2.1 车辆纵向动力学模型 注&#xff1a;车辆的纵向控制是指控制车辆行驶方向上的加减速&#xff0c;使得汽车可以按照期望的速度行驶&#xff0c;并保持安全的前后车距&#xff08;即对汽车油门 / 刹车的控制&#xff09;&#xff1b; 2.1.1 车辆纵向受力模型 &#xff1a;轮胎…

SpringBootSpringCloud升级可能会出现的问题

1.背景 之前负责过我们中台的SpringBoot和Cloud的升级&#xff0c;特次记录分享一下项目中可能出现的问题&#xff0c;方便后续的人快速定位问题。以及下述选择的解决方案都是基于让升级的服务影响和改动最小以及提供通用的解决方案的提前进行选择的。 1.1版本说明 升级前&a…

OpenCV基本图像处理操作(十)——图像特征harris角点

角点 角点是图像中的一个特征点&#xff0c;指的是两条边缘交叉的点&#xff0c;这样的点在图像中通常表示一个显著的几角。在计算机视觉和图像处理中&#xff0c;角点是重要的特征&#xff0c;因为它们通常是图像中信息丰富的区域&#xff0c;可以用于图像分析、对象识别、3D…

JavaSE中的String类

1.定义方式 常见的三种字符串构造 public class Test1 {public static void main(String[] args) {// 使用常量串构造String str1 "abc";System.out.println(str1);// 直接newString对象String str2 new String("ABC");System.out.println(str2);// 使用…

【Linux学习】Linux指令(四)

文章标题 &#x1f680;zip/unzip指令&#xff1a;&#x1f680;tar指令&#xff08;重要&#xff09;&#xff1a;&#x1f680;uname –r指令&#xff1a;&#x1f680;关机指令&#x1f680;几个常用操作 &#x1f680;zip/unzip指令&#xff1a; zip 与 unzip的安装 yum i…

Day20-【Java SE高级】单元测试 反射 注解 动态代理

一、单元测试 就是针对最小的功能单元(方法)&#xff0c;编写测试代码对其进行正确性测试。 1. 咱们之前是如何进行单元测试的?有啥问题? 只能在main方法编写测试代码&#xff0c;去调用其他方法进行测试。无法实现自动化测试&#xff0c;一个方法测试失败&#xff0c;可能…

学习在Debian系统上安装Shadowsocks教程

学习在Debian系统上安装Shadowsocks教程 安装shadowsocks-libev及其所需的依赖启动Shadowsocks服务&#xff1a;如果你想要通过代理本地流量&#xff0c;你可以使用ss-local&#xff1a;启动并设置ss-local&#xff1a;查看状态本地连接 安装shadowsocks-libev及其所需的依赖 …

量化交易为什么独宠Python

“我在学一门叫Python的语言”。“什么是Python&#xff0c;没听说过啊&#xff0c;为什么不学C啊”。这是发生在2014年&#xff0c;上海的一家量化基金&#xff0c;量化研究员和老板之间的对话。 “我想问一下关于Python的课程&#xff0c;什么时候能开班”。“Python啊&#…