深度学习——优化器Optimizer

news/2024/4/24 9:14:21/文章来源:https://blog.csdn.net/Elon15/article/details/131649932

代码以及详细注释:

import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as plt# torch.manual_seed(1)    # reproducible
"""超参数
"""
# 学习率
LR = 0.01
# 批大小
BATCH_SIZE = 32
# 轮次
EPOCH = 12# 造数据
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))# # plot dataset
# plt.scatter(x.numpy(), y.numpy())
# plt.show()# put dateset into torch dataset
torch_dataset = Data.TensorDataset(x, y)
# 数据加载器
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)# default network
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.hidden = torch.nn.Linear(1, 20)   # hidden layerself.predict = torch.nn.Linear(20, 1)   # output layerdef forward(self, x):x = F.relu(self.hidden(x))      # activation function for hidden layerx = self.predict(x)             # linear outputreturn xif __name__ == '__main__':# 相同的网络结构net_SGD         = Net()net_Momentum    = Net()net_RMSprop     = Net()net_Adam        = Net()# 将上面的网络集成到这里nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]# 不同的优化器opt_SGD         = torch.optim.SGD(net_SGD.parameters(), lr=LR)opt_Momentum    = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)opt_RMSprop     = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)opt_Adam        = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))# 将上面的优化器集成到这里optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]# 损失函数loss_func = torch.nn.MSELoss()losses_his = [[], [], [], []]   # record loss# 训练轮次for epoch in range(EPOCH):print('Epoch: ', epoch)# 分批训练for step, (b_x, b_y) in enumerate(loader):          # for each training stepfor net, opt, l_his in zip(nets, optimizers, losses_his):output = net(b_x)              # get output for every netloss = loss_func(output, b_y)  # compute loss for every netopt.zero_grad()                # clear gradients for next trainloss.backward()                # backpropagation, compute gradientsopt.step()                     # apply gradientsl_his.append(loss.data)        # loss recoder# 绘图labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']for i, l_his in enumerate(losses_his):plt.plot(l_his, label=labels[i])plt.legend(loc='best')plt.xlabel('Steps')plt.ylabel('Loss')plt.ylim((0, 0.2))plt.show()

运行结果:

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_328917.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

你知道什么是基于StyleNeRF的conditional GAN模型吗

随着深度学习技术的不断发展,生成对抗网络(GAN)已经成为了人工智能研究和应用中的重要组成部分。其中,GAN可以被用来生成高质量的图像、视频等内容,这为娱乐产业和数字化制作带来了新的机遇和挑战。本文将介绍一种基于…

[探地雷达]预处理

探地雷达预处理方式主要包括:雷达数据的编辑、校正零偏、时间零 点校正、滤波、自动增益等。 (1)数据编辑 数据文件在采集后通常需要进行排序、重新排列或合并处理;而数据中明显的废道信号,也需要从回波数据中剔除。…

Gof23设计模式之桥接模式

1.概述 桥接模式(Bridge Pattern)是一种结构型设计模式,它将抽象部分与实现部分分离,使它们可以独立地变化。它的核心思想就是将一个大类或一系列紧密关联的类拆分成两个独立的抽象和实现部分,以便能够更加灵活地扩展…

Windows电源模式(命令行)

一、简介 windows使用powercfg.exe来控制电源方案,像cmd.exe一样,powercfg.exe也是windows自带的。 powercfg命令行选项 选项说明/?、-help显示有关命令行参数的信息。/list、/L列出所有电源方案。/query、/Q显示电源方案的内容。

TCP连接管理(三次握手,四次挥手)

目录 一、回顾一下TCP包头二、连接的建立——“三次握手”三、连接的建立——“四次挥手”保活计时器 一、回顾一下TCP包头 源端口号(Source Port):16 位字段,表示发送方的端口号。 目的端口号(Destination Port&…

有趣的命令——————用shell脚本实现三角形(直角三角形、等腰三角形)

直角三角形 vim zhijiao.sh 输入以下内容&#xff1a;#!/bin/bashread -p "几层的三角形:" nfor ((i1;i<$n;i)) dofor ((jn;j>i;j--))doecho -n " "donefor ((j1;j<i;j))doecho -n "*"doneecho done例&#xff1a; 测试&#xff1a; 等…

【QT】——布局

目录 1.在UI窗口中布局 2.API设置布局 2.1 QLayout 2.2 QHBoxLayout 2.3 QVBoxLayout 2.4 QGirdLayout 注意 示例 Qt 窗口布局是指将多个子窗口按照某种排列方式将其全部展示到对应的父窗口中的一种处理方式。在 Qt 中常用的布局样式有三种&#xff0c;分别是&#xff1…

jmeter的高阶使用技巧——打印时间戳与年月时分秒

Jmeter中提供了一种函数&#xff0c;可以打印时间戳&#xff0c;如下图 年&#xff1a; yyyy 月&#xff1a;MM 日&#xff1a;dd   时&#xff1a; HH 分&#xff1a; mm 秒&#xff1a;ss 关于时间戳的格式&#xff0c;可以自由组合定义&#xff0c;这里我写成这样 yyyy-M…

【C++】float / double 与 0 值比较

【C】float / double 与 0 值比较 文章目录 【C】float / double 与 0 值比较1. 概述不同1.1 - float 与 double 实际存储1.2 - C 语言与 C 中不同 2. 比较方法2.1 - C 风格比较2.2 - 使用 limits 函数 3. 参考链接 References 1. 概述不同 当然使用普通的比较没有问题&#xf…

【七天入门数据库】第一天 MySQL的安装部署

系列文章传送门&#xff1a; 【七天入门数据库】第一天 MySQL的安装部署 【七天入门数据库】第二天 数据库理论基础 【七天入门数据库】第三天 MySQL的库表操作 MySQL数据库存在多种版本&#xff0c;不同的版本在不同的平台上&#xff08;OS&#xff0c;也就是操作系统上&a…

msvcr120.dll找不到是什么原因,怎样修复

msvcr120.dll的定义 msvcr120.dll是微软Visual C Redistributable软件包中的一个动态链接库文件。它是Microsoft Visual 所需的一个重要组件。这个文件主要用于支持和管理C语言编写的应用程序的运行。它包含了许多C的运行库函数和类&#xff0c;以便应用程序能够正常运行和调用…

Python(一):为什么我们要学习Python?

❤️ 专栏简介&#xff1a;本专栏记录了我个人从零开始学习Python编程的过程。在这个专栏中&#xff0c;我将分享我在学习Python的过程中的学习笔记、学习路线以及各个知识点。 ☀️ 专栏适用人群 &#xff1a;本专栏适用于希望学习Python编程的初学者和有一定编程基础的人。无…

精确时钟同步协议ptp/IEEE-1588v2协议-------(2)主从时钟之间的消息交互与时钟同步过程

本文目录 1、主时钟和从时钟之间的消息交互流2、延时delay和偏移offset的计算2.1、延时delay的计算2.2、偏移offset的计算 主时钟和从时钟之间&#xff0c;通过sync, follow up, delay request, delay response这四条消息&#xff0c;完成时钟同步过程。PTP时钟同步系统能工作的…

3.8.cuda运行时API-使用cuda核函数加速yolov5后处理

目录 前言1. Yolov5后处理2. 后处理案例2.1 cpu_decode2.2 gpu_decode 总结 前言 杜老师推出的 tensorRT从零起步高性能部署 课程&#xff0c;之前有看过一遍&#xff0c;但是没有做笔记&#xff0c;很多东西也忘了。这次重新撸一遍&#xff0c;顺便记记笔记。 本次课程学习精简…

vue 2.0 的使用

day01 1. Vue简介 一套用于构建用户界面的 <font colorred>渐进式框架</font> 2. 初识Vue 2.1 搭建Vue开发环境 第一步&#xff1a;去<a href"https://v2.cn.vuejs.org/">Vue2官网</a>&#xff0c;下载依赖包。 第二步&#xff1a;在 …

多线程与并发编程【线程休眠、线程让步、线程联合、判断线程是否存活】(二)-全面详解(学习总结---从入门到深化)

目录 线程休眠 线程让步 线程联合 Thread类中的其他常用方法 判断线程是否存活 线程的优先级 线程休眠 sleep()方法&#xff1a;可以让正在运行的线程进入阻塞状态&#xff0c;直到休眠时间 满了&#xff0c;进入就绪状态。sleep方法的参数为休眠的毫秒数。 public class…

两部搞定Pytorch 安装与配置(小白也能搞定!!!)

Pytorch 安装与配置 NVIDIA系统管理界面查看 nvidia-smi 进入NVIDIA系统管理界面 对应的详细解释看下图 参考博文 (53条消息) nvidia-smi命令详解和一些高阶技巧介绍_Chaos_Wang_的博客-CSDN博客 CUDA 查看 CUDA 有两类&#xff1a;其中一类是驱动API(Driver API)&#xff…

实现windows系统文件传输到Linux系统中的工具

1、实现windows系统文件传输到Linux系统中的工具 yum -y install lrzsz然后就可以将windows中的文件&#xff0c;直接拖到Xshell窗口即可。

【钱处理】商业计算怎样才能保证精度不丢失

以项目驱动学习&#xff0c;以实践检验真知 前言 很多系统都有「处理金额」的需求&#xff0c;比如电商系统、财务系统、收银系统&#xff0c;等等。只要和钱扯上关系&#xff0c;就不得不打起十二万分精神来对待&#xff0c;一分一毫都不能出错&#xff0c;否则对系统和用户来…

Kafka入门,mysql5.7 Kafka-Eagle部署(二十五)

官网 https://www.kafka-eagle.org/ 下载解压 这里使用的是2.0.8 创建mysql数据库 创建名为ke数据库,新版本会自动创建&#xff0c;不会创建的话&#xff0c;自己手动创建&#xff0c;不然会报查不到相关表信息错误 SET NAMES utf8; SET FOREIGN_KEY_CHECKS 0;-- ------…