自动微分与反向传播大意

作者: 引线小白-本文永久链接:httpss://www.limoncc.com/post/4be777c8af12fc58/
知识共享许可协议: 本博客采用署名-非商业-禁止演绎4.0国际许可证

数学上一切都不理解,很大程度上是对符号的不清晰、不清楚、不理解造成的。究其根本是结合数学概念的符号练习太少了。

首先我们回忆一下向量函数对向量导数

1、数量函数对向量导数

1、数量函数对列向量的导数定义,结果是个列向量。

$$\begin{align}
\big[\nabla f\big]_{i\times 1}=\left[\frac{\partial{f(\bm{x})}}{\partial{\bm{x}}}\right]_{i\times 1}=\frac{\partial{f(\bm{x})}}{\partial{x_i}}
\end{align}$$

2、数量函数对行向量的导数定义,结果是个行向量。

$$\begin{align}
\left[\frac{\partial{f(\bm{x})}}{\partial{\bm{x}^\text{T}}}\right]_{1\times j}=\frac{\partial{f(\bm{x})}}{\partial{x_j}}
\end{align}$$

2、向量函数对向量导数

1、列向量函数对行向量的导数定义
$$\begin{align}
\bm{J}=\left[\frac{\partial{\bm{f}(\bm{x})}}{\partial{\bm{x}^\text{T}}}\right]_{i\times j}=\frac{\partial{f_i}}{\partial{x_j}}
\end{align}$$

2、行向量函数对列向量的导数定义

$$\begin{align}
\left[\frac{\partial{\bm{f}^\text{T}(\bm{x})}}{\partial{\bm{x}}}\right]_{j\times i}=\frac{\partial{f_i}}{\partial{x_j}}
\end{align}$$

3、链式法则

一般情况下我们喜欢列向量函数对行向量导数的表示方法。但是在链式法则中,数量函数对向量导数,一般使用数量函数对行向量导数,也就是说我们认为梯度是行向量,这样就是保持了统一:因为有链式法则

$$\begin{align}
h(\bm{t})=f\big(\bm{x}(\bm{t})\big)&\to \mathrm{D}h=\mathrm{D}f \cdot\mathrm{D}\bm{x}\\
\bm{h}(\bm{t})=\bm{f}\big(\bm{x}(\bm{t})\big)&\to \mathrm{D}\bm{h}=\mathrm{D}\bm{f}\cdot\mathrm{D}\bm{x}
\end{align}$$
链式法则将会保持一致和简单性,即由外向内,从左到右书写。

注意梯度一般我们默认是列向量,这样梯度的导数才是海塞矩阵哈。千万不要混乱了。所以请务必熟悉我讲的以上三点。到底是行对列,还是列对行,然后求出的维度是多少,务必熟练理解定义。

4、自动微分的一个典型例子

现在我们来看看神经网络的自动微分为啥要用反向传播,我们先定义点符号,以方便表示:
1、$\displaystyle \ell:$ 损失函数
2、$\displaystyle \bm{x}\to\bm{z}_1\to\bm{z}_2\to\bm{y}\to\ell$: 一个三层神经网络= 输入层 $\bm{x}$+ 隐藏层 $\bm{z}_1$ +隐藏层 $\bm{z}_2$+ 输出层 $\bm{y}$

我们先后向求导哈:

$$\begin{align}
\frac{d\ell}{d\bm{x}^\text{T}}
=\frac{d\ell}{d\bm{y}^\text{T}}
\cdot\frac{d\bm{y}}{d\bm{z}_2^\text{T}}
\cdot\frac{d\bm{z}_2}{d\bm{z}_1^\text{T}}
\cdot\frac{d\bm{z}_1}{d\bm{x}^\text{T}}
\end{align}$$

貌似也没什么特别的哈,来我们让它特别起来,首先我们来明确一下维度

$$\begin{align}
\bm{x}_{10\times 1}\to\bm{z}1_{6\times 1}\to\bm{z}2_{8\times 1}\to\bm{y}_{2\times 1}\to\ell_{1\times 1}
\end{align}$$

我们给她加上维度

$$\begin{align}
\bigg[\frac{d\ell}{d\bm{x}^\text{T}}\bigg]_{1\times 10}
=\bigg[\frac{d\ell}{d\bm{y}^\text{T}}\bigg]_{1\times 2}
\cdot\bigg[\frac{d\bm{y}}{d\bm{z}_2^\text{T}}\bigg]_{2\times 8}
\cdot\bigg[\frac{d\bm{z}_2}{d\bm{z}_1^\text{T}}\bigg]_{8\times 6}
\cdot\bigg[\frac{d\bm{z}_1}{d\bm{x}^\text{T}}\bigg]_{6\times 10}
\end{align}$$

我来看一个矩阵乘法的简单事实
$$\begin{align}
\bm{D}_{n\times s}= \bm{A}_{n\times p}\cdot\bm{B}_{p\times m}\cdot\bm{C}_{m\times s}
\end{align}$$
1、首先计算 $\displaystyle \bm{A}_{n\times p}\cdot\bm{B}_{p\times m}$。你要计算 $\displaystyle n\times p \times m$次乘法。对吧,想不通自己回忆矩阵数乘定义,并计算一翻。看不出显然要打板子哦。
2、其次计算 $\displaystyle \big[\bm{A}_{n\times p}\cdot\bm{B}_{p\times m}\big]_{n\times m}\cdot\bm{C}_{m\times s}$,你要计算 $\displaystyle n\times m\times s$次乘法
3、最后计算出 $\displaystyle \bm{D}$,你一共要计算

$$\begin{align}
\mathop{\text{cal_cost}}(\bm{D})=n\times p \times m+n\times m\times s
\end{align}$$

如果我们反过来计算呢?
$$\begin{align}
\big[\bm{D}^\text{T}\big]_{s\times n}= \big[\bm{C}^\text{T}\big]_{s\times m}\cdot\big[\bm{B}^\text{T}\big]_{m\times p}\cdot\big[\bm{A}^\text{T}\big]_{p\times n}
\end{align}$$

计算出 $\displaystyle \bm{D}^\text{T}$,你一共要计算

$$\begin{align}
\mathop{\text{cal_cost}}(\bm{D}^\text{T})=s\times m \times p+s\times p\times n
\end{align}$$

现在在看看这个后向求导计算
$$\begin{align}
\bigg[\frac{d\ell}{d\bm{x}^\text{T}}\bigg]_{1\times 10}
=\bigg[\frac{d\ell}{d\bm{y}^\text{T}}\bigg]_{1\times 2}
\cdot\bigg[\frac{d\bm{y}}{d\bm{z}_2^\text{T}}\bigg]_{2\times 8}
\cdot\bigg[\frac{d\bm{z}_2}{d\bm{z}_1^\text{T}}\bigg]_{8\times 6}
\cdot\bigg[\frac{d\bm{z}_1}{d\bm{x}^\text{T}}\bigg]_{6\times 10}
\end{align}$$

计算成本是

$$\begin{align}
\mathop{\text{cal_cost}}\bigg(\frac{d\ell}{d\bm{x}^\text{T}}\bigg)=1\times 2 \times 8+1\times 8\times 6+1\times 6\times 10 = 124
\end{align}$$

对于前向计算呢

$$\begin{align}
\bigg[\frac{d\ell}{d\bm{x}}\bigg]_{10\times 1}
=\bigg[\frac{d\bm{z}_1^\text{T}}{d\bm{x}}\bigg]_{10\times 6}
\cdot\bigg[\frac{d\bm{z}_2^\text{T}}{d\bm{z}_1}\bigg]_{6\times 8}
\cdot\bigg[\frac{d\bm{y}^\text{T}}{d\bm{z}_2}\bigg]_{8\times 2}
\cdot\bigg[\frac{d\ell}{d\bm{y}}\bigg]_{2\times 1}
\end{align}$$

$$\begin{align}
\mathop{\text{cal_cost}}\bigg(\frac{d\ell}{d\bm{x}}\bigg)=10\times 6 \times 8+10\times 8\times 2+10\times 2\times 1 = 660
\end{align}$$

实际上我们有,如果看不出显然,请自我反思

$$\begin{align}
\forall \dim[\bm{\ell}]\leqslant dim[\bm{x}]\to\mathop{\text{cal_cost}}\bigg(\frac{d\ell}{d\bm{x}^\text{T}}\bigg)
\leqslant
\mathop{\text{cal_cost}}\bigg(\frac{d\ell}{d\bm{x}^\text{T}}\bigg)\\
\forall \dim[\bm{\ell}]> dim[\bm{x}]\to\mathop{\text{cal_cost}}\bigg(\frac{d\ell}{d\bm{x}^\text{T}}\bigg)>\mathop{\text{cal_cost}}\bigg(\frac{d\ell}{d\bm{x}^\text{T}}\bigg)
\end{align}$$

5、自动微分一个普遍的写法

特别的如果我们不将 $\displaystyle \ell$视为标量的损失函数,而是一个可以向量优化目标 $\displaystyle \bm{\ell}$。同时我们特别的定义
第 $i$层雅可比矩阵 $\displaystyle \bm{J}_i =\frac{d\bm{z}_i^\text{T}}{d\bm{z}_{i-1}}$

如果是标量损失,最后一层导数 $\displaystyle \bm{v}_{output}:=\displaystyle \bm{J}_n=\frac{d\ell}{d\bm{y}}$

那么有前向计算公式

$$\begin{align}
\frac{d\bm{\ell}^\text{T}}{d\bm{x}} =\prod_{i=1}^{n} \bm{J}_i ^\text{T}
=\bm{J}_1 ^\text{T}\cdot\bm{J}_2 ^\text{T}\cdots\bm{J}_{n-1} ^\text{T}\cdot\bm{J}_n ^\text{T}
\end{align}$$

后向计算公式
$$\begin{align}
\frac{d\bm{\ell}}{d\bm{x}^\text{T}} =\prod_{i=0}^{n-1} \bm{J}_{n-i}
=\bm{J}_{n}\cdot\bm{J}_{n-1}\cdots\bm{J}_{2}\cdot\bm{J}_{1}
\end{align}$$

对于标量损失函数前向计算公式

$$\begin{align}
\frac{d\ell}{d\bm{x}} =\prod_{i=1}^{n-1} \bm{J}_i ^\text{T}\cdot\bm{v}_{output}
==\bm{J}_1 ^\text{T}\cdot\bm{J}_2 ^\text{T}\cdots\bm{J}_{n-1} ^\text{T}\cdot\bm{v}_{output}
\end{align}$$

后向计算公式
$$\begin{align}
\frac{d\ell}{d\bm{x}^\text{T}}
=\bm{v}_{output}^\text{T}\cdot\prod_{i=1}^{n-1} \bm{J}_{n-i}
=\bm{v}_{output}^\text{T}\cdot\bm{J}_{n-1}\cdots\bm{J}_{2}\cdot\bm{J}_{1}
\end{align}$$

而对于一个熟练上述概念的人来说,以上所有内容其实应该是显然的,你应该做到在脑海里一闪而过就得出结论。如果不能做到,请打牢固概念基础


版权声明
引线小白创作并维护的柠檬CC博客采用署名-非商业-禁止演绎4.0国际许可证。
本文首发于柠檬CC [ https://www.limoncc.com ] , 版权所有、侵权必究。
本文永久链接httpss://www.limoncc.com/post/4be777c8af12fc58/
如果您需要引用本文,请参考:
引线小白. (Apr. 25, 2017). 《自动微分与反向传播大意》[Blog post]. Retrieved from https://www.limoncc.com/post/4be777c8af12fc58
@online{limoncc-4be777c8af12fc58,
title={自动微分与反向传播大意},
author={引线小白},
year={2017},
month={Apr},
date={25},
url={\url{https://www.limoncc.com/post/4be777c8af12fc58}},
}

'