AI深度学习入门与实战06 线性回归模型:在问题中回顾与了解基础概念

news/2024/4/30 19:06:32/文章来源:https://blog.csdn.net/fegus/article/details/127149990

学到这里相信你已经了解了深度学习,知道该如何训练模型、如何评估模型,并且掌握了一种叫作人工神经网络的算法。正所谓温故而知新,这节课我会通过一个线性回归模型将前面的知识串起来,做一个完整的回顾。

线性回归是一个最基本的传统机器学习算法,不算是深度学习的内容,但神经元其实就是线性回归模型。你以后学习物体检测的话,线性回归也会被用在物体检测问题中。

线性回归

那什么是线性回归呢?顾名思义,就是用一种线性关系进行回归。回归与分类都属于有监督学习,分类预测的是一个类别,而回归预测的是一个数值,例如房价、天气的温度、股票的走势。

假设我们有训练数据 D:

Drawing 1.png

其中 xi=(xi,1,xi,2,...,xi,n)。

我们希望找到一个线性关系 f(x),使 f(x)尽可能地模拟出数据 x 与 y 的关系,数学上的表达是f(xi)≈yi,也就是希望 f(x) 接受一个输入 x,它的输出能尽可能地接近真实结果 y。

线性关系 f(x)可以定义为:

Drawing 3.png

向量的表示方式为:

Drawing 5.png

其中 w 为权重,是一个列向量,我们定义为:

Drawing 7.png

b 是一个偏移项。

这个 f(x)就是我们说的线性回归模型。我们对 f(x)的训练,就是获得 f(x)中的 w 与 b。

我们可以看到数据集 D 中,每个 x 都是一个向量,每个 x 中有 n 个元素。当 n 大于 2 的时候,我们把 f(x)称为多元线性回归,当 n 等于 1 或者 2 的时候,我们一般直接称为一元线性回归或者二元线性回归。

当 n=1 时,f(x)就是一条直线,也就是说,我们希望找到一条直线拟合出数据的整体趋势。如下图所示:

Drawing 9.png

当 n=2 时,f(x)就是一个平面,这个时候我们希望用一个平面拟合出数据的整体趋势。如下图所示:

Drawing 11.png

模型训练与评估

我们先来回忆一下机器学习的流程。

  1. 确定好训练数据与评估数据。

  2. 选择损失函数、优化方法进行训练。

  3. 利用评估数据评估模型。

接下来我们就通过线性回归模型的训练与评估来复习一下以上流程。

第 1 步与第 3 步在这里就先略过,我在后面还会通过一个房价预测的案例带你回顾机器学习的流程,在那里会有相关的例子。

损失函数

对于训练数据 D,我们需要根据每一条数据(xi, yi)来训练线性回归模型。我们希望 f(x)的输出可以无限接近真实的结果 y,那我们怎样衡量预测结果与真实结果的差距呢?

最常用的方式是使用均方误差(Mean Squared Error, MSE)。均方误差是回归任务中经常使用的一种损失函数(cost function 或者是 loss function)。其定义如下:

Drawing 13.png

在均方误差中,xi 与 yi 是数据中的具体数值,自变量就是 f(x)的权重 w 与偏移项 b。所以均方误差 J 是一个关于权重 w 与偏移项 b 的函数。

既然均方误差是衡量我们预测值与真实值之间的差距的,那我们只要让它获得最小的 w 与 b 的值,就能够得到预测值最接近真实值的回归方程了。

模型优化

模型优化就是获得权重 w 与 b 的过程,这些 w 与 b 可以让损失函数取得最小值。在《04|函数与优化方法:模型的自我学习(上)》中我讲到了梯度下降的优化方法,现在我们就用这个方法来优化我们的权重 w 与 b。

我们先回忆一下梯度下降的原理。

假如我们在一座高山上,想走到山脚,但是我们又不认识路,那我们应该怎么办呢?我们可以以当前点为基准,向坡面最陡的方向前进,每前进一步,就依据现在的位置再寻找最陡的方向,这样一点点地前进,就可以成功到达山脚。

到达山脚的这个过程就是求 J(w, b)最小值的过程。

在给定点找到最陡的方向,就是函数中给定点的梯度 ,然后朝着梯度相反的方向前进,就能让函数值下降的最快。

确定了梯度,那么每次走多长呢,这取决于我们设置的学习率。

为了举例方便,我假设回归方式是一元回归方程,所以回归方程只有一个 w。

首先我们要计算 J(w, b)的梯度:

Drawing 15.png

其中,w 与 b 的偏导数为:

Drawing 17.png
Drawing 19.png

由于我们的训练样本 m 往往很大,采用以上的做法每次需要消耗巨大的计算资源。为了解决这个问题,随机梯度下降法与小批量梯度下降法营运而生。

接下来,我们就自己写一个每次只取一条数据的小批量随机梯度下降算法来看看效果如何,即每次计算梯度时,m=1。

首先我们生成数据,蓝色点为我们生成的数据。

np.random.seed(123)
w_real = 2
b_real = -15
xlim = [-15, 15]
x_gen = np.random.randint(low=xlim[0], high=xlim[1], size=30)
y_real = w_real * x_gen + b_real
plt.plot(x_gen, y_real, 'bo')

Drawing 21.png

我们假设权重 w=2, b=-15,生成训练数据。我们将 w 与 b 随机初始化,然后读取训练数据计算梯度,通过随机梯度下降法来更新 w 与 b,使其接近我们真实的值。过程如下:

def sgd(x, y, lr, epochs=1):# x 为训练集中的数据# y 为训练集中标签# lr 学习率# epochs 训练 epoch 数

    # 随机初始化 w 与 b
    w = random.random()
    b = random.random()
    for e in range(epochs):
        w_list.append(w)
        b_list.append(b)
        for x_ , y_ in zip(x, y):
            w = w - lr * x_ * (w * x_ + b - y_)
            b = b - lr * (w * x_ + b - y_)
    return w, b

用于记录 w 与 b 的变化

w_list = []
b_list = []
w, b = sgd(x_gen, y_real, 0.001, epochs=150
print(“w = {:.3f} , b = {:.3f}”.format(w, b))

输出:w = 2.006 , b = -14.732

设置学习率为 0.001,更新 150 个 Epoch。可以看到,最终输出的 w 与 b 已经很接近真实的 w 与 b 了。

讲到这里,我们就完成一个线性回归模型的完整训练了。

模型评估

当模型训练完成之后,我们需要对模型进行评估。评估指标可以回顾“03 课时”的内容,我也会在下面的这个例子中再次提到。

经典问题:房价预测

这一节我会用一个经典的波士顿房价预测实例,带你完整地回顾一遍机器学习的整个流程,并简单地进行一次问题排查。

数据加载

房价预测的数据集已经被包含在 sklearn 中了。所以我们可以通过 sklearn 直接加载房价数据。加载完数据之后,调用 train_test_split 方法,按照一定的比例将数据集拆成训练集与测试集。如下所示:

from sklearn import datasets
# 分割数据的模块,把数据集分为训练集和测试集
from sklearn.model_selection import train_test_split
# 导入回归方法
from sklearn.linear_model import SGDRegressor
boston_data = datasets.load_boston()
# sklearn 把数据分为了 data(输入)与 target(输出)两部分
data_X = boston_data.data
data_y = boston_data.target
# 将数据集分割成训练集与测试集,其中测试集占 30%
X_train, X_test, y_train, y_test = train_test_split(data_X,data_y, test_size=0.3)

模型训练

我们使用 sklearn 的 SGDRegressor 回归来进行预测,它的优化方法就是我刚才介绍的随机梯度下降。训练时我们会进行如下设置:

  • loss 为 squared_loss;

  • 学习率为 0.001;

  • 迭代次数为 2000 次。

model = SGDRegressor(loss='squared_loss', l1_ratio = 0.001, max_iter=2000, verbose=1)
model.fit(X_train, y_train)

我们来观察一下 loss:

Drawing 23.png

我们可以看到,loss 存在很大问题:loss 的数量级是 e 的 29 次幂。

任何机器学习项目,首要任务都是要观察数据、分析数据。把数据搞清楚之后,我们才能着手训练。所以,我们要回过头来看一下数据。盲目的训练是不会取得很好的结果的。

sklearn 把波士顿房价数据以一个字典的形式返回,我们可以看一下字典的 key 都有哪些:

boston_data.keys()
输出如下:
dict_keys(['data', 'target', 'feature_names', 'DESCR', 'filename'])

其中有一个 key 叫作 DESCR,它是用来描述我们数据集的信息的,如下所示:

print(boston_data['DESCR'])
输出如下:
.. _boston_dataset:
Boston house prices dataset
---------------------------
**Data Set Characteristics:**:Number of Instances: 506 :Number of Attributes: 13 numeric/categorical predictive. Median Value (attribute 14) is usually the target.:Attribute Information (in order):- CRIM     per capita crime rate by town- ZN       proportion of residential land zoned for lots over 25,000 sq.ft.- INDUS    proportion of non-retail business acres per town- CHAS     Charles River dummy variable (= 1 if tract bounds river; 0 otherwise)- NOX      nitric oxides concentration (parts per 10 million)- RM       average number of rooms per dwelling- AGE      proportion of owner-occupied units built prior to 1940- DIS      weighted distances to five Boston employment centres- RAD      index of accessibility to radial highways- TAX      full-value property-tax rate per $10,000- PTRATIO  pupil-teacher ratio by town- B        1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town- LSTAT    % lower status of the population- MEDV     Median value of owner-occupied homes in $1000's:Missing Attribute Values: None:Creator: Harrison, D. and Rubinfeld, D.L.
This is a copy of UCI ML housing dataset.
https://archive.ics.uci.edu/ml/machine-learning-databases/housing/

This dataset was taken from the StatLib library which is maintained at Carnegie Mellon University.
The Boston house-price data of Harrison, D. and Rubinfeld, D.L. ‘Hedonic
prices and the demand for clean air’, J. Environ. Economics & Management,
vol.5, 81-102, 1978. Used in Belsley, Kuh & Welsch, ‘Regression diagnostics
…’, Wiley, 1980. N.B. Various transformations are used in the table on
pages 244-261 of the latter.
The Boston house-price data has been used in many machine learning papers that address regression
problems.

… topic:: References

  • Belsley, Kuh & Welsch, ‘Regression diagnostics: Identifying Influential Data and Sources of Collinearity’, Wiley, 1980. 244-261.
  • Quinlan,R. (1993). Combining Instance-Based and Model-Based Learning. In Proceedings on the Tenth International Conference of Machine Learning, 236-243, University of Massachusetts, Amherst. Morgan Kaufmann.

我们可以看到这个数据集一共有 506 条记录、13 个属性,并且没有缺失值。其实,这些内容即使我们不看它的描述,也应该通过代码自己确认一遍。

我们再具体看一下房价的数据:

import pandas as pd
# 加载数据到dataframe中
df = pd.DataFrame(data=boston_data['data'], columns=boston_data['feature_names'])
df.head()
输出如下:
CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT
0.00632	18.0	2.31	0.0	0.538	6.575	65.2	4.0900	1.0	296.0	15.3	396.90	4.98
0.02731	0.0	7.07	0.0	0.469	6.421	78.9	4.9671	2.0	242.0	17.8	396.90	9.14
0.02729	0.0	7.07	0.0	0.469	7.185	61.1	4.9671	2.0	242.0	17.8	392.83	4.03
0.03237	0.0	2.18	0.0	0.458	6.998	45.8	6.0622	3.0	222.0	18.7	394.63	2.94
0.06905	0.0	2.18	0.0	0.458	7.147	54.2	6.0622	3.0	222.0	18.7	396.90	5.33

你有没有发现一个问题,就是每个属性的取值范围是不一样的,CRIM 那列普遍很小,B 的那列却取值很大。

在这个时候,你就要考虑是否应该把所有训练数据都正规化到一个取值范围里了,否则会影响到模型的训练。关于正规化,将会在《12 | 数据预处理:让模型学得更好》中介绍。

当然,在这种传统机器学习的项目中,挑选特征是最重要的事情。所以我们先看看是否使用全量特征进行训练。

我们来计算一下各个特征(包括目标属性)的相关性。

相关性的取值范围是-1 到 1,越接近 1 或者-1 代表越相关,越接近 0 则越不相关。相关系数大于 0 称为正相关,小于 0 称为负相关。

假如 A 与 B 正相关,则是说 A(B)会随着 B(A)的增大而增大,减小而减小。

假如 A 与 B 负相关,则是说 A(B)会随着 B(A)的增大而减小,减小而增大。

下面的例子我采用皮尔逊相关性探索属性间的关系,具体计算方式在这里我就不介绍了,你可以自行查询了解。它常用于传统的机器学习项目,在深度学习中很少使用。

# 观察属性间的相关性
import seaborn as sns
df['PRICE'] = boston_data.target
correlation_matrix = df.corr().round(2)
plt.figure(figsize=(10, 10))
sns.heatmap(data=correlation_matrix, annot=True)

Drawing 25.png

从上图我们可以发现两点:

  1. RM 与 PRICE 有很强的正相关性(0.7),而 LSTAT 与 PRICE 有很强的负相关性(-0.74)。

  2. 在线性模型中,很重要的一点是考虑是否剔除相关性很高的属性。RAD 与 TAX 有非常高的正相关性(0.91),DIS 与 AGE 的负相关性(-0.75)很高。

我们把第 1 点发现中与 price 相关的两个属性可视化之后,可以观察到正相关与负相关确实如我们刚才所介绍的那样。

Drawing 13.png

RM 与 LSTAT 算是两个非常重要的特征了,我们下面做 2 组实验对比一下,看看我们模型的模型是否收敛。

  • 实验 1:只用 RM 预测 PRICE

  • 实验 2:只用 LSTAT 预测 PRICE

实验 1 的训练 loss:

Drawing 28.png

实验 2 的 loss:

Drawing 30.png

实验 1 与实验 2 的 loss 都很好地收敛了,并且实验 2 收敛得更快一些。接下来我们可以看看评估结果。

模型评估

模型评估我们依然使用 sklearn 中自带的均方误差来进行。

实验 1 的评估:

from sklearn.metrics import mean_squared_error
y_predict = model.predict(X_test)
# 使用 mean_squared_error 模块,并输出评估结果。
print('The mean squared error of SGDRegressor is', mean_squared_error(y_test, y_predict))
Output:The mean squared error of SGDRegressor is 38.61589881811603

实验 2 的评估:

Output:The mean squared error of SGDRegressor is 49.33770954798598

从评估结果上看实验 1 要好一些。但你可能会好奇,实验 2 收敛得更快,为什么实验 2 的评估结果不好呢?这很正常,因为在验证集上的表现和训练时的收敛速度并没有直接关系。

实际问题往往更加复杂,如果在训练集上表现好,但在验证集中表现差,我们可能需要再好好地分析一下数据集哪里出现了问题,是不是训练集没有覆盖到验证集。

总结

这一课时我介绍了线性回归模型,那我们可不可以把线性回归模型用于分类问题呢?答案是可以的,只需要把回归方程 f(x)输出加一个 sigmoid 函数就可以了。

你可能会觉得这跟我们讲的神经元很像,其实就是一回事。就像我在介绍人工神经网络时讲到,一个神经元可以处理一个线性关系,有足够多个神经元就可以模拟出任意的非线性关系。

本课时通过线性回归模型对机器学习的整个流程进行了复习,并用波士顿房价预测的问题进行了实际的举例,课余时间你可以试一试同时使用 RM 与 LSTAT 对 PRICE 进行回归,或者提前对输入数据进行规则化处理,看看模型的效果会不会变好。

关于线性回归模型你还有什么疑问,欢迎在留言区留言。本节课涉及的代码我已经传到了 GitHub,你可以点击链接查看。

下一节,我将带你了解卷积神经网络(CNN),让你了解模型是如何观察世界的。


精选评论

**升:

打卡

**3185:

第六节的代码是python语言吗,我复制到python中无法运行

    讲师回复:

    本课程中提供的代码,不是完整版本的,只提供了最为核心的部分,要想真正执行,你需要增加更多的环节,这些会在后续介绍tensorflow的时候介绍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_17860.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【牛客刷题--SQL篇】基础查询 SQL3查询结果去重SQL4查询结果限制返回行数SQL5将查询后的列重新命名

个人主页:与自己作战 牛客刷题系列篇:【SQL篇】【Python篇】【Java篇】 推荐刷题网站注册地址:【牛客网–SQL篇】 推荐理由:从0-1起步,循序渐进 文章目录一、基础查询1、基础查询1.1 SQL3 查询结果去重1.2 SQL4查询结果限制返回行…

Go设计模式学习准备——下载bilibili合集视频

需求 前段时间面试,被问到设计模式。说实话虽然了解面向对象、多态,但突然被问到设计模式,还要说清解决什么问题,自己是有些懵的,毕竟实习主要工作是在原项目基础进行CRUD,自己还是没有深度思考,所以只能简单介绍自己知道的简单工厂模式等。趁着回家这段假期,充电学习一…

VS2022编译错误:链接器工具错误 LNK2005

产生原因 自己在学习namespace时,参照C++ plus一书写了基本相同的代码,分别定义了h文件和两个CPP文件,其中一个CPP用来定义变量,一个CPP用来跑main(void)。文件代码如下: head.h文件 #pragma once #include<string> #include <iostream>namespace pers {struc…

对于在网页中实现数据库数据的遍历

基于套用了网站模板,然后连接了数据库 具体实现如下:1、定义一个实体类Uesr package org.example;public class User {private String name;private String id;private String teacher;private String whe;public void setName(String name){this.name=name;}public String g…

【ES6丨前端进阶基础 】ES6的关键字,新特性以及解构赋值

&#x1f482; 个人主页&#xff1a;Aic山鱼 个人社区&#xff1a;山鱼社区 &#x1f4ac; 如果文章对你有所帮助请点个&#x1f44d;吧!欢迎关注、点赞、收藏(一键三连)和订阅专目录 前言 什么是ecmascrpit 一&#xff0c;let关键字的特点 1.不能重复声明变量 2.块级作用…

【持久层框架】- SpringData - JPA

SpringData - JPA &#x1f604;生命不息&#xff0c;写作不止 &#x1f525; 继续踏上学习之路&#xff0c;学之分享笔记 &#x1f44a; 总有一天我也能像各位大佬一样 &#x1f3c6; 一个有梦有戏的人 怒放吧德德 &#x1f31d;分享学习心得&#xff0c;欢迎指正&#xff0c;…

图解mysql(六)——内存篇

6 揭开Buffer Pool的面纱 6.1 为什么要有Buffer Pool&#xff1f; 虽然说 MySQL 的数据是存储在磁盘里的&#xff0c;但是也不能每次都从磁盘里面读取数据&#xff0c;这样性能是极差的。 要想提升查询性能&#xff0c;加个缓存就行了嘛。所以&#xff0c;当数据从磁盘中取出…

基于BootStrap的农业信息数据收集与管理平台设计

目 录 1 前言 1 1.1 研究背景 1 1.2 研究内容 1 1.3 研究意义 1 1.4 本论文工作和章节内容 2 2 系统环境与技术支持 3 2.1 系统环境 3 2.2 B/S架构 3 2.3 服务端技术 3 2.3.1 MVC模式 3 2.3.2 Spring框架 4 2.3.3 Hibernate框架 4 2.3.4 MySQL数据库 5 2.4 前端技术 5 2.4.1 Bo…

基于JavaWeb的企业出差费用报销管理系统设计与实现

目录 第一章 绪论 1 1.1 出差报销管理系统的开发背景 1 1.2设计目的与意义 1 第二章 系统需求分析 2 2.1 可行性分析 2 2.1.1 操作可行性 2 2.1.2 经济可行性 2 2.1.3 技术可行性 2 2.2方案的设计与比较 2 2.2.1 C/S设计结构和B/S设计结构比较 2 2.2.2 系统模式的设计 3 2.2.3系…

大二数据库实验-MySQL语句(Employee、Department、Salary)

实验所用到的的几张表&#xff1a; 显示Employee表中姓王的记录。 显示salary 表中InCome大于2000的数据。 显示salary 表中InCome在2500到3000之间的数据。 显示Employee表中在1968年下半年出生的数据。 分别显示三个表中总记录条数。 显示salary表中收入和支出总…

(附源码)计算机毕业设计ssm财务管理系统

毕设帮助&#xff0c;指导&#xff0c;本源码分享&#xff0c;调试部署(见文末) 3.2.1系统开发流程 财务管理系统开发时&#xff0c;首先进行需求分析&#xff0c;进而对系统进行总体的设计规划&#xff0c;设计系统功能模块&#xff0c;数据库的选择等&#xff0c;本系统的…

权限控制WAF绕过之木马混淆免杀

目录 前言&#xff1a; 绕过思路&#xff1a; &#xff08;一&#xff09;变量覆盖 &#xff08;二&#xff09;加密混杂 未加密前&#xff1a; 接口加密 加密后&#xff1a; WAF验证&#xff1a; &#xff08;三&#xff09;异或生成 0x01 原理&#xff1a; 0x02 适用…

线段树什么的最讨厌了

发现如果正着从一颗线段树搜到这一个区间,很难搜。所以考虑从一个区间搜出一颗线段树。 对于一个区间 \([l,r]\),他的父亲区间只可能是 \([2*l-r-2,r],[2*l-r-1,r],[l,2*r-l],[l,2*r-l-1]\) 四种情况。 发现无论往哪一种方向走,\(\frac{l}{r-l+1}\) 的值都会除以2.那么这么做…

建立对单片机/嵌入式启动、运行的整体认知

文章目录一、51单片机的启动过程二、STM32的完整启动流程分析1. 根据boot引脚决定三种启动模式2. 启动后bootloader做了什么&#xff1f;3. bootloader中对内存的搬移和初始化4. ISP、IAP、ICP三种烧录方式5. 参考资料从上电到启动&#xff0c;一文读懂STM32启动全流程1、直接上…

m基于FPGA的cordic算法实现,输出sin和cos波形(包括仿真录像)

目录 1.源码获取方式 2.算法描述 3.部分程序 4.部分仿真图预览 1.源码获取方式 使用版本matlab2022a 获取方式1&#xff1a; 点击下载链接&#xff08;解压密码C123456&#xff09;&#xff1a; m基于FPGA的cordic算法实现,输出sin和cos波形 获取方式2&#xff1a; 如…

程序员的数学课21 神经网络与深度学习:计算机是如何理解图像、文本和语音的?

在上一讲的最后&#xff0c;我们提到过“浅层模型”和“深层模型”。其实&#xff0c;人工智能的早期并没有“浅层模型”的概念&#xff0c;浅层模型是深度学习出现之后&#xff0c;与之对应而形成的概念。在浅层模型向深层模型转变的过程中&#xff0c;神经网络算法无疑是个催…

Vue2 生命周期

Vue 生命周期 概述在使用 Vue 时,我们需要执行一些 JS 代码。比如我们需要在页面中添加一个定时器来固定间隔更新时间。这时我们可能会想到直接在,Vue 实例外书写 JS 代码。这种方法能完成操作,但是 Vue 并不建议这样写。Vue 建议尽量在 Vue 实例中完成所有的操作。这时我们…

Hadoop3.X安装教程(Ubuntu)

前提:一台纯净的Ubuntu机器(虚拟机安装教程略) ctrl + alt + T 打开bash,全程使用bash指令进行,以hadoop 和 java 8为例 首先换源进入root账户 sudo su -升级软件列表 apt-get update安装vim apt install vim中途询问直接输入Y确认下载hadoop和java 创建/data mkdir /data…

半导体中的缺陷和位错能级

点缺陷&#xff1a; 在一定的温度下&#xff0c;组成晶体的格点原子在平衡位置附近做振动&#xff0c;这些振动就会有强有弱&#xff0c;这样会使得一部分原子可以获得足够的能量&#xff0c;而挣脱周围电子对它的束缚&#xff0c;挤入间隙位置&#xff0c;这样的结果就形成了…

211西北大学,计算机、软件学硕和专硕专业课都变难了!

西北大学位于陕西省西安市&#xff0c;是一所211大学。西北大学计算机学科评估B-&#xff0c;软件工程学科评估B&#xff0c;计算机实力在211大学中处于中上游水平&#xff0c;还算不错。西北大学前段时间公布了23考研的招生目录&#xff0c;我们来看一下&#xff1a;西北大学2…