基于Pytorch框架的深度学习RegNet神经网络二十五种宝石识别分类系统源码

news/2024/7/14 17:55:57/文章来源:https://blog.csdn.net/m0_59023219/article/details/139276024

 第一步:准备数据

25种宝石数据,总共800张:

{ "0": "Alexandrite","1": "Almandine","2": "Benitoite","3": "Beryl Golden","4": "Carnelian", "5": "Cats Eye","6": "Danburite", "7": "Diamond","8": "Emerald","9": "Fluorite","10": "Garnet Red","11": "Hessonite","12": "Iolite","13": "Jade","14": "Kunzite","15": "Labradorite","16": "Malachite","17": "Onyx Black","18": "Pearl","19": "Quartz Beer","20": "Rhodochrosite","21": "Sapphire Blue","22": "Tanzanite","23": "Variscite","24": "Zircon"}

第二步:搭建模型

本文选择一个RegNet网络,其原理介绍如下:

该论文提出了一个新的网络设计范式,并不是专注于设计单个网络实例,而是设计了一个网络设计空间network design space。整个过程类似于经典的手工网络设计,但被提升到了设计空间的水平。使用本文的方法,作者探索了网络设计的结构方面,并得到了一个由简单、规则的网络构成了低维设计空间并称之为RegNet。RegNet设计空间提供了各个范围flop下简单、快速的网络。在类似的训练设置和flops下,RegNet的效果超过了EfficientNet同时在GPU上快了5倍

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import create_regnet
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(args)print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')tb_writer = SummaryWriter()if os.path.exists("./weights") is False:os.makedirs("./weights")train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# 如果存在预训练权重则载入model = create_regnet(model_name=args.model_name,num_classes=args.num_classes).to(device)# print(model)if args.weights != "":if os.path.exists(args.weights):weights_dict = torch.load(args.weights, map_location=device)load_weights_dict = {k: v for k, v in weights_dict.items()if model.state_dict()[k].numel() == v.numel()}print(model.load_state_dict(load_weights_dict, strict=False))else:raise FileNotFoundError("not found weights file: {}".format(args.weights))# 是否冻结权重if args.freeze_layers:for name, para in model.named_parameters():# 除最后的全连接层外,其他权重全部冻结if "head" not in name:para.requires_grad_(False)else:print("train {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# trainmean_loss = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateacc = evaluate(model=model,data_loader=val_loader,device=device)print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))tags = ["loss", "accuracy", "learning_rate"]tb_writer.add_scalar(tags[0], mean_loss, epoch)tb_writer.add_scalar(tags[1], acc, epoch)tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=25)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=4)parser.add_argument('--lr', type=float, default=0.001)parser.add_argument('--lrf', type=float, default=0.01)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"G:\demo\data\gemstone\archive_train")parser.add_argument('--model-name', default='RegNetY_400MF', help='create model name')# 预训练权重下载地址# 链接: https://pan.baidu.com/s/1XTo3walj9ai7ZhWz7jh-YA  密码: 8lmuparser.add_argument('--weights', type=str, default='regnety_400mf.pth',help='initial weights path')parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习RegNet神经网络二十五种宝石识别分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_1053304.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

架构师系列---RPC通信原理

RPC通信原理 基于网络的调用 问题:谁来解决这个跨进程调用的问题? RPC:Remote Percedure Call 远程过程调用 定义了一台主机上的程序通过网络调用另外一台主机上的程序的子程序这一行为。 RPC符合CS模型,可以实现进程间的通信&a…

超详细的前后端实战项目(Spring系列加上vue3)前端篇(二)(一步步实现+源码)

好了,兄弟们,继昨天的项目之后,开始继续敲前端代码,完成前端部分 昨天完成了全局页面的代码,和登录页面的代码,不过昨天的代码还有一些需要补充的,这里添加一下 内容补充:在调用登…

vxe-form-design 表单设计器的使用

vxe-form-design 在 vue3 中表单设计器的使用 查看官网 https://vxeui.com 安装 npm install vxe-pc-ui // ... import VxeUI from vxe-pc-ui import vxe-pc-ui/lib/style.css // ...// ... createApp(App).use(VxeUI).mount(#app) // ...使用 github vxe-form-design 用…

Vue学习笔记2——创建一个Vue项目

Vue项目 1、创建一个Vue项目2、Vue项目的目录结构3、模版语法4、属性绑定5、条件渲染 1、创建一个Vue项目 vue官方文档: https://cn.vuejs.org/打开命令行界面( “winR"再输入"cmd”),切换位置到指定的位置创建vue项目…

一文详解SpringBoot的自定义starter

目录 一、SpringBoot 二、自定义starter 三、SpringBoot的自定义starter 一、SpringBoot Spring Boot是一个开源的Java框架,由Pivotal团队(现为VMware的一部分)于2013年推出,旨在简化Spring应用程序的创建和部署过程。它基于S…

民国漫画杂志《时代漫画》第28期.PDF

时代漫画28.PDF: https://url03.ctfile.com/f/1779803-1248635321-5c67ad?p9586 (访问密码: 9586) 《时代漫画》的杂志在1934年诞生了,截止1937年6月战争来临被迫停刊共发行了39期。 ps: 资源来源网络!

Linux一键安装Docker、kkfileviewer

Linux一键安装Docker、kkfileviewer 一、安装docker 安装docker脚本 vi initDocker.sh脚本内容 #安装前先更新yum,防止连接镜像失败 yum -y update#卸载系统之前的docker(可选择,我这里直接注释了) #yum remove docker docker…

蓝桥杯物联网竞赛_STM32L071_18_长短按键检测

长短按键的检测是国赛题里面遇到的,省赛没出过有两种实方法 定时器配置: 定时器的话要比delay准确,其中tim7定时器的准度最高 定时器预分配配置32 - 1,计数周期是10000 - 1这样做那么32MHZ/32也就是一秒钟记录10^6的数&#xf…

开源远程协助:分享屏幕,隔空协助!

🖥️ 星控远程协助系统 🖱️ 一个使用Java GUI技术实现的远程控制软件,你现在就可以远程查看和控制你的伙伴的桌面,接受星星的指引吧! 支持系统:Windows / Mac / Linux 🌟 功能导览 &#x1f…

linux清理僵尸进程

1、僵尸进程是什么? 僵尸进程是当子进程比父进程先结束,而父进程又没有回收子进程,释放子进程占用的资源,此时子进程将成为一个僵尸进程。如果父进程先退出 ,子进程被init接管,子进程退出后init会回收其占…

vue2 案例入门

vue2 案例入门 1 vue环境2 案例2.1 1.v-text v-html2.2 v-bind2.3 v-model2.4 v-on2.5 v-for2.6 v-if和v-show2.7 v-else和v-else-if2.8 计算属性和侦听器2.9 过滤器2.10 组件化2.11 生命周期2.12 使用vue脚手架2.13 引入ElementUI2.13.1 npm方式安装2.13.2 main.js导入element…

读书短视频脚本:四川京之华锦信息技术公司

读书短视频脚本:打造引人入胜的文学世界 随着短视频平台的兴起,各类内容以更加直观、生动的方式呈现在观众面前。在这个信息爆炸的时代,如何将书籍的精华和魅力通过短视频这一新兴媒介传递给更多人,成为了一个值得探讨的话题。四…

有哪些藏文翻译器在线翻译?工具分享

有哪些藏文翻译器在线翻译?随着全球化的推进,语言之间的交流变得越来越重要。藏语作为中华民族的重要语言之一,其翻译需求也日益增加。为了满足这一需求,市场上涌现出了多款藏文翻译器在线翻译工具,它们以其高效、准确…

六一儿童节创意项目:教你用HTML5和CSS3制作可爱的雪糕动画

六一儿童节快到了,这是一个充满童趣和欢乐的日子。为了给孩子们增添一份节日惊喜,我们决定用HTML5和CSS3制作一个生动有趣的雪糕动画。通过这个项目,不仅能提升你的前端技能,还能带给孩子们一份特别的节日礼物。无论你是前端开发新…

CISCN——2024——re——app-debug

输入检查类题型 package com.example.re11113;import android.os.Bundle; import android.util.Log; import android.view.View.OnClickListener; import android.view.View; import android.widget.Button; import android.widget.EditText; import android.widget.Toast; im…

服务高峰期gc,导致服务不可用

随着应用程序的复杂性和负载的不断增加,对JVM进行调优,也是保障系统稳定性的一个重要方向。 需要注意,调优并非首选方案,一般来说解决性能问题还是要从应用程序本身入手(业务日志,慢请求等)&am…

[算法][数字][leetcode]2769.找出最大的可达成数字

题目地址 https://leetcode.cn/problems/find-the-maximum-achievable-number/description/ 题目描述 实现代码 class Solution {public int theMaximumAchievableX(int num, int t) {return num2*t;} }

记录一次安装k8s初始化失败

实例化 kubeadm init --configkubeadm.yaml --ignore-preflight-errorsSystemVerification报错 [init] Using Kubernetes version: v1.25.0 [preflight] Running pre-flight checks error execution phase preflight: [preflight] Some fatal errors occurred:[ERROR CRI]: co…

uniapp 使用vuex 在app上能获取到state,小程序获取不到

1. 在根目录下新建store目录, 在store目录下创建index.js定义状态值import Vue from vue; import Vuex from Vuex; import Vuex from vuex; Vue.use(Vuex);const store new Vuex.Store({ state: { login: false, token: , avatarUrl: , userName: }, mutations: { lo…

轻兔推荐 —— vfox

简介 vfox 是一个跨平台且可扩展的版本管理工具,终于有一个可以管理所有运行环境的工具了 - 支持一键安装 Java、Node.js、Flutter、.Net、Golang、PHP、Python等多种环境 - 支持一键切换不同版本 特点 支持Windows(非WSL)、Linux、macOS! 支持不同项目不同版本、…