002-基于Pytorch的手写汉字数字分类

news/2024/4/29 12:53:42/文章来源:https://blog.csdn.net/ahhjian/article/details/137020997

本节将介绍一种

2.1 准备

2.1.1 数据集

(1)MNIST

只要学习过深度学习相关理论的人,都一定听说过名字叫做LeNet-5模型,它是深度学习三巨头只有Yann Lecun在1998年提出的一个CNN模型(很多人认为这是第一个具有实际应用价值的CNN模型)。在当年使用该模型可以很好地完成手写体数字的识别,而该模型所处理的手写体数字数据库称为MNIST。

MNIST全称是:Mixed National Institute of Standards and Technology databas,它包含70000张手写数字的灰度图片,每一张图片包含 28 X 28 个像素点。数据集被分为两部分,其中训练(mnist.train)集包括60000样本,测试集(mnist.test)包含10000样本。训练集又进一步封你为 55000 个样本用于训练,5000样本用于验证。下图是MNIST样本实例图。

MNIST数据集虽然经典,但也有问题。最主要的问题是,它太简单了!相对于现在动辄上百万个参数的“大”模型,MNIST数据集要小很多,且只是简单的十类问题,因此导致现有的模型在MNIST上的分类精度都超过了95%。为了更直观地观察不同算法间的性能差异,需要用一个更复杂一点的数据集,这时Fashion-MNIST出现了。

(2)Fashion-MNIST

FashionMNIST是一个替代MNIST的图像数据集。 它是由一家德国科技公司(Zalando)整理提供。FashionMNIST 的大小、格式和训练集/测试集划分与原始的 MNIST 完全一致。60000/10000 的训练测试数据划分,28x28 的灰度图片。因此,能跑MNIST数据集的代码,只需稍加改动,就可以跑新的数据集。两个数据集的不同之处主要有两点,一是虽然两者都是以灰度图像呈现的,但MNIST呈现的是数字,背景设为0,前景设为1,FashionMNIST则是真正意义的灰度数据集。二是两者内容不同,前者被分类的是手写体数字,后者则是十类衣物服饰(分别是:T恤、裤子、套头衫、连衣裙、大衣、凉鞋、衬衫、运动鞋、包、短靴),其内容的复杂程度远高于手写体数字。下图是FashionMNIST的一个图示。

网上有很多基于FashionMNIST数据集的实例,在此就不再重复介绍。

本节实例选用的是中国版的MNIST,由英国纽卡斯尔大学整理并提供,我们不妨将其称为CHN-MNIST数据集。

(3)CHN-MNIST

该数据集共由100人书写,每人重复书写10遍,因此数据集样本数为1000组,每组包括15个汉字的数字,即“零、一、二、三、四、五、六、七、八、九、十、百、千、万、亿”,总样本数为15000。图像的分辨率为300*300。

2.1.2  模型

对于这样一个简单的分类任务,不需要使用太复杂的网络,前面提到的LeNet-5足能胜任。

对于LeNet-5网络模型的介绍,网上一搜一大把,在此不再赘述,只贴出该模型的示意图,供大家参考。

2.2 代码解析

下面将结合代码,一部分一部分的介绍具体的过程。

(1)载入必要的扩展库

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

由于是第一个例程,我们对所使用的扩展库详细加以介绍:

  • matplotlib库:用于绘图
  • numpy库:用于数值计算
  • pandas库:用于数据分析
  • torch库:提供Pytorch支持
  • PIL库:用于图像绘制
  • tqdm库:Python提供的进度条空间库

(2)设置参数

这一部分完成的是设置一些与模型训练有关的超参数。如下面代码所示:

batch_size = 32  # 批次大小
lr = 0.003  # 学习率
epochs = 10  # 迭代轮数
save_path = './best_model.pkl'  # 模型保存路径
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # 设备

各个参数的功能见注释,至于各个参数数值大小对最终结果的影响,将放在后续的章节介绍。其中最后一行是自动检测是否安装了cuda,如果是,则启动gpu加速。

(3)加载数据集

这一部分完成的是设置一些与模型训练有关的超参数。如下面代码所示:

class CustomDataset(Dataset):def __init__(self, k, l, csv_file='./chinese_mnist.csv'):self.df = pd.read_csv(csv_file)self.k = {'九': int(9), '十': int(10), '百': int(11), '千': int(12), '万': int(13), '亿': int(14), '零': int(0),'一': int(1), '二': int(2), '三': int(3), '四': int(4), '五': int(5), '六': int(6), '七': int(7),'八': int(8)}self.target = 'character'self.features = ['suite_id', 'sample_id', 'code', ]self.labels = np.asarray(self.df.iloc[:, 4])self.y = df[self.target]self.X = df.drop(self.target, axis=1)def __getitem__(self, idx):single_image_label = self.labels[idx]class_id = self.k[single_image_label]img = Image.open(f"./data/data/input_{self.X.iloc[idx, 0]}_{self.X.iloc[idx, 1]}_{self.X.iloc[idx, 2]}.jpg")img = np.array(img)return img, class_iddef __len__(self):return len(self.X)

还需要对数据集进行一下预处理,便于后面的训练过程g

# 1.构建索引到汉字的映射字典
num2char = {int(9): '九', int(10): '十', int(11): '百',int(12): '千', int(13): '万', int(14): '亿',int(0): '零', int(1): '一', int(2): '二',int(3): '三', int(4): '四', int(5): '五',int(6): '六', int(7): '七', int(8): '八'}# 2.读取csv处理文件
df = pd.read_csv('./chinese_mnist.csv', sep=',')# 3.处理数据
train_df = df.groupby('value').apply(lambda x: x.sample(700, random_state=42)).reset_index(drop=True)
x_train, y_train = train_df.iloc[:, :-2], train_df.iloc[:, -2]test_df = df.groupby('value').apply(lambda x: x.sample(300, random_state=42)).reset_index(drop=True)
x_test, y_test = test_df.iloc[:, :-2], test_df.iloc[:, -2]

(未完,待续)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_1027597.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qlib-Server:量化库数据服务器

Qlib-Server:量化库数据服务器 介绍 Qlib-Server 是 Qlib 的配套服务器系统,它利用 Qlib 进行基本计算,并提供广泛的服务器系统和缓存机制。通过 Qlib-Server,可以以集中的方式管理 Qlib 提供的数据。 框架 Qlib 的客户端/服务器框架基于 WebSocket 构建,这是因为 WebS…

学点儿Java_Day10_集合框架(List、Set、HashMap)

1 简介 ArrayList: 有序(放进去顺序和拿出来顺序一致),可重复 HashSet: 无序(放进去顺序和拿出来顺序不一定一致),不可重复 Testpublic void test1() {String[] array new String[3];//List: 有序 可重复//有序: 放入顺序 与 拿出顺序一致,…

【NLP笔记】大模型prompt推理(提问)技巧

文章目录 prompt概述推理(提问)技巧基础prompt构造技巧进阶优化技巧prompt自动优化 参考链接: Pre-train, Prompt, and Predict: A Systematic Survey of Prompting Methods in Natural Language Processing预训练、提示和预测:NL…

【并发】第二篇 ThreadLocal详解

导航 一. ThreadLocal 简介二. ThreadLocal 源码解析1. get2. set3 .remove4. initialValue三. ThreadLocalMap 源码分析1. 构造方法2. getEntry()3. set()4. resize()5. expungeStaleEntries()6. cleanSomeSlots()7. nextIndex()8. remove()9. 总结ThreadLocalMap四. 内存泄漏…

HarmonyOS 应用开发之显式Want与隐式Want匹配规则

在启动目标应用组件时,会通过显式 Want 或者隐式 Want 进行目标应用组件的匹配,这里说的匹配规则就是调用方传入的 want 参数中设置的参数如何与目标应用组件声明的配置文件进行匹配。 显式Want匹配原理 显式 Want 匹配原理如下表所示。 名称类型匹配…

NanoMQ的安装与部署

本文使用docker进行安装,因此安装之前需要已经安装了docker 拉取镜像 docker pull emqx/nanomq:latest 相关配置及密码认证 创建目录/usr/local/nanomq/conf以及配置文件nanomq.conf、pwd.conf # # # # MQTT Broker # # mqtt {property_size 32max_packet_siz…

使用苹果应用商店上架工具实现应用快速审核与发布

摘要 移动应用app上架是开发者关注的重要环节,但常常会面临审核不通过等问题。为帮助开发者顺利完成上架工作,各种辅助工具应运而生。本文探讨移动应用app上架原理、常见辅助工具功能及其作用,最终指出合理使用工具的重要性。 引言 移动应…

第4章.精通标准提示,引领ChatGPT精准输出

标准提示 标准提示,是引导ChatGPT输出的一个简单方法,它提供了一个具体的任务让模型完成。 如果你要生成一篇新闻摘要。你只要发送指示词:汇总这篇新闻 : …… 提示公式:生成[任务] 生成新闻文章的摘要: 任务&#x…

Stable Diffusion WebUI 生成参数:脚本(Script)——提示词矩阵、从文本框或文件载入提示词、X/Y/Z图表

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里,订阅后可阅读专栏内所有文章。 大家好,我是水滴~~ 在本篇文章中,我们将深入探讨 Stable Diffusion WebUI 的另一个引人注目的生成参数——脚本(Script)。我们将逐一细说提示词矩阵、从文本框或文件导入提示词,…

跑腿小程序|基于微信小程序的跑腿平台小程序设计与实现(源码+数据库+文档)

跑腿平台小程序目录 目录 基于微信小程序的跑腿平台小程序设计与实现 一、前言 二、系统设计 三、系统功能设计 1、用户信息管理 2、跑腿任务管理 3、任务类型管理 4、公告信息管理 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、…

如何安全地添加液氮到液氮罐中

液氮是一种极低温的液体,它在许多领域广泛应用,但在处理液氮时需谨慎小心。添加液氮到液氮罐中是一个常见的操作,需要遵循一些安全准则以确保操作人员的安全和设备的完整性。 选择合适的液氮容器 选用专业设计用于存储液氮的容器至关重要。…

SnapGene 5 for Mac 分子生物学软件

SnapGene 5 for Mac是一款专为Mac操作系统设计的分子生物学软件,以其强大的功能和用户友好的界面,为科研人员提供了高效、便捷的基因克隆和分子实验设计体验。 软件下载:SnapGene 5 for Mac v5.3.1中文激活版 这款软件支持DNA构建和克隆设计&…

线性代数 - 应该学啥 以及哪些可以交给计算机

AI很热,所以小伙伴们不免要温故知新旧时噩梦 - 线代。 (十几年前,还有一个逼着大家梦回课堂的风口,图形学。) 这个真的不是什么美好的回忆,且不说老师的口音,也不说教材的云山雾绕,单…

JVM(一)——内存结构

一. 前言 1、什么是 JVM? 1)定义: Java Virtual Machine - java 程序的运行环境(java 二进制字节码的运行环境) 2)好处: 一次编写,到处运行自动内存管理,垃圾回收功能数组下标越…

React Native 应用打包上架

引言 在将React Native应用上架至App Store时,除了通常的上架流程外,还需考虑一些额外的优化策略。本文将介绍如何通过配置App Transport Security、Release Scheme和启动屏优化技巧来提升React Native应用的上架质量和用户体验。 配置 App Transport…

基于振弦采集仪的土体变形监测与分析

基于振弦采集仪的土体变形监测与分析 工程监测振弦采集仪是一种专用于工程监测中的振弦测量的仪器。它能够实时采集及记录结构物的振动信号,以评估结构物的健康状况、安全性能等。它通常由振弦传感器、数据采集模块和数据处理软件组成。振弦传感器负责测量结构物的…

uniApp使用XR-Frame创建3D场景(5)材质贴图的运用

上一篇讲解了如何在uniApp中创建xr-frame子组件并创建简单的3D场景。 这篇我们讲解在xr-frame中如何给几何体赋予贴图材质。 先看源码 <xr-scene render-system"alpha:true" bind:ready"handleReady"><xr-node><xr-assets><xr-asse…

联想 lenovoTab 拯救者平板 Y700 二代_TB320FC原厂ZUI_15.0.677 firmware 线刷包9008固件ROM root方法

联想 lenovoTab 拯救者平板 Y700 二代_TB320FC原厂ZUI_15.0.677 firmware 线刷包9008固件ROM root方法 ro.vendor.config.lgsi.market_name拯救者平板 Y700 ro.vendor.config.lgsi.en.market_nameLegion Tab Y700 #ro.vendor.config.lgsi.short_market_name联想平板 ZUI T # B…

2024多云管理平台CMP排名看这里!

随着云计算技术的迅猛发展&#xff0c;多云管理平台CMP应运而生。多云管理平台CMP仅能够简化对多个云环境的统一管理&#xff0c;还能提高资源利用效率和降低成本。因此了解多云管理平台CMP品牌是必要的。2024多云管理平台CMP排名看这里&#xff01;仅供参考哈&#xff01; 20…

Java八股文(JVM)

Java八股文のJVM JVM JVM 什么是Java虚拟机&#xff08;JVM&#xff09;&#xff1f; Java虚拟机是一个运行Java字节码的虚拟机。 它负责将Java程序翻译成机器代码并执行。 JVM的主要组成部分是什么&#xff1f; JVM包括以下组件&#xff1a; ● 类加载器&#xff08;ClassLoa…