tensorflow多层感知机+mnist数据集

news/2024/4/30 17:30:31/文章来源:https://blog.csdn.net/m0_54510474/article/details/127286582

这里写目录标题

  • keras与tensorflow建立模型的不同
  • 加载mnist
  • tensorflow多层感知机构建
    • 全连接层函数定义
    • 感知机各层的输入和输出
    • 损失函数、优化器
    • 模型准确率计算
    • 模型训练参数定义
    • 训练开始
    • 模型训练效果可视化
    • 模型评分
    • 利用模型进行预测
    • 显示混淆矩阵

keras与tensorflow建立模型的不同

kerastensorflow
模型建立model=Sqeuential():建立线性堆叠模型,
model.add():将各个神经网络层加入网络即可
自定义layer函数,使用layer建立多层感知机模型
模型优化在model.compile()函数中,模型预编译阶段完成自定义loss损失函数以及optimizer优化器
模型训练在model.fit()函数中,传入x_train,y_train,epochs,batch_size,验证集比例等参数后由函数自动完成,返回history对象其中包含模型训练过程中的loss和acc变化训练过程需要自定义,通过循环定义模型在每个epoch中需要执行的操作,每次epoch通过自定义的loss_fun和acc函数计算模型的loss和acc存入数组中作为训练过程的"history"

加载mnist

import tensorflow.examples.tutorials.mnist.input_data as input_data
path=r'E:\mydataset\data\LeYun_mnist'
mnist=input_data.read_data_sets(path,one_hot=True)
mnist对象的属性值数据分类属性属性值
mnist.train.num_examples训练集图片数量55000
.images.shape-数据形状(55000, 784)
.labels.shape-数据形状(55000, 10)
mnist.test.num_examples测试集图片数量10000
.images.shape-数据形状(10000, 784)
.images.shape-数据形状(10000, 10)
mnist.validation.num_examples交叉验证集图片数量5000
.images.shape-数据形状(5000,784)
.labels.shape-数据形状(5000,10)

注:path文件夹下面需要包含mnist数据集压缩包(共4个文件),否则函数将先从网络上下载,如图在这里插入图片描述

tensorflow多层感知机构建

全连接层函数定义

import tensorflow as tf
def layer(output_dim,input_dim,inputs,activation=None):w=tf.Variable(tf.random_normal([input_dim,output_dim]))b=tf.Variable(tf.random_normal([1,output_dim]))y=tf.matmul(inputs,w)+bif activation is None:return yelse:return activation(y)

感知机各层的输入和输出

# 感知机各层输入和输出
x=tf.placeholder('float',[None,784])
h1=layer(output_dim=256,input_dim=784,inputs=x,activation=tf.nn.relu)
y_predict=layer(output_dim=10,input_dim=256,inputs=h1,activation=None)
y_label=tf.placeholder('float',[None,10])

损失函数、优化器

loss=tf.nn.softmax_cross_entropy_with_logits(logits=y_predict,labels=y_label)
loss_fun=tf.reduce_mean(loss)
optimizer=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_fun)

模型准确率计算

correct_prediction=tf.equal(tf.argmax(y_label,1),tf.argmax(y_predict,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,'float'))

模型训练参数定义

epochs=15
batch_size=100
total_epoch=int(mnist.train.num_examples/batch_size)
loss_list=[]
accuracy_list=[]
sess=tf.Session()
sess.run(tf.global_variables_initializer())

训练开始

for epoch in range(epochs):for i in range(total_epoch):batch_x,batch_y=mnist.train.next_batch(batch_size)sess.run(optimizer,feed_dict={x:batch_x,y_label:batch_y})loss,acc=sess.run([loss_fun,accuracy],feed_dict={x:mnist.validation.images,y_label:mnist.validation.labels})loss_list.append(loss)accuracy_list.append(acc)print((epoch+1,loss,acc))

运行结果
在这里插入图片描述

模型训练效果可视化

import matplotlib.pyplot as pltfig=plt.gcf()
plt.plot(range(1,16),accuracy_list,label='accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(loc='upper left')

运行结果
在这里插入图片描述

fig=plt.gcf()
plt.plot(range(1,16),loss_list,label='loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(loc='upper left')

运行结果
在这里插入图片描述

模型评分

score=sess.run(accuracy,feed_dict={x:mnist.test.images,y_label:mnist.test.labels})
>>> 0.9427

利用模型进行预测

import numpy as np
pred=sess.run(tf.argmax(y_predict,1),feed_dict={x:mnist.test.images})
pred[:10]
>>> array([7, 2, 1, 0, 4, 1, 4, 9, 6, 9], dtype=int64)
real_label=np.argmax(mnist.test.labels,axis=-1)
real_label[:10]
>>> array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9], dtype=int64)

运行结果
在这里插入图片描述

显示混淆矩阵

import pandas as pd
real_label=real_label.reshape(-1)
pd.crosstab(real_label,pred,rownames=['label'],colnames=['predict'])

运行结果
在这里插入图片描述

总体来看效果还算不错

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_22425.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

两栏布局与三栏布局(圣杯布局与双飞翼布局)

两栏布局 右侧绝对定位的写法 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><meta name"viewport" content&qu…

浅析某城商行手机银行水平授权漏洞问题

-问题现象描述 据报道&#xff0c;某黑客通过软件抓包、PS身份证等非法手段&#xff0c;在某城商行手机银行APP内使用虚假身份信息注册银行Ⅱ、Ⅲ类账户。 其操作方法具体来说&#xff0c;是在注册账户过程中&#xff0c;先输入本人身份信息&#xff0c;待进行人脸识别步骤时…

Flask学习笔记(十二)-Flask-Migrate实现数据库迁移详解

一、定义flask-migrate是基于Alembic的一个封装,并集成到Flask中 所有的迁移操作其实都是Alembic做的,能跟踪模型的变化,并将变化映射到数据库中。二、Flask-Migrate安装pip install flask-migrate三、使用Flask-Migrate步骤  实例展示: 目录结构:  flask_SQLalchemy:. …

MindSpore体验--在Windows10中源码安装

MindSpore体验--在Windows10中源码安装&#xff08;反面教材&#xff09; 一直以来安装包都是直接pip intall&#xff0c;发现安装MindSpore的操作流程中教学了源码编译安装&#xff0c;借此学习一下使用源码安装。 环境创建 为了方便管理环境&#xff0c;此处我新创建了一个…

多测师肖sir_高级讲师_第2个月第27讲解jmeter性能测试jmeter性能实战

jmeter性能实战 一、单接口性能测试 1、先建接口cms 登录接口 2、在监听器中添加聚合报告 3、设置线程组 &#xff08;1&#xff09;线程组&#xff1a;一个线程组中有若干个请求 &#xff08;2&#xff09;线程 &#xff1a;一个虚拟用户就是一个线程 &#xff08;3&#…

webpack的一些常用打包配置

1.webpack 是什么&#xff1f; webpack 是一个模块化打包工具 2.模块是什么&#xff1f; 模块我理解就是 import xx 后面导入的文件就是一个模块 它可以是js css 图片 等等 3&#xff0c;webpack的配置文件的作用&#xff1f; 就是根据需求自定义配置webpack webpack默认只能打…

轻轻松松搞定分布式Token校验

文章目录前言token存储Token 存储实体login 业务代码枚举类修改存储效果客户端存储token验证前端提交后端校验自定义注解切面处理使用总结前言 没想到前天小水了一篇博文&#xff0c;竟然就火了&#xff1f;&#xff01;&#xff01;既然如此&#xff0c;那我再来一篇&#xf…

第7章 单行函数

1.函数的理解 *函数可以把我们经常使用的代码封装起来&#xff0c;需要的时候直接调用即可。这样既提高了代码效率&#xff0c;又提高了可维护性。在SQL中我们也可以使用函数对检索出来的数据进行函数操作。使用这些函数&#xff0c;可以极大地提高用户对数据库的管理效率。 …

微信小程序|基于小程序实现打卡功能

文章目录一、文章前言二、开发流程及准备三、开发步骤一、文章前言 此文主要在小程序内实现打卡功能&#xff0c;可根据用户位置与公司设定的打卡范围实时判断打卡场景。 二、开发流程及准备 2.1、注册微信公众平台账号。 2.2、准备腾讯地图用户Key。 三、开发步骤 3.1、访问…

【面试题常考!!!】JZ39 数组中出现次数超过一半的数字【五种方法解决】

欢迎观看我的博客&#xff0c;如有问题交流&#xff0c;欢迎评论区留言&#xff0c;一定尽快回复&#xff01;&#xff08;大家可以去看我的专栏&#xff0c;是所有文章的目录&#xff09; 字体风格&#xff1a; 红色文字表示&#xff1a;重难点 蓝色文字表示&#xff1a;思路以…

神经网络模型数据处理,神经网络模型参数辨识

1、有哪些深度神经网络模型&#xff1f; 目前经常使用的深度神经网络模型主要有卷积神经网络(CNN) 、递归神经网络(RNN)、深信度网络(DBN) 、深度自动编码器(AutoEncoder) 和生成对抗网络(GAN) 等。 递归神经网络实际.上包含了两种神经网络。一种是循环神经网络(Recurrent Neu…

STM32F4单片机读取AT24c02

​STM32F4是由ST&#xff08;意法半导体&#xff09;开发的一种高性能微控制器系列。其采用了90nm的NVM工艺和ART技术&#xff08;自适应实时存储加速器&#xff0c;Adaptive Real-Time MemoryAccelerator™&#xff09; AT24C02是Atmel公司出品的一个2K位串行CMOS E2PROM&…

【k8s】五、Pod生命周期(一)

目录 前言 Pod生命周期 Pod 相位 状态值 挂起&#xff08;Pending&#xff09; 运行中&#xff08;Running&#xff09; 成功&#xff08;Succeeded&#xff09; 失败&#xff08;Failed&#xff09; 未知&#xff08;Unknown&#xff09; Init Containers Init Cont…

pc端引擎颠覆电脑兼容性

张小龙曾在讲座上阐述小程序理念的精髓&#xff0c;小程序承载着张小龙及微信团队对未来程序形态的一种见解&#xff0c;总结为五个字&#xff1a;所见即所得。原文如下&#xff1a; 它是一种真正的所见即所得的形态&#xff0c;我说的所见即所得不同于在PC时代&#xff0c;我…

组合模式+桥接模式

目录 组合模式 定义&#xff1a; 业务实现例子&#xff1a; 桥接模式 JDBC中的桥接模式 组合模式 定义&#xff1a; 将对象组合通过树形结构进行展示&#xff0c;使得用户——>不管对单个对象or组合对象的使用具有一致性 可以理解为部分-整体模式——>简单来说就…

深度学习环境搭建

(1) 安装 Anaconda :建立 Python 应用环境 安装成功界面如下:(2) Visual Studio Code: 建立代码编辑环境 1.安装Python扩展2.选择合适的Python解释器 3.安装下列应用扩展:codeRunner : 快速运行程序 Jupyter : 交互式运行程序 Pylance : 高效代码提示 安装完成如图所示:4.创…

Linux基础组件之muduo日志库分析

muduo日志库分析异步日志机制双缓存机制前台日志写入栈后台日志(落盘)写入栈使用示例总结后言异步日志机制 #mermaid-svg-nrIugWYiOaAGFTWH {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-nrIugWYiOaAGFTWH .error-…

如何做架构规划

文章目录架构师的职责WhyWhatHow架构活动生命周期环境搭建目标确认可行性探索架构规划统一语义需求确认任务边界划分确认规划完整性项目启动阶段性价值交付复盘经历过的典型案例参考架构师的职责 Why 互联网架构活动的挑战较多&#xff0c;如&#xff1a; 反射式的研发行为。…

Scratch软件编程等级考试四级——20200913

Scratch软件编程等级考试四级——20200913理论单选题判断题实操奇偶之和创意画图数字之和用逗号分隔列表数字反转理论 单选题 1、执行下面程序&#xff0c;输入4和7后&#xff0c;角色说出的内容是&#xff1f;&#xff08;&#xff09; A、4&#xff0c;7 B、7,7 C、7,4 D、…

为什么会发生云中断?如何防范?

IT 越依赖云服务&#xff0c;用户就越有可能因云中断而遭受停机和收入损失。由于云中断事件的发生&#xff0c;超过 60% 的使用公共云的组织在 2022 年报告了损失&#xff0c;因此云中断并不是公司不太可能面临的异常事件。 但是中断是否足以成为永远离开云的理由?还是应该坚持…