最容易理解的优化器的公式表示(BGD,SGD,SGDM,Adagrad,RMSProp,Adam)

在深度学习的训练中,我们依赖于权重系数来构建和调整模型。这些权重系数需要通过反向传播算法进行更新,以实现模型训练的目标。

权重的更新是基于梯度来计算的,而梯度的获取方式并非单一,它们由不同类型的优化器提供。这就是我们今天要深入探讨的优化器的角色。

首先为了后续符号的统一,我们先引入一个最为常见的损失函数: L = ∑ i = 1 m ( y i − f w , b ( x i ) ) 2 L= \sum_{i=1}^m(y_i - f_{w,b}(\bold{x}_i))^2 L=i=1m(yifw,b(xi))2这个损失函数是计算全体数据样本的损失函数,一般来说,只是随机采用一个样本 i i i 进行权重更新的话,则可以将损失函数写成: L ( i ) = ( y i − f w , b ( x i ) ) 2 L(i)= (y_i - f_{w,b}(\bold{x}_i))^2 L(i)=(yifw,b(xi))2如果要计算其基于 w j w_j wj 的偏导数,也就是梯度的话,我们可以得到 ∂ L ( i ) ∂ w j \frac{\partial{L(i)}}{\partial{w_j}} wjL(i),考虑上时间 t t t 来表示更新前( t − 1 t-1 t1)和更新后( t t t)的区别,可以写成: ∂ L ( i ) ∂ w t − 1 , j \frac{\partial{L(i)}}{\partial{w_{t-1,j}}} wt1,jL(i),因为梯度是根据前一个时刻的权重系数 w t − 1 , j w_{t-1,j} wt1,j 计算的。

后续为了简化,我会写成 ∇ L ( w t − 1 , j ) i \nabla L(w_{t-1,j})_i L(wt1,j)i,来表示损失函数在 t − 1 t-1 t1 时刻关于 w j w_j wj 的偏导数,也就是梯度。

按照我的理解,通过前一时刻( t − 1 t-1 t1)的权重系数计算出来的梯度,应该是算是前一时刻的梯度?

但是 Pytorch 官方实现各种优化器的时候,都是这么表示梯度的: g t ← ∇ θ f t ( θ t − 1 ) g_t \leftarrow \nabla_{\theta}f_t(\theta_{t-1}) gtθft(θt1),这就意味着,Pytorch 官方定义,通过 t − 1 t-1 t1 时刻的权重系数计算出来的事当前 t t t 时刻的梯度。

但是,我为了方便理解,将 t − 1 t-1 t1 时刻权重计算得到的梯度表示为 t − 1 t-1 t1 时刻的梯度,即 g t − 1 g_{t-1} gt1,这样对我来说更好理解一些,但是暂时还不知道坏处是啥?如果有大佬可以告诉我是否可行,我感觉梯度因为是实时计算的,不需要更新的,标不标准时刻感觉都可以。

接下来就来到正题,介绍各个优化器的公式:

BGD (Batch Gradient Descent)

其实这种方式和基于全体样本的梯度计算方式比较类似,或者说是一样的。

将一个 Batch 的数据( m m m 个样本)都考虑进来,我们要更新 w j w_j wj 权重,就对其求 w j w_j wj 的偏导,然后可以得到 ∇ L ( w t , j ) i \nabla L(w_{t,j})_i L(wt,j)i,但是因为这是 m m m 个样本的梯度求和,

所以我们需要求其平均梯度: g t − 1 , j = 1 m ⋅ ∑ i = 1 m ∇ L ( w t − 1 , j ) i g_{t-1, j} = \frac{1}{m}\cdot\sum_{i=1}^m\nabla L(w_{t-1,j})_i gt1,j=m1i=1mL(wt1,j)i权重更新公式可以写成: w t , j = w t − 1 , j − η ⋅ g t − 1 , j w_{t,j}= w_{t-1, j}-\eta \cdot g_{t-1,j} wt,j=wt1,jηgt1,j因为我们所有的权重都是和 w j w_j wj 相关的,所以为了让公式更简洁一点,我们就不必显式的将 j j j 的下标标注出来了,后面的公式都采用一样的操作。

也就是我们的 BGD 的公式可以写成: g t − 1 = 1 m ⋅ ∑ i = 1 m ∇ L ( w t − 1 , j ) i g_{t-1} = \frac{1}{m}\cdot\sum_{i=1}^m\nabla L(w_{t-1,j})_i gt1=m1i=1mL(wt1,j)i w t = w t − 1 − η ⋅ g t − 1 w_{t}= w_{t-1}-\eta \cdot g_{t-1} wt=wt1ηgt1

MBGD (Mini-Batch Gradient Descent)

这种方式是 BGD 和 SGD 的一种折中处理方式,采用一个 mini batch.

简单举个例子,采用一个 batchsize 为 k k k 的小样本集,权重系数的更新公式可以写成下面的形式:

梯度和权重更新公式可以表示为: g t − 1 = 1 k ⋅ ∑ i i + k − 1 ∇ L ( w t − 1 , j ) i g_{t-1} = \frac{1}{k}\cdot\sum_{i}^{i+k-1} \nabla L(w_{t-1, j})_i gt1=k1ii+k1L(wt1,j)i w t = w t − 1 − η ⋅ g t − 1 w_{t} = w_{t-1} - \eta \cdot g_{t-1} wt=wt1ηgt1除了梯度计算的样本数变少了,其他是完全一样的。

SGD (Stochastic Gradient Descent)

随机梯度下降,这个就不需要考虑全体样本了,因为全体样本毕竟还是耗时耗力,只要随机抽取一个样本 i i i 来计算梯度就可以了: g t − 1 = ∇ L ( w t − 1 , j ) i g_{t-1} = \nabla L(w_{t-1, j})_i gt1=L(wt1,j)i w t = w t − 1 − η ⋅ g t − 1 w_{t} = w_{t-1} - \eta \cdot g_{t-1} wt=wt1ηgt1只这种情况是最简单的了。

SGDM (Stochastic Gradient Descent with Momentum)

在随机梯度下降的基础上,加上一阶动量。

梯度: g t − 1 = ∇ L ( w t − 1 , j ) i g_{t-1}=\nabla L(w_{t-1, j})_i gt1=L(wt1,j)i一阶动量系数更新: v t = λ ⋅ v t − 1 + ( 1 − λ ) ⋅ g t − 1 v_{t}= \lambda \cdot v_{t-1} + (1-\lambda)\cdot g_{t-1} vt=λvt1+(1λ)gt1权重系数更新: w t = w t − 1 − η ⋅ v t w_{t}= w_{t-1} - \eta \cdot v_{t} wt=wt1ηvt

Adagrad

这里的梯度可以写成 g t − 1 = ∇ L ( w t − 1 , j ) i g_{t-1}=\nabla L(w_{t-1, j})_i gt1=L(wt1,j)i权重系数更新: w t = w t − 1 − η ⋅ 1 ∑ i = 0 t − 1 g i 2 + ϵ ⋅ g t − 1 w_{t} = w_{t-1} - \eta \cdot \frac{1}{\sqrt{\sum_{i=0}^{t-1}g_i^2} +\epsilon} \cdot g_{t-1} wt=wt1ηi=0t1gi2 +ϵ1gt1注意:这里面求和的 i i i 并不是指样本 i i i,只是对之前所有时间 t − 1 t-1 t1 求和的一个表示。

RMSProp

和 Adagrad 基本类似,只是加入了迭代衰减(二阶动量)

梯度: g t − 1 = ∇ L ( w t − 1 , j ) i g_{t-1} = \nabla L(w_{t-1,j})_i gt1=L(wt1,j)i二阶动量参数(其中 q 0 = g 0 2 q_0 = g_0^2 q0=g02 ): q t = α ⋅ q t − 1 + ( 1 − α ) ⋅   g t − 1 2 q_{t} = \alpha \cdot q_{t-1} + (1-\alpha)\cdot\ g_{t-1}^2 qt=αqt1+(1α) gt12权重系数: w t = w t − 1 − η ⋅ 1 q t + ϵ ⋅ g t − 1 w_{t} = w_{t-1} - \eta \cdot \frac{1}{\sqrt{q_t}+\epsilon} \cdot g_{t-1} wt=wt1ηqt +ϵ1gt1

Adam

这里我为 Adam 选择计算梯度的方式是随机小批量,也就是类似 MBGD 的方式,但是实际上,Adam 可以根据数据集的特点自己选择全量样本,随机样本或者随机小批量来进行梯度计算:
g t − 1 = 1 k ∑ i i + k − 1 ∇ L ( w t − 1 , j ) i g_{t-1} = \frac{1}{k}\sum_i^{i+k-1}\nabla L(w_{t-1,j})_i gt1=k1ii+k1L(wt1,j)i累计平方梯度(二阶动量,和 RMSProp 类似):
q t = α ⋅ q t − 1 + ( 1 − α ) ⋅ g t − 1 2 q_{t} = \alpha \cdot q_{t-1} + (1-\alpha)\cdot g_{t-1}^2 qt=αqt1+(1α)gt12累计梯度(一阶动量,和 SGDM 类似): v t = λ ⋅ v t − 1 + ( 1 − λ ) ⋅ g t − 1 v_{t} = \lambda \cdot v_{t-1} + (1-\lambda)\cdot g_{t-1} vt=λvt1+(1λ)gt1修正偏差: v ~ t = v t 1 − λ t ,        q ~ t = q t 1 − α t \tilde{v}_{t} = \frac{v_{t}}{1-\lambda^{t}}, \space\space\space\space\space\space \tilde{q}_t=\frac{q_{t}}{1-\alpha^{t}} v~t=1λtvt,      q~t=1αtqt这里的 λ \lambda λ α \alpha α 都是初始设定的固定参数,这里的上标 t t t 是求时间 t t t 次方。

权重系数更新的公式为: w t = w t − 1 − η ⋅ 1 q ~ t + ϵ ⋅ v ~ t w_{t} = w_{t-1} - \eta \cdot \frac{1}{\sqrt{\tilde{q}_{t}}+\epsilon}\cdot \tilde{v}_{t} wt=wt1ηq~t +ϵ1v~t这里的 ϵ \epsilon ϵ 是一个非常小的值,为了是防止分母为 0。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/611674.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【计算机毕业设计】springboot工资管理系统

人类现已迈入二十一世纪,科学技术日新月异,经济、资讯等各方面都有了非常大的进步,尤其是资讯与 网络技术的飞速发展,对政治、经济、军事、文化等各方面都有了极大的影响。 利用电脑网络的这些便利,发展一套工资管理系…

Unity 修复Sentinel key not found (h0007)错误

这个问题是第二次遇到了,上次稀里糊涂的解决了,也没当回事,这次又跑出来了,网上找的教程大部分都是出自一个人。 1.删除这个路径下的文件 C:\ProgramData\SafeNet Sentinel,注意ProgramData好像是隐藏文件 2.在Windows…

Mac安装激活--Typora,一个比记事本更加强大的纯文本软件

一、安装 1.首先到官网下载Mac版的Typora,下载地址:https://typoraio.cn/ (1)打开默认中文站 (2)往下滑,下载Mac版 2.下载完成后,会看到Typora.dmg文件,点击打开文件 3.打开Typ…

mac苹果电脑卡顿反应慢如何解决?2024最新免费方法教程

苹果电脑以其稳定的性能、出色的设计和高效的操作系统,赢得了广大用户的喜爱。然而,随着时间的推移,一些用户会发现自己的苹果电脑开始出现卡顿、反应慢等问题。这不仅影响使用体验,还会影响工作效率。那么,面对这些问…

luceda ipkiss教程 68:通过代码模板提高线路设计效率

在用ipkiss设计器件或者线路时,经常需要输入: from ipkiss3 import all as i3那么有什么办法可以快速输入这段代码呢?这里就可以利用Pycharm的 live template功能,只需要将文件:ipkiss.xml (luceda ipkiss教程 68&…

JetBrains的Java集成开发环境IntelliJ 2024.1版本在Windows/Linux系统的下载与安装配置

目录 前言一、IntelliJ在Windows安装二、IntelliJ在Linux安装三、Windows下使用配置四、Linux下使用配置总结 前言 ​ “ IntelliJ IDEA Ultimate是一款功能强大的Java集成开发环境(IDE)。它提供了丰富的功能和工具,可以帮助开发人员更高效地…

深入理解Java HashSet类及其实现原理

哈喽,各位小伙伴们,你们好呀,我是喵手。运营社区:C站/掘金/腾讯云;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点,并以文字的形式跟大家一起交流,互相学习,一…

LabVIEW MEMS电容式压力传感器测试系统

LabVIEW MEMS电容式压力传感器测试系统 随着微电子技术的发展,MEMS(微电机系统)技术在各个领域得到了广泛应用。MEMS电容式压力传感器以其高灵敏度、小尺寸、低功耗等优点,在微传感器领域占据了重要的地位。然而,这些…

基于FPGA的音视频监视器,音视频接口采集器的应用

① 支持1路HDMI1路SDI 输入 ② 支持1路HDMI输出 ③ 支持1080P高清屏显示实时画面以 及叠加的分析结果 ④ 支持同时查看波形图(亮度/RGB)、 直方图、矢量图 ⑤ 支持峰值对焦、斑马纹、伪彩色、 单色、安全框遮幅标记 ⑥ 支持任意缩放画面,支…

TypeScript安装及编译

一、TypeScript是什么 ​ Type script 是微软基于 Javascript 开发的开源编程语言,是拥有类型的 Javascript 的超集,继承了js 所有语法,此外增加了一些自己语法。可以编译成普通、千净、完整的 JavaScript 代码。 目的: 不是创造…

【Linux】从零开始认识动静态库 - 静态库

送给大家一句话: 永不言弃,就是我的魔法! ——阿斯塔《黑色四叶草》 ଘ(੭ˊ꒳​ˋ)੭✧ଘ(੭ˊ꒳​ˋ)੭✧ଘ(੭ˊ꒳​ˋ)੭✧ ଘ(੭ˊ꒳​ˋ)੭✧ଘ(੭ˊ꒳​ˋ)੭✧ଘ(੭ˊ꒳​ˋ)੭✧ ଘ(੭ˊ꒳​ˋ)੭✧ଘ(੭ˊ꒳​ˋ)੭✧ଘ(੭ˊ꒳​ˋ)੭✧ 从零…

mysql数据库调优篇章1--日志篇

目录 1.认识数据库中日志的作用2.增加mysql数据库中my.ini 基本配置3.增加my.ini中参数配置4.查看已经执行过的sql语句过去执行时间5.找出慢查询的sql6.常用参数查询命令7.认识慢查询日志记录8.认识通用日志记录(记录增删改查操作)9.认识二进制文件binlo…

多维点分布的均匀性评估方法(NDD和Voronoi 图法)

评估多维点分布的均匀性是统计学和数据科学中的一个重要问题,特别是在模拟、空间分析和样本设计等领域。下面,我将详细介绍2种评估多维点分布均匀性的方法,包括它们的数学原理、实现公式以及各自的优缺点。 1. 最近邻距离法(Neare…

CTF例题和知识点

[ACTF2020 新生赛]Include 打开靶机发现一个超链接,点击之后出现一段话 “Can you find out the flag?” 查看源码注入,无果 仔细看url,发现有flag.php 根据题目提示,该题应该是文件包含漏洞,因此可以判断出此题是PH…

通俗的理解网关的概念的用途(三):你的数据包是如何到达下一层的

其实,这一章我写不好,因为这其中会涉及到一些计算和一些广播等概念,本人不善于此项。在此略述,可以参考。 每台设备的不同连接在获得有效的IP地址后,会根据IP地址的规则和掩码的规则,在操作系统和交换机&a…

自动控制原理学习--平衡小车的控制算法(三)

上一节PID的simulin仿真,这一节用LQR 一、模型 二、LQR LQR属于现代控制理论的一个很重要的点,这里推荐B站的【Advanced控制理论】课程(up主DR_CAN),讲得很好,这里引用了他视频里讲LQR的ppt。 LQR属于lo…

车载电子电器架构 —— 应用软件开发(中)

车载电子电器架构 —— 应用软件开发(中) 我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证明…

医疗行业如何提升Windows操作系统登录的安全性

医疗行业使用账号和密码登录Windows系统时,可能会遇到一些痛点,这些痛点可能会影响工作效率、数据安全和用户体验。以下是一些主要的痛点: 1. 密码管理复杂性:医疗行业通常涉及大量的敏感数据和隐私信息,因此密码策略…

非模块化 Vue 开发的 bus 总线通信

个人感觉,JavaScript 非模块开发更适合新人上手,不需要安装配置一大堆软件环境,不需要编译,适合于中小项目开发,只需要一个代码编辑器即可开发,例如 vsCode。网页 html 文件通过 script 标签引入 JavaScrip…

学习笔记——字符串(单模+多模+练习题)

单模匹配 Brute Force算法(暴力) 算法思想 母串和模式串字符依次配对,如果配对成功则继续比较后面位置是否相同,如果出现匹配不成功的位置,则j(模式串当前的位置)从头开始,i&…
最新文章