利用InceptionV3实现图像分类

news/2024/4/20 11:58:31/文章来源:https://blog.csdn.net/ziele_008/article/details/129198386

最近在做一个机审的项目,初步希望实现图像的四分类,即:正常(neutral)、涉政(political)、涉黄(porn)、涉恐(terrorism)。有朋友给推荐了个github上面的文章,浏览量还挺大的。地址如下:

https://github.com/xqtbox/generalImageClassification

我导入试了一下,发现博主没有放他训练的模型文件my_model.h5,所以代码trainMyDataWithKerasModel.py不能直接运行。必须先自己训练个模型才行,所以只好自己搞了。我开发电脑上安装的python版本是3.9.12,这个版本通常会遇到兼容性的问题,所以我决定先搭建个虚拟环境来测试一下。虚拟环境就用3.7.16了。

1、执行:conda create -n InceptionV3 python=3.7

在C:\Users\用户名\anaconda3\envs目录下创建虚拟环境InceptionV3目录。

2、执行:conda activate InceptionV3

启动InceptionV3虚拟环境。

3、执行:pip install -i https://pypi.douban.com/simple/ tensorflow==1.14.0

我的显卡是Nvidia GeForce RTX 3060的,CUDA是11.8,Cudnn是8.7.0,查了一下对应的。查了一下对应tensorflow版本是1.14.0,所以就安装这个。

4、执行:pip install -i https://pypi.douban.com/simple/ protobuf==3.19.0

5、执行:pip install -i https://pypi.douban.com/simple/ tensorflow_hub==0.9.0

6、执行:pip install -i https://pypi.douban.com/simple/ opencv-python

7、执行:pip install -i https://pypi.douban.com/simple/ scikit-learn

8、执行:pip install -i https://pypi.douban.com/simple/ albumentations==1.2.0

9、执行:pip install -i https://pypi.douban.com/simple/ h5py==2.10.0

10、执行:pip install -i https://pypi.douban.com/simple/ matplotlib

11、执行:pip install -i https://pypi.douban.com/simple/ Tensorflow-gpu==2.4.0

12、执行:pip install -i https://pypi.douban.com/simple/ keras==2.6.0

13、下面是训练代码,文件名是train1.py

import numpy as np
from tensorflow.keras.optimizers import Adamimport cv2
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.model_selection import train_test_splitfrom tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.applications import InceptionV3
import os
import tensorflow as tffrom tensorflow.python.keras.layers import Dense
from tensorflow.python.keras.models import Sequentialimport albumentations
norm_size = 224
datapath = 'data/train'
EPOCHS = 20
INIT_LR = 3e-4
labelList = []# 这里是分类详情
dicClass = {'neutral':0, 'political':1, 'porn':2, 'terrorism':3}
# 这是分类个数
classnum = 4batch_size = 2
np.random.seed(42)# tf.config.list_physical_devices('GPU')
# tf.test.is_gpu_available()def loadImageData():imageList = []listClasses = os.listdir(datapath)  # 类别文件夹print(listClasses)for class_name in listClasses:label_id = dicClass[class_name]class_path = os.path.join(datapath, class_name)image_names = os.listdir(class_path)for image_name in image_names:image_full_path = os.path.join(class_path, image_name)labelList.append(label_id)imageList.append(image_full_path)return imageListprint("开始加载数据")
imageArr = loadImageData()
labelList = np.array(labelList)
print("加载数据完成")
print(labelList)
trainX, valX, trainY, valY = train_test_split(imageArr, labelList, test_size=0.3, random_state=42)train_transform = albumentations.Compose([albumentations.OneOf([albumentations.RandomGamma(gamma_limit=(60, 120), p=0.9),albumentations.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.9),albumentations.CLAHE(clip_limit=4.0, tile_grid_size=(4, 4), p=0.9),]),albumentations.HorizontalFlip(p=0.5),albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=20,interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT, p=1),albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)])
val_transform = albumentations.Compose([albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)])def generator(file_pathList, labels, batch_size, train_action=False):L = len(file_pathList)while True:input_labels = []input_samples = []for row in range(0, batch_size):temp = np.random.randint(0, L)X = file_pathList[temp]Y = labels[temp]image = cv2.imdecode(np.fromfile(X, dtype=np.uint8), -1)if image.shape[2] > 3:image = image[:,:,:3]if train_action:image = train_transform(image=image)['image']else:image = val_transform(image=image)['image']image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)image = img_to_array(image)input_samples.append(image)input_labels.append(Y)batch_x = np.asarray(input_samples)batch_y = np.asarray(input_labels)yield (batch_x, batch_y)checkpointer = ModelCheckpoint(filepath='best_model.hdf5',monitor='val_acc', verbose=1, save_best_only=True, mode='max')reduce = ReduceLROnPlateau(monitor='val_acc', patience=10,verbose=1,factor=0.5,min_lr=1e-6)model = Sequential()
model.add(InceptionV3(include_top=False, pooling='avg', weights='imagenet'))
model.add(Dense(classnum, activation='softmax'))optimizer = Adam(learning_rate=INIT_LR)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['acc'])# print('trainX = ' + str(trainX))
# print('trainY = ' + str(trainY))model.add(tf.keras.layers.BatchNormalization())history = model.fit(generator(trainX, trainY, batch_size, train_action=True),steps_per_epoch=len(trainX) / batch_size,validation_data=generator(valX, valY, batch_size, train_action=False),epochs=EPOCHS,validation_steps=len(valX) / batch_size,callbacks=[checkpointer, reduce])
model.save('my_model.h5')
print(history)loss_trend_graph_path = r"WW_loss.jpg"
acc_trend_graph_path = r"WW_acc.jpg"
import matplotlib.pyplot as pltprint("Now,we start drawing the loss and acc trends graph...")
# summarize history for acc
fig = plt.figure(1)
plt.plot(history.history["acc"])
plt.plot(history.history["val_acc"])
plt.title("Model acc")
plt.ylabel("acc")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(acc_trend_graph_path)
plt.close(1)
# summarize history for loss
fig = plt.figure(2)
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("Model loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(loss_trend_graph_path)
plt.close(2)
print("We are done, everything seems OK...")

13.1、norm_size = 224 设置输入图像的大小,InceptionV3默认的图片尺寸是224×224。但是我的图片有300px以上的,好像也没什么问题

13.2、datapath = ‘data/train’ 设置图片存放的路径

13.3、EPOCHS = 20 epochs的数量,关于epoch的设置多少合适,这个问题很纠结,一般情况设置300足够了,如果感觉没有训练好,再载入模型训练。

13.4、INIT_LR = 1e-3 学习率,一般情况从0.001开始逐渐降低,也别太小了到1e-6就可以了。

13.5、classnum = 12 类别数量,数据集有两个类别,所有就分为两类。

13.6、batch_size = 4 batchsize,根据硬件的情况和数据集的大小设置,太小了loss浮动太大,太大了收敛不好,根据经验来,一般设置为2的次方。windows可以通过任务管理器查看显存的占用情况。

14、工程目录的文件如下图:

其中train1.py是训练程序;test.py是检测程序,本文后面会再详细讲怎么用;FormatImages.py是格式化图片的程序,功能就是把从网上爬下来比较大的图片等比压缩成300px以内。

data目录存放的就是训练用的数据,如下图:

其中train存放的是训练图片,test存放的是测试图片。train下的目录如下图:

可以看到,图中的train目录中的文件夹名要与train1.py中dicClass的值对应起来,训练数据放到对应目录下就可以了。如下图:

15、下面开始训练了,在训练之前有几个事情要做一下。

首先检查一下自己的cuda安装好没有,方法是在cmd下面输入命令nvcc -V,如果显示版本号就没问题了,如下图:

如果还没有安装也没关系,先看看自己显卡的cuda版本,如下图:

然后去https://developer.nvidia.com/cuda-toolkit-archive下载显卡对应版本的cuda工具包。如下图:

下载完成后安装到默认目录就行,一般是安装在C:\Program Files\NVIDIA GPU Computing Toolkit,如下图:

安装完成后在到https://developer.nvidia.com/rdp/cudnn-download去下载cudnn

下载完成后解压缩,把解压缩后的目录cudnn-windows-x86_64-8.8.0.121_cuda12-archive下的bin、include、lib三个目录里的文件分别复制到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8的bin、include、lib三个目录里。如下图:

最后到https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#install-zlib-windows下载ZLIB.DLL。如下图:

下载完成后解压缩,把解压后zlib123dllx64\dll_x64\zlibwapi.dll文件复制到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin目录下

现在,在train1.py目录下执行:python train1.py

可以看一下任务管理器,压力应该都在GPU上:

16、训练完成后,可以看到train1.py目录下多了几个文件,如下图:

其中my_model.h5就是咱们训练出来的模型文件。WW_acc.jpg和WW_loss.jpg是训练结果保存的图,看了一下觉得还不错。

17、接下来要验证一下模型的效果,现在data\test\放一张用于预测的图。如下图:

18、下面是测试代码,文件名是test.py:

import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from  tensorflow.keras.models import load_model
import time
import albumentations
norm_size = 224
imagelist = []emotion_labels = {0: 'neutral',1: 'political',2: 'porn',3: 'terrorism',
}val_transform = albumentations.Compose([albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)])
emotion_classifier = load_model("best_model.hdf5")
t1 = time.time()
image = cv2.imdecode(np.fromfile('data/test/01.jpg', dtype=np.uint8), -1)
image = val_transform(image=image)['image']
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imagelist.append(image)
imageList = np.array(imagelist, dtype="float")
out = emotion_classifier.predict(imageList)
print(out)
pre = np.argmax(out)
emotion = emotion_labels[pre]
t2 = time.time()
print(emotion)
t3 = t2 - t1
print(t3)

其中emotion_labels是分类,填上与训练文件中一致的内容。

在image = cv2.imdecode(np.fromfile('data/test/01.jpg', dtype=np.uint8), -1)这行修改路径,指向到用于预测的图片位置。

19、执行python test.py

可以看到,data/test/01.jpg被预测成为terrorism,验证正确。至此大功告成。

后记:我是python的领域的新兵,在开发过程中遇到最麻烦的事情就是版本的问题。tensorflow最新版本已经2.11.0了,但是使用起来会有各种问题。我尝试了很多版本,查了不少资料,最后才确定了能用的这个组合。尤其是过程中gpu一直利用不上,程序总是使用cpu在训练,经过一顿折腾总算是能用了,但是为什么这么组合,我也没有找到一个清晰的说明,希望能有大神能给解释一下CUDA、Cudnn、tensorflow、tensorflow-gpu的版本怎么组合最合理。下面把我虚拟环境的配置发上来供大家参考:

Package                 Version
----------------------- ---------
absl-py                 0.15.0
albumentations          1.2.0
astor                   0.8.1
astunparse              1.6.3
cachetools              5.3.0
certifi                 2022.12.7
charset-normalizer      3.0.1
cycler                  0.11.0
flatbuffers             1.12
fonttools               4.38.0
gast                    0.3.3
google-auth             2.16.1
google-auth-oauthlib    0.4.6
google-pasta            0.2.0
grpcio                  1.32.0
h5py                    2.10.0
idna                    3.4
imageio                 2.25.1
importlib-metadata      6.0.0
joblib                  1.2.0
keras                   2.6.0
Keras-Applications      1.0.8
Keras-Preprocessing     1.1.2
kiwisolver              1.4.4
Markdown                3.4.1
MarkupSafe              2.1.2
matplotlib              3.5.3
networkx                2.6.3
numpy                   1.19.5
oauthlib                3.2.2
opencv-python           4.7.0.68
opencv-python-headless  4.7.0.68
opt-einsum              3.3.0
packaging               23.0
Pillow                  9.4.0
pip                     22.3.1
protobuf                3.19.0
pyasn1                  0.4.8
pyasn1-modules          0.2.8
pyparsing               3.0.9
python-dateutil         2.8.2
PyWavelets              1.3.0
PyYAML                  6.0
qudida                  0.0.4
requests                2.28.2
requests-oauthlib       1.3.1
rsa                     4.9
scikit-image            0.18.3
scikit-learn            1.0.2
scipy                   1.7.3
setuptools              65.6.3
six                     1.15.0
tensorboard             2.11.2
tensorboard-data-server 0.6.1
tensorboard-plugin-wit  1.8.1
tensorflow              1.14.0
tensorflow-estimator    2.4.0
tensorflow-gpu          2.4.0
tensorflow-hub          0.9.0
termcolor               1.1.0
threadpoolctl           3.1.0
tifffile                2021.11.2
typing-extensions       3.7.4.3
urllib3                 1.26.14
Werkzeug                2.2.3
wheel                   0.38.4
wincertstore            0.2
wrapt                   1.12.1
zipp                    3.14.0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_73635.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习笔记之近似推断(一)从深度学习角度认识推断

机器学习笔记之近似推断——从深度学习角度认识推断引言推断——基本介绍精确推断难的原因虽然能够表示,但计算代价太大无法直接表示引言 本节是一篇关于推断总结的博客,侧重点在于深度学习模型中的推断任务。 推断——基本介绍 推断(Inference\text{…

Python中实现将内容进行base64编码与解码

一、需求说明需要使用Python实现将内容转为base64编码,解码,方便后续的数据操作。二、base64简介Base64是一种二进制到文本的编码方式【是一种基于 64 个可打印字符来表示二进制数据的表示方法(由于 2^664,所以每 6 个比特为一个单…

国产音质好的蓝牙耳机有哪些?国产音质最好的耳机排行

随着时间的推移,真无线蓝牙耳机逐渐占据耳机市场的份额,成为人们日常生活中必备的数码产品之一。蓝牙耳机品牌也多得数不胜数,哪些国产蓝牙耳机音质好?下面,我们从音质出来,来给大家介绍几款国产蓝牙耳机&a…

硬件系统工程师宝典(11)-----去耦电容布局“有讲究”

各位同学大家好,欢迎继续做客电子工程学习圈,今天我们继续来讲这本书,硬件系统工程师宝典。 上篇我们说到在电源完整性分析的目标就是要做到电源的干净、稳定和快速响应,以及针对不同噪声处理的实现方法。今天我们来看看去耦电容…

父传子与子传父步骤

父传子: 问题:父页面中引入子组件 把想要传给子页面的值用在子组件中用 :值“值” (用同一个值好区分)来绑定。 在子页面中用props接收 子组件不能改变父组件传过来的值。(传多个页面的时候是,比如父传孙的时候我会…

【2023】华为OD机试真题Java-题目0221-AI处理器组合

AI处理器组合 题目描述 某公司研发了一款高性能AI处理器。每台物理设备具备8颗AI处理器,编号分别为0、1、2、3、4、5、6、7。编号0-3的处理器处于同一个链路中,编号4-7的处理器处于另外一个链路中,不通链路中的处理器不能通信,如下图所示。现给定服务器可用的处理器编号数…

STM32---备份寄存器BKP和 FLASH学习使用

BKP库函数 学习BKP,首先就是知道BKP每一个函数的作用然后如何使用即可 使用备份域的作用只需要操作上面的两个函数即可,其余的都是它的其他功能 BKP简介 备份寄存器是42个16位的寄存器,可用来存储84个字节的用户应用程序数据。他们处在备份…

如何高效管理自己的时间,可以从这几个方向着手

如果你是上班族,天选打工人,你的绝大多数时间都属于老板,能够自己支配的时间其实并不多,所以你可能察觉不到时间管理的重要性。但如果你是自由职业者或者创业者,想要做出点成绩,那你就需要做好时间管理&…

2. Dart 开发工具环境配置

很多编辑器都可以用来开发dart,所以大家可以选择自己喜欢的编辑器去进行开发。我还是比较喜欢vs code如果你不用vs code来开发dart的话,这篇文章可以直接跳过。如果想要在vs code里有dart的语法提示,我们需要安装相关的插件如图点开插件输入d…

预编译(回顾)

浏览器运行机制变量提升机制私有变量提升步骤全局对象 GO 和全局变量对象 VO的区别带var和不带var的区别系统输出顺序变量提升在条件判断下的处理防止函数的变量提升 浏览器运行机制 为了能够让代码执行,浏览器首先会形成一个执行环境栈 ECStack(Execution Contex…

TCP协议和TCP连接

TCP协议和TCP连接一、TCP协议的简介二、TCP连接的简介1、TCP连接的建立(TCP三次握手)2、TCP连接的断开(TCP四次挥手)一、TCP协议的简介 TCP(Transmission Control Protocol)传输控制协议是一种面向连接的、…

java 接口 详解

目录 一、概述 1.介绍 : 2.定义 : 二、特点 1.接口成员变量的特点 : 2.接口成员方法的特点 : 3.接口构造方法的特点 : 4.接口创建对象的特点 : 5.接口继承关系的特点 : 三、应用 : 1.情景 : 2.多态 : ①多态的传递性 : ②关于接口的多态参数和多态…

股票量化交易SQL特征工程入门

虽然现在各种量化教程和自助平台铺天盖地,但是对于新人来说入门最重要的事情就是挖掘特征。 对于传统的学习路径第一步是学习Python或者某一门编程语言,虽说Python入门容易上手快,但是要在实际应用中对股票数据进行分析,并挖掘有…

2023/2/24 图数据库Neo4j的理解与应用

1 什么是图数据库(graph database) 十大应用案例:https://go.neo4j.com/rs/710-RRC-335/images/Neo4j-Top-Use-Cases-ZH.pdf “大数据”每年都在增长,但如今的企业领导者不仅需要管理更大规模的数据,还迫切需要从现有…

8000+字,就说一个字Volatile

简介 volatile是Java提供的一种轻量级的同步机制。Java 语言包含两种内在的同步机制:同步块(或方法)和 volatile 变量,相比于synchronized(synchronized通常称为重量级锁),volatile更轻量级&…

【大数据】记一次hadoop集群missing block问题排查和数据恢复

问题描述 集群环境总共有2个NN节点,3个JN节点,40个DN节点,基于hadoop-3.3.1的版本。集群采用的双副本,未使用ec纠删码。 问题如下: bin/hdfs fsck -list-corruptfileblocks / The list of corrupt files under path…

汽车零部件行业mes系统具体功能介绍

众所周知,汽车零部件的组装是汽车制造的关键环节,而汽车零部件江湖变革以精益为终极目标。即汽车零部件制造企业转型升级向精益生产和精益管理方向前进,而车间信息化管理是精益化生产的基础。 汽车零部件行业现状 随着全球汽车产业不断升级…

蓝牙标签操作指南

一、APP安装指南 1.APP权限问题 电子标签APP安装之后,会提示一些权限的申请,点击允许。否则某些会影响APP的正常运行。安装后,搜索不到蓝牙标签,可以关闭App,重新打开。 2.手机功能 运行APP时候,需要打开…

Linux基本介绍与常用操作指令

参考链接: Linux面试必备20个常用命令_无 羡ღ的博客-CSDN博客_linux常用命令 1. Linux简介 Linux是一个支持多用户、多任务、多线程和多CPU的操作系统,特点是免费、稳定、高效, 一般运行在大型服务器上。 1.1 常用目录简介 /:根目…

【华为OD机试模拟题】用 C++ 实现 - 数组的中心位置(2023.Q1)

最近更新的博客 华为OD机试 - 入栈出栈(C++) | 附带编码思路 【2023】 华为OD机试 - 箱子之形摆放(C++) | 附带编码思路 【2023】 华为OD机试 - 简易内存池 2(C++) | 附带编码思路 【2023】 华为OD机试 - 第 N 个排列(C++) | 附带编码思路 【2023】 华为OD机试 - 考古…