作者: 引线小白-本文永久链接:httpss://www.limoncc.com/post/2823bb8e386a0878/
知识共享许可协议: 本博客采用署名-非商业-禁止演绎4.0国际许可证
摘要: 本文意在理清向量和矩阵微分的基础问题。若有错误,请大家指正。
关键词:Delta更新规则,向量微分,RNN,对角加秩一矩阵,DPLR
[TOC]
一、引言
承前所述,联想记忆和Delta更新规则构成了现代循环神经网络(RNN)演进的重要基础。具体而言,现代RNN的一个日趋统一的范式是:基于Delta更新规则构建动态状态,实现实时的在线元学习。在这种范式下,状态本质上扮演着瞬时记忆的角色,它会根据当前输入进行迭代更新。 因此,不同现代RNN变体的核心差异,本质在于对状态更新机制的设计理念不同——即对如何管理和利用这种瞬时记忆的核心认识存在差异。
基于Delta规则的状态更新机制会自然引入了一种类似经典 Householder 变换形式。标准的Householder变换定义如下:给定一个非零向量 $\bm{k} \in \mathbb{R}^n$,其对应的Householder矩阵 $\bm{H}$ 为:
$$\begin{align}
\bm{H} = \bm{E} - \beta \bm{k} \bm{k}^\T, \quad \beta = \frac{2}{\bm{k}^\top \bm{k}}
\end{align}$$
其中 $\beta$ 是一个由向量 $\bm{v}$ 的模长唯一确定的标量。
然而,在RNN的状态更新机制中,虽然采用了形如 $\bm{H} = \bm{E} - \beta \bm{k}\bm{k}^\T$ (或类似结构)的变换形式,但其参数化方式与严格意义上的Householder变换存在关键区别:$\beta$ 通常被设计为一个独立的、可学习的标量参数,而非固定地等于 $\frac{2}{\bm{k}^\top \bm{k}}$。向量 $\bm{k}$ 本身一般也是由网络(如前一层状态、当前输入)计算得到的可学习函数输出。计算Householder变换连积的WY表示方法可以很好将状态的顺序计算转变为并行计算。从而实现现代RNN的并行计算。接下来将重点讨论以下状态形式的并行计算:
$$\begin{align}
\bm{S}_t
= \bm{S}_{t-1}\big[\bm{E}-\beta_t\bm{k}_t \bm{k}_t^\T\big]+ \eta_t \bm{v}_t\bm{k}_t^\T
\end{align}$$
无论形式如何,这里的关键结构是对角加秩一(DPLR:Diagonal Plus Low-Rank) $\bm{E}-\beta_t\bm{k}_t \bm{k}_t^\T$,其中 $\bm{E}$是对角矩阵, $\beta_t\bm{k}_t \bm{k}_t^\T$的秩 $\mathrm{rank}[\beta_t\bm{k}_t \bm{k}_t^\T]=1$, 对角加秩一的几何意义在RNN的复兴02^1中已经解释,这里上个图回忆一下:

对角加秩一矩阵的实质是通过迭代变换构建记忆状态: 对于任意向量 $\bm{s}$,变换后沿 $\bm{k}$方向的分量被衰减、保留或放大,而正交分量保持不变,从而实现精细化的记忆控制。 $\beta$用于动态调整记忆衰减强度,确保模型仅遗忘与当前任务无关的信息。
二、对角加秩一矩阵变换
2.1、对角加秩一矩阵(DPLR)连积
已知DPLR变换 $\displaystyle \bm{H}_\tau=\bm{E}-\beta_{\tau}\bm{k}_{\tau}\bm{k}_{\tau}^\T$,对于DPLR变换的累积有如下 $\bm{W}^\T\bm{Y}$表示方法:
$$\begin{align}
\bm{P}=\prod_{\tau=1}^{t}\bm{H}_\tau
=\prod_{\tau=1}^{t}\Big[\bm{E}-\beta_{\tau}\bm{k}_{\tau}\bm{k}_{\tau}^\T\Big]
=\bm{E}-\sum_{\tau}^t \bm{w}_\tau \bm{k}_\tau ^\T
=\bm{E}-\bm{W}^\T\bm{Y}^\T
=\bm{E}-\bm{W}^\T\bm{K}
\end{align}$$
可以将连乘运算转换为连加运算,其中 $\bm{W}^\T=[\bm{w}_1,\cdots,\bm{w}_t]$, $\bm{Y}^\T=\bm{K}^\T=[\bm{k}_1,\cdots,\bm{k}_t]$。 其中 $\bm{w}_\tau \mathbf{k}_\tau^\T$是节点 $\tau$ 对全局状态的影响,$\sum$表示所有节点影响的叠加,负号则是历史路径的衰减效应。
证明:
$$\begin{align}
\bm{P}
&= \bigg[\bm{E}-\sum_{\tau}^{t-1} \bm{w}_\tau \bm{k}_\tau ^\T\bigg] \bigg[\bm{E}-\beta_{t}\bm{k}_{t}\bm{k}_{t}^\T \bigg]\\
&=\bm{E}-\sum_{\tau}^{t-1} \bm{w}_\tau \bm{k}_\tau ^\T-\beta_{t}\bm{k}_{t}\bm{k}_{t}^\T +\sum_{\tau}^{t-1} \bm{w}_\tau \bm{k}_\tau ^\T\beta_{t}\bm{k}_{t}\bm{k}_{t}^\T\\
&=\bm{E}-\sum_{\tau}^{t-1} \bm{w}_\tau \bm{k}_\tau ^\T - \underbrace{\beta_{t}\bigg[\bm{k}_{t}-\sum_{\tau}^{t-1} \bm{w}_\tau \bm{k}_\tau ^\T\bm{k}_{t}\bigg]}_{\bm{w}_t}\bm{k}_{t}^\T\\
&=\bm{E}-\sum_{\tau}^t \bm{w}_\tau \bm{k}_\tau ^\T\\
&=\bm{E}-\bm{W}^\T\bm{K}
\end{align}$$
证明过程中还实现了 $\bm{w}_t$的构造方法:
$$\begin{align}
\bm{w}_t = \beta_{t}\bm{k}_{t}-\beta_{t}\sum_{\tau}^{t-1} \bm{w}_\tau \bm{k}_\tau ^\T\bm{k}_{t}=\beta_{t}\bm{k}_{t}-\beta_{t}\sum_{\tau}^{t-1} \bm{k}_\tau ^\T\bm{k}_{t}\bm{w}_\tau
\end{align}$$
2.2、信息传播算子[^3]
令 $\bm{W}^\T=[\bm{w}_1,\cdots,\bm{w}_t]$ 以便整体看待。基于$\bm{w}_t$的构造方法展开,很快就能看到 $\bm{W}$编码了节点信息的传播。先展开几项,以便获得直觉:
$$\begin{align}
\bm{w}_1&= \beta_1 \bm{k}_1\\
\bm{w}_2&= \beta_2 \bm{k}_2 -\beta_{2}\bm{k}_1 ^\T\bm{k}_{2}\bm{w}_1\\
\bm{w}_3&= \beta_3 \bm{k}_3 -\beta_{3}\bm{k}_1 ^\T\bm{k}_{3}\bm{w}_1-\beta_{3}\bm{k}_2 ^\T\bm{k}_{3}\bm{w}_2\\
&\vdots\\
\bm{w}_t&=\beta_t \bm{k}_t-\beta_{t}\bm{k}_1 ^\T\bm{k}_{t}\bm{w}_1-\beta_{t}\bm{k}_2 ^\T\bm{k}_{t}\bm{w}_2-
\cdots-\beta_{t}\bm{k}_{t-1} ^\T\bm{k}_{t}\bm{w}_{t-1}
\end{align}$$
写成矩阵形式
$$\begin{align}
\begin{bmatrix}
\bm{w}_1 ^\T\\
\bm{w}_2^\T\\
\bm{w}_3^\T\\
\vdots\\
\bm{w}_t^\T\\
\end{bmatrix}
=\begin{bmatrix}
\beta_1 \bm{k}_1^\T\\
\beta_2 \bm{k}_2^\T\\
\beta_3 \bm{k}_3^\T\\
\vdots\\
\beta_t \bm{k}_t^\T\\
\end{bmatrix}+ \begin{bmatrix}
0&0&0&\cdots&0\\
-\beta_{2}\bm{k}_1 ^\T\bm{k}_{2}&0&0&\cdots&0\\
-\beta_{3}\bm{k}_1 ^\T\bm{k}_{3}&-\beta_{3}\bm{k}_2 ^\T\bm{k}_{3}&0&\cdots&0\\
\vdots\\
-\beta_{t}\bm{k}_1 ^\T\bm{k}_{t}&-\beta_{t}\bm{k}_2 ^\T\bm{k}_{t}&\cdots&-\beta_{t}\bm{k}_{t-1} ^\T\bm{k}_{t}&0
\end{bmatrix}
\begin{bmatrix}
\bm{w}_1^\T\\
\bm{w}_2^\T\\
\bm{w}_3^\T\\
\vdots\\
\bm{w}_t^\T\\
\end{bmatrix}
\end{align}$$
如果令 $\bm{b}_\tau=\beta_\tau \bm{k}_\tau$, $\bm{B}^\T=[\beta_1 \bm{k}_1,\cdots,\beta_t \bm{k}_t] = \mathrm{diag}[\bm{\beta}]\bm{K} $,于是有 $\bm{W} = \bm{B}+\bm{A}\bm{W}$,易得:
$$\begin{align}
\bm{W} &=\big[\bm{E}-\bm{A}\big]^{-1}\bm{B}\\
\end{align}$$
令 $\bm{T}=\big[\bm{E}-\bm{A}\big]^{-1}$,其中下三角矩阵 $\bm{A}$中的元素 $A_{ij} = -\beta_i (\bm{k}_j^\T\bm{k}_i)$, 其中 $i>j$。这样就得到了 $\bm{W}$的计算方法。
2.3、下三角矩阵矩阵求逆
2.3.1、前向替代法 (Forward Substitution)
如何计算$\bm{T}=\big[\bm{E}-\bm{A}\big]^{-1}$,充分利用 $\bm{A}$是下三角矩阵的事实,对于 $\big[\bm{E}-\bm{A}\big]\bm{T} = \bm{E}$, 展开第 $j$列:
$$\begin{align}
\bm{T}[:,j] - \bm{A} * \bm{T}[:,j] = \bm{e}_j
\end{align}$$
由于 $\bm{A}$是严格下三角, $\bm{T}$是下三角矩阵,可得递推关系:
$$\begin{align}
\bm{T}[i,j] = \begin{cases}
0 & \text{if } i < j \
1 & \text{if } i = j \
\sum_{k=j}^{i-1} \bm{A}[i,k] \bm{T}[k,j] & \text{if } i > j
\end{cases}
\end{align}$$
来个例子,做做数学符号体操:
对于 $$\begin{align}
\bm{A}=\begin{bmatrix}
0 & 0 & 0 & 0\
A_{21} & 0 & 0 & 0\\
A_{31} & A_{32} & 0 & 0\\
A_{41} & A_{42} & A_{43} & 0\\
\end{bmatrix}
\end{align}$$
- 第1列 (j=1):
$i=1\to \bm{T}[1,1]=1$
$i=2\to \bm{T}[2,1]=\bm{A}[2,1]\bm{T}[1,1]=A_{21}\times 1= A_{21}$
$i=3\to \bm{T}[3,1]=\bm{A}[3,1]\bm{T}[1,1]+\bm{A}[3,2]\bm{T}[2,1]=A_{31}+A_{32}A_{21}$
$i=4\to \bm{T}[4,1]=\bm{A}[4,1]\bm{T}[1,1]+\bm{A}[4,2]\bm{T}[2,1]+\bm{A}[4,3]\bm{T}[3,1]=A_{41}+A_{42}A_{21}+A_{43}A_{31}+A_{42}A_{32}A_{21}$
- 第2列 (j=2):
$i=2\to \bm{T}[2,2]=1$
$i=3\to \bm{T}[3,2]=\bm{A}[3,2]\bm{T}[2,2]=A_{32}$
$i=4\to \bm{T}[4,2]=\bm{A}[4,2]\bm{T}[2,2]+\bm{A}[4,3]\bm{T}[3,2]=A_{42}+A_{43}A_{32}$
- 第3列 (j=3):
$i=3\to \bm{T}[3,3]=1$
$i=4\to \bm{T}[4,3]=\bm{A}[4,3]\bm{T}[3,3]=A_{43}$
- 第4列 (j=4):
$i=4\to \bm{T}[4,]=1$
$$\begin{align}
\bm{T}=
\begin{bmatrix}
1 & 0 & 0 & 0\
A_{21} & 1 & 0 & 0\\
A_{31}+A_{32}A_{21} & A_{32} & 1 & 0\\
A_{41}+A_{42}A_{21}+A_{43}A_{31}+A_{42}A_{32}A_{21} & A_{42}+A_{43}A_{32} & A_{43} & 1\\
\end{bmatrix}
\end{align}$$
观察上述公式,其实有 $\bm{T}[i,j]=A[i,1:i-1]\bm{T}[1:i-1,j]$, 这意味着
1 | def block_forward_substitution(A, block_size=32): |
2.3.2、Neumann级数展开
2.3.2.1、收敛证明
若谱范数 $|\bm{A}|_2=\lambda_{max}\big[\bm{A}^\T\bm{A}\big]<1$,则有
$$\begin{align}
\big[\bm{E}-\bm{A}\big]^{-1}=\bm{E}+\bm{A}+\bm{A}^2+\bm{A}^3+\cdots
\end{align}$$
令 $\bm{S}=\bm{E}+\bm{A}+\bm{A}^2+\bm{A}^3+\cdots$, $\bm{S}(N)=\sum_{n=0}^N\bm{A}^n$, 其中 $\bm{S}_0=\bm{E}$, $\bm{S}_1=\bm{E}+\bm{A}$
$$\begin{align}
\big[\bm{E}-\bm{A}\big]\bm{S}
=\lim_{N\to \infty}\big[\bm{E}-\bm{A}\big]\bm{S}(N)
=\lim_{N\to \infty}\bigg[\big[\bm{E}-\bm{A}\big]\sum_{n=0}^N\bm{A}^n\bigg]
=\sum_{n=0}^N\bigg[\bm{A}^n-\bm{A}^{n+1}\bigg]
\end{align}$$
$$\begin{align}
\bm{S}\big[\bm{E}-\bm{A}\big]
=\lim_{N\to \infty}\bm{S}(N)\big[\bm{E}-\bm{A}\big]
=\lim_{N\to \infty}\bigg[\sum_{n=0}^N\bm{A}^n\big[\bm{E}-\bm{A}\big]\bigg]
=\sum_{n=0}^N\bigg[\bm{A}^n-\bm{A}^{n+1}\bigg]
\end{align}$$
又有
$$\begin{align}
\sum_{n=0}^N\bigg[\bm{A}^n-\bm{A}^{n+1}\bigg]
&=\big[\bm{E}-\bm{A}\big]+\big[\bm{A}-\bm{A}^2\big]+\cdots + \big[\bm{A}^n-\bm{A}^{n+1}\big]\\
&=\bm{E}-\bm{A}^{n+1}\Leftarrow(|\bm{A}|_2<1\Rightarrow|\bm{A}^{n+1}|_2\leq |\bm{A}|_2^{n+1}\to 0)\\
&=\bm{E}
\end{align}$$
于是有 $\big[\bm{E}-\bm{A}\big]\bm{S}=\sum_{n=0}^N\bigg[\bm{A}^n-\bm{A}^{n+1}\bigg]=\bm{E}$,得 $\bm{S}=\big[\bm{E}-\bm{A}\big]^{-1}$,因此有:
$$\begin{align}
\big[\bm{E}-\bm{A}\big]^{-1}=\bm{E}+\bm{A}+\bm{A}^2+\bm{A}^3+\cdots
\end{align}$$
2.3.2.2、参数的限制
其中最重要的是谱范数 $|\bm{A}|_2=\lambda_{max}\big[\bm{A}^\T\bm{A}\big]<1$,知道 $A_{ij} = -\beta_i (\bm{k}_j^\T\bm{k}_i)$, 其中 $i>j$。Key向量一般是归一化的 $|\bm{k}_t|_2\leqslant 1,\forall t$,对于 $\beta_t\leqslant \frac{1}{D-1},\forall t$,其中 $D=\dim[\bm{k}_t]$是特征维度。
对于对角线元素
$$\begin{align}
(\bm{A}^\T\bm{A})_{ii}
&=\sum_{k=i+1}^D A_{ki}^2
=\sum_{k=i+1}^D\beta_{k}^2(\bm{k}_i^\T \bm{k}_k)^2\\
&\leqslant\sum_{k=i+1}^D\beta_{k}^2
\leqslant (D-i)\beta_{max}^2\\
&\leqslant \frac{D-i}{(D-1)^2}
\end{align}$$
对于非对角线元素
$$\begin{align}
(\bm{A}^\T\bm{A})_{ij}
&=\sum_{k=\max(i,j)+1}^DA_{ki}A_{kj}
\leqslant \sum_{k=\max(i,j)+1}^D \beta_k^2(\bm{k}_i^\T \bm{k}_k)(\bm{k}_j^\T \bm{k}_k)\\
&\leqslant (D-\max(i,j))\beta_k^2
\leqslant (D-\max(i,j))\beta_{max}^2\\
&\leqslant \frac{D-\max(i,j)}{(D-1)^2}
\end{align}$$
构造上界矩阵 $(\bm{A}^\T\bm{A})_{ij} \leqslant B_{ij}$
$$\begin{align}
B_{ij}\begin{cases}
(D-\max(i,j))\beta_{max}^2 & i\neq j\\
(D-i)\beta_{max}^2 & i=j
\end{cases}
\end{align}$$
这样就有
$$\begin{align}
\lambda_{max}[\bm{A}^\T\bm{A}]
&\leqslant \lambda_{max}[\bm{B}]
\leqslant|\bm{B}|_{\infty}
=\max_{i}\sum_j B_{ij}\\
&=\max_{i}\bigg(B_{ii}+\sum_{i\neq j}B_{ij}\bigg)\\
&=\max_{i}\bigg((D-i)\beta_{max}^2+\sum_{i\neq j}(D-\max(i,j))\beta_{max}^2\bigg)\\
&=\max_{i}\bigg((D-i)\beta_{max}^2+\beta_{max}^2\sum_{j=1}^{i-1}(D-i)+\beta_{max}^2\sum_{j=i+1}^{D}(D-j)\bigg)\\
&=\max_{i}\bigg((D-i)\beta_{max}^2+(i-1)(D-i)\beta_{max}^2+\frac{(D-i)(D-i-1)}{2}\beta_{max}^2\bigg)\\
&=\max_{i} \frac{(D-i)(D+i-1)}{2}\beta_{max}^2\\
&=\max_{i} \frac{1}{2}(-i^2+i+D^2-D)\beta_{max}^2\\
&=\frac{D(D-1)}{2}\beta_{max}^2\Leftarrow (i=1)\\
&< \frac{D}{2(D-1)}\Leftarrow (\beta_t\leqslant \frac{1}{D-1})\\
&<1\Leftarrow (D>2)\\
\end{align}$$
总结一下
- 归一化假设:$|\mathbf{k}_t| \leq 1$ 保证相似度有界
- $\beta$约束:$\beta_{\max} \leq \frac{1}{D-1}$ 控制影响强度
- 块大小影响:$D$ 越大,上界 $\frac{D}{2(D-1)}$ 越接近 0.5
此证明保证了在合理参数范围内,使用Neumann级数展开是数值稳定的。
2.3.2.3、精度控制
$$\begin{align}
|\big[\bm{E}-\bm{A}\big]^{-1}-\bm{S}(N)|
=\left|\sum_{i=N+1}^{\infty}\bm{A}^n\right|
\leqslant \sum_{i=N+1}^{\infty}|\bm{A}|^2
= \frac{|\bm{A}|^{N+1}}{1-|\bm{A}|}
<\epsilon
\end{align}$$
忽略分母 $1-|\bm{A}|$的保守估计(因 $|\bm{A}|\ll 1$时影响小),简化为
$$\begin{align}
|\bm{A}|^{N+1}<\epsilon
\Leftrightarrow (N+1)\log|\bm{A}|<\log \epsilon
\Leftrightarrow N> \frac{\log \epsilon}{\log|\bm{A}|}-1
\end{align}$$
最终截断项数取:
$$\begin{align}
N=\bigg[\frac{\log \epsilon}{\log|\bm{A}|}\bigg]
\end{align}$$
经验项数 $N \approx 10 \sim 20 $
当 $|\bm{A}| \leq 0.7$时,$N \approx 10 \sim 20$可满足典型精度 $\varepsilon \sim 10^{-6}$:
$|\bm{A}| = 0.7 \to N = \lceil \log(10^{-6}) / \log(0.7) \rceil = 25$
$|\bm{A}| = 0.5 \to N = 20$
$|\bm{A}| = 0.3 \to N = 11$
适用场景:$|\bm{A}|$ 较小(如 $ < 0.8$) 时,级数收敛快,较少项即可达到高精度。
2.3.2.4、代码的实现
有两种代码实现,第一种
1 | def neumann_inverse(A, N=10): |
第二种是利用 $\bm{S}(N+1)=\bm{E}+\bm{A}\bm{S}(N)$
1 | def richardson_iteration(A, max_iter=20): |
2.4、状态累计更新的图论
使用Neumann级数展开有
$$\begin{align}
\bm{W}
=\big[\bm{E}-\bm{A}\big]^{-1} \mathrm{diag}[\bm{\beta}]\bm{K}
=\big[\bm{E}+\bm{A}+\bm{A}^2+\bm{A}^3+\cdots+\bm{A}^\infty\big]\cdot\mathrm{diag}[\bm{\beta}]\cdot\bm{K}
\end{align}$$
对于其中的 $\bm{w}_{\tau}$, 令 $\bm{b}_\tau=\beta_\tau \bm{k}_\tau$, $\bm{B}^\T=[\beta_1 \bm{k}_1,\cdots,\beta_t \bm{k}_t]$ 实际有
$$\begin{align}
\bm{w}_\tau = \underbrace{\bm{b}_\tau \vphantom{\sum\sum_{j}}}_{\text{自身贡献}}+ \underbrace{\sum_{i}A_{\tau i}\bm{b}_\tau \vphantom{\sum\sum_{j}}}_{\text{相邻贡献}} + \underbrace{\sum_{i}\sum_{j}A_{\tau i}A_{ij}\bm{b}_j \vphantom{\sum\sum_{j}}}_{\text{跨步影响}} + \cdots
\end{align}$$
其中 $\tau>i>j$。来几个例子,做做数学符号体操, 对了解细节大有裨益,以Neumann级数展开视角,当 $t=4$时考察 $\bm{W}$其实有
$$\begin{align}\bm{W}&=
\underbrace{\begin{bmatrix}
\beta_1 \bm{k}_1^\T\\
\beta_2 \bm{k}_2^\T\\
\beta_3 \bm{k}_3^\T\\
\beta_4 \bm{k}_4^\T\\
\end{bmatrix}
}_{自身贡献}+
\underbrace{\begin{bmatrix}
0 & 0 & 0 & 0\
A_{21} & 0 & 0 & 0\\
A_{31} & A_{32} & 0 & 0\\
A_{41} & A_{42} & A_{43} & 0\\
\end{bmatrix}
}_{\bm{A}:相邻贡献}
\begin{bmatrix}
\beta_1 \bm{k}_1^\T\\
\beta_2 \bm{k}_2^\T\\
\beta_3 \bm{k}_3^\T\\
\beta_4 \bm{k}_4^\T\\
\end{bmatrix}\\&+
\underbrace{\begin{bmatrix}
0 & 0 & 0 & 0\
0 & 0 & 0 & 0\\
A_{32}A_{21} & 0 & 0 & 0\\
A_{42}A_{21}+A_{43}A_{31} & A_{43}A_{32} & 0 & 0\\
\end{bmatrix}
}_{\bm{A}^2:跨步影响}
\begin{bmatrix}
\beta_1 \bm{k}_1^\T\\
\beta_2 \bm{k}_2^\T\\
\beta_3 \bm{k}_3^\T\\
\beta_4 \bm{k}_4^\T\\
\end{bmatrix}+
\underbrace{\begin{bmatrix}
0 & 0 & 0 & 0\
0 & 0 & 0 & 0\\
0 & 0 & 0 & 0\\
A_{43}A_{32}A_{21} & 0 & 0 & 0\\
\end{bmatrix}
}_{\bm{A}^3:长程依赖}
\begin{bmatrix}
\beta_1 \bm{k}_1^\T\\
\beta_2 \bm{k}_2^\T\\
\beta_3 \bm{k}_3^\T\\
\beta_4 \bm{k}_4^\T\\
\end{bmatrix}\\&+
\underbrace{\begin{bmatrix}
0 & 0 & 0 & 0\
0 & 0 & 0 & 0\\
0 & 0 & 0 & 0\\
0 & 0 & 0 & 0\\
\end{bmatrix}
}_{\bm{A}^4}
\begin{bmatrix}
\beta_1 \bm{k}_1^\T\\
\beta_2 \bm{k}_2^\T\\
\beta_3 \bm{k}_3^\T\\
\beta_4 \bm{k}_4^\T\\
\end{bmatrix}+\bm{O}\bm{B}\\
\end{align}$$
展开就有
$$\begin{align}
\bm{w}_1 &= \beta_1\bm{k}_1\\
\bm{w}_2 &= \beta_2\bm{k}_2 + A_{21}(\beta_1\bm{k}_1)\\
\bm{w}_3 &= \beta_3\bm{k}_3 + A_{31}(\beta_1\bm{k}_1) + A_{32}(\beta_2\bm{k}_2) + A_{32}A_{21}(\beta_1\bm{k}_1)\\
\bm{w}_4 &= \beta_4\bm{k}_4 + A_{41}(\beta_1\bm{k}_1) + A_{42}(\beta_2\bm{k}_2) + A_{43}(\beta_3\bm{k}_3) \\&+ A_{42}A_{21}(\beta_1\bm{k}_1)+A_{43}A_{31}(\beta_1\bm{k}_1)+A_{43}A_{32}(\beta_2\bm{k}_2)\\&+A_{43}A_{32}A_{21}(\beta_1\bm{k}_1)\\
\end{align}$$
如果以 $\bm{b}_\tau=\beta_\tau \bm{k}_\tau$ 为节点,以下三角矩阵 $\bm{A}$ 中的元素 $A_{ij}=-\beta_i(\bm{k}_j^\T\bm{k}_i)$ 乘积为边, 其中 $i>j$ 。那么 $\bm{w}_\tau$ 的计算可视为一个图:
- 1、$\bm{w}_\tau$表示节点 $\tau$的累积路径信息, 编码来历史状态衰减信息。
- 2、$\bm{k}_\tau$是节点的特征向量,影响信息传播的方向
- 3、$\beta_\tau$是节点更新强度,控制新信息权重
- 4、$A_{ij}$是边权值,是历史状态衰减系数

这样的路径求和将顺序依赖转化为独立路径的叠加,$\bm{W}$的本质是图上的信息传播算子,是状态转移的累积效应。
- 1、当 $\beta_\tau \approx 0$ 时:节点孤立(无信息传播)
- 2、当 $\mathbf{k}_i \perp \mathbf{k}_j$ 时:边权为零(无信息传递)
- 3、当 $|\mathbf{k}_\tau| \to 0$ 时:节点影响消失
三、状态更新的并行
3.1、状态矩阵并行表示
有如下定理
$$\begin{align}
\bm{S}_t
= \bm{S}_{t-1}\big[\bm{E}-\beta_t\bm{k}_t \bm{k}_t^\T\big]+ \eta_t \bm{v}_t\bm{k}_t^\T
=\bm{S}_0\Big[\bm{E}-\sum_{\tau}^t \bm{w}_\tau \bm{k}_\tau ^\T\Big]+\sum_{\tau=1}^t \bm{u}_\tau \bm{k}_\tau^\T
=\bm{S}_0\big[\bm{E}-\bm{W}^\T\bm{K}\big]+ \bm{U}^\T\bm{K}
\end{align}$$
其中 $\bm{W}^\T=[\bm{w}_1,\cdots,\bm{w}_t]$, $\bm{U}^\T=[\bm{u}_1,\cdots,\bm{u}_t]$, $\bm{K}^\T=[\bm{k}_1,\cdots,\bm{k}_t]$
证明
$$\begin{align}\bm{S}_t
&=\Big[\bm{S}_{t-2}\big[\bm{E}-\beta_{t-1}\bm{k}_{t-1} \bm{k}_{t-1}^\T\big]+ \eta_{t-1} \bm{v}_{t-1}\bm{k}_{t-1}^\T\Big]\big[\bm{E}-\beta_t\bm{k}_t \bm{k}_t^\T\big]+ \eta_t \bm{v}_t\bm{k}_t^\T\\
&=\bm{S}_{t-2}\big[\bm{E}-\beta_{t-1}\bm{k}_{t-1} \bm{k}_{t-1}^\T\big]\big[\bm{E}-\beta_t\bm{k}_t \bm{k}_t^\T\big]+\eta_{t-1} \bm{v}_{t-1}\bm{k}_{t-1}^\T\big[\bm{E}-\beta_t\bm{k}_t \bm{k}_t^\T\big]+ \eta_t \bm{v}_t\bm{k}_t^\T\\
&=\bm{S}_0\underbrace{\prod_{\tau=1}^t \Big[\bm{E} - \beta_\tau \bm{k}_\tau \bm{k}_\tau^\T\Big]}_{\text{累积转移矩阵}} + \underbrace{\sum_{\tau=1}^t \Big[ \eta_\tau \bm{v}_\tau \bm{k}_\tau^\T \prod_{i=\tau+1}^t \big[\bm{E} - \beta_i \bm{k}_i \bm{k}_i^\T\big]\Big]}_{\text{累积更新项}}
\end{align}$$
其中 $\displaystyle \bm{E}-\beta_{t+1}\bm{k}_{t+1} \bm{k}_{t+1}^\T=\bm{E}-\bm{0}=\bm{E}$,对于左侧的累积转移矩阵,根据前面的推导易得结论。
$$\begin{align}
\bm{S}_0\prod_{\tau=1}^t \Big[\bm{E} - \beta_\tau \bm{k}_\tau \bm{k}_\tau^\T\Big]
=\bm{S}_0\Big[\bm{E}-\sum_{\tau}^t \bm{w}_\tau \bm{k}_\tau ^\T\Big]
=\bm{S}_0\big[\bm{E}-\bm{W}^\T\bm{K}\big]
\end{align}$$
现在关注右侧累积更新项 $\bm{L}(t)$
$$\begin{align}
\bm{L}(t)
&=\sum_{\tau=1}^t\Big[\eta_\tau \bm{v}_\tau \bm{k}_\tau^\T \prod_{i=\tau+1}^t \big[\bm{E} - \beta_i \bm{k}_i \bm{k}_i^\T\big]\Big]\\
&=\sum_{\tau=1}^{t-1}\Big[\eta_\tau \bm{v}_\tau \bm{k}_\tau^\T \prod_{i=\tau+1}^{t}\big[\bm{E} - \beta_i \bm{k}_i \bm{k}_i^\T\big]\Big]
+\eta_t \bm{v}_t\bm{k}_t^\T\\
&=\sum_{\tau=1}^{t-1}\Big[\eta_\tau \bm{v}_\tau \bm{k}_\tau^\T \prod_{i=\tau+1}^{t-1}\big[\bm{E} - \beta_i \bm{k}_i \bm{k}_i^\T\big]\Big]\big[\bm{E} - \beta_t \bm{k}_t \bm{k}_t^\T\big]
+\eta_t \bm{v}_t\bm{k}_t^\T\\
&=\bm{L}(t-1)\big[\bm{E} - \beta_t \bm{k}_t \bm{k}_t^\T\big]
+\eta_t \bm{v}_t\bm{k}_t^\T\\
&=\sum_{\tau=1}^{t-1} \bm{u}_\tau \bm{k}_\tau^\T\big[\bm{E} - \beta_t \bm{k}_t \bm{k}_t^\T\big]
+\eta_t \bm{v}_t\bm{k}_t^\T\\
&=\sum_{\tau=1}^{t-1} \bm{u}_\tau \bm{k}_\tau^\T
-\sum_{\tau=1}^{t-1} \bm{u}_\tau \bm{k}_\tau^\T\beta_t \bm{k}_t \bm{k}_t^\T
+\eta_t \bm{v}_t\bm{k}_t^\T\\
&=\sum_{\tau=1}^{t-1} \bm{u}_\tau \bm{k}_\tau^\T
+\Big[\underbrace{\eta_t\bm{v}_t-\beta_t\sum_{\tau=1}^{t-1}\bm{u}_\tau\big[\bm{k}_{\tau}^\T\bm{k}_t\big]}_{\bm{u}_t}\Big]\bm{k}_t^\T\\
&=\sum_{\tau=1}^{t} \bm{u}_\tau \bm{k}_\tau^\T\\
&=\bm{U}^\T\bm{K}
\end{align}$$
证明过程中还实现了 $\bm{u}_t$的构造方法
$$\begin{align}
\bm{u}_t
=\eta_t\bm{v}_t-\beta_t\sum_{\tau=1}^{t-1}\bm{u}_\tau\big[\bm{k}_{\tau}^\T\bm{k}_t\big]
=\eta_t\bm{v}_t-\beta_t\sum_{\tau=1}^{t-1}\big[\bm{k}_{\tau}^\T\bm{k}_t\big]\bm{u}_\tau
\end{align}$$
3.2、信息传播算子的计算
考虑对 $\bm{U}^\T=[\bm{u}_1,\cdots,\bm{u}_t]$的计算,可以整体看待。先展开
$$\begin{align}
\bm{u}_1&= \eta_1 \bm{v}_1\\
\bm{u}_2&= \eta_2 \bm{v}_2 -\beta_{2}\bm{k}_1 ^\T\bm{k}_{2}\bm{u}_1\\
\bm{u}_3&= \eta_3 \bm{v}_3 -\beta_{3}\bm{k}_1 ^\T\bm{k}_{3}\bm{u}_1-\beta_{3}\bm{k}_2 ^\T\bm{k}_{3}\bm{u}_2\\
&\vdots\\
\bm{u}_t&=\eta_t \bm{v}_t-\beta_{t}\bm{k}_1 ^\T\bm{k}_{t}\bm{u}_1-\beta_{t}\bm{k}_2 ^\T\bm{k}_{t}\bm{u}_2-
\cdots-\beta_{t}\bm{k}_{t-1} ^\T\bm{k}_{t}\bm{u}_{t-1}
\end{align}$$
写成矩阵形式
$$\begin{align}
\begin{bmatrix}
\bm{u}_1 ^\T\\
\bm{u}_2^\T\\
\bm{u}_3^\T\\
\vdots\\
\bm{u}_t^\T\\
\end{bmatrix}=
\begin{bmatrix}
\eta_1 \bm{v}_1^\T\\
\eta_2 \bm{v}_2^\T\\
\eta_3 \bm{v}_3^\T\\
\vdots\\
\eta_t \bm{v}_t^\T\\
\end{bmatrix}+ \begin{bmatrix}
0&0&0&\cdots&0\\
-\beta_{2}\bm{k}_1 ^\T\bm{k}_{2}&0&0&\cdots&0\\
-\beta_{3}\bm{k}_1 ^\T\bm{k}_{3}&-\beta_{3}\bm{k}_2 ^\T\bm{k}_{3}&0&\cdots&0\\
\vdots\\
-\beta_{t}\bm{k}_1 ^\T\bm{k}_{t}&-\beta_{t}\bm{k}_2 ^\T\bm{k}_{t}&\cdots&-\beta_{t}\bm{k}_{t-1} ^\T\bm{k}_{t}&0
\end{bmatrix}
\begin{bmatrix}
\bm{u}_1^\T\\
\bm{u}_2^\T\\
\bm{u}_3^\T\\
\vdots\\
\bm{u}_t^\T\\
\end{bmatrix}
\end{align}$$
于是有和Householder变换累积一样,可以定义同样的下三角矩阵 $\bm{A}$中的元素 $A_{ij} = -\beta_i (\bm{k}_j^\T\bm{k}_i)$,或者叫邻接矩阵,其中 $i>j$。定义 $\bm{U}^\T=[\bm{u}_1,\cdots,\bm{u}_t]$, $\bm{V}^\T=[\bm{v}_1,\cdots,\bm{v}_t]$
$$\begin{align}
\bm{U} = \mathrm{diag}[\bm{\eta}]\bm{V}+\bm{A}\bm{U}
\end{align}$$
$$\begin{align}
\bm{U}
= \big[\bm{E}-\bm{A}\big]^{-1}\mathrm{diag}[\bm{\eta}]\bm{V}
=\bm{T}\mathrm{diag}[\bm{\eta}]\bm{V}
\end{align}$$
3.3、状态更新的图论
来几个例子,做做数学符号体操, 对了解细节大有裨益。以Neumann级数展开视角,当 $t=4$时考察 $\bm{U}$其实有
$$\begin{align}
\bm{U}
&=
\underbrace{\begin{bmatrix}
\eta_1 \bm{v}_1^\T\\
\eta_2 \bm{v}_2^\T\\
\eta_3 \bm{v}_3^\T\\
\eta_4 \bm{v}_4^\T\\
\end{bmatrix}
}_{自身贡献}
+
\underbrace{\begin{bmatrix}
0 & 0 & 0 & 0\
A_{21} & 0 & 0 & 0\\
A_{31} & A_{32} & 0 & 0\\
A_{41} & A_{42} & A_{43} & 0\\
\end{bmatrix}
}_{\bm{A}:相邻贡献}
\begin{bmatrix}
\eta_1 \bm{v}_1^\T\\
\eta_2 \bm{v}_2^\T\\
\eta_3 \bm{v}_3^\T\\
\eta_4 \bm{v}_4^\T\\
\end{bmatrix}\\
&+
\underbrace{\begin{bmatrix}
0 & 0 & 0 & 0\
0 & 0 & 0 & 0\\
A_{32}A_{21} & 0 & 0 & 0\\
A_{42}A_{21}+A_{43}A_{31} & A_{43}A_{32} & 0 & 0\\
\end{bmatrix}
}_{\bm{A}^2:跨步影响}
\begin{bmatrix}
\eta_1 \bm{v}_1^\T\\
\eta_2 \bm{v}_2^\T\\
\eta_3 \bm{v}_3^\T\\
\eta_4 \bm{v}_4^\T\\
\end{bmatrix}
+
\underbrace{\begin{bmatrix}
0 & 0 & 0 & 0\
0 & 0 & 0 & 0\\
0 & 0 & 0 & 0\\
A_{43}A_{32}A_{21} & 0 & 0 & 0\\
\end{bmatrix}
}_{\bm{A}^3:长程依赖}
\begin{bmatrix}
\eta_1 \bm{v}_1^\T\\
\eta_2 \bm{v}_2^\T\\
\eta_3 \bm{v}_3^\T\\
\eta_4 \bm{v}_4^\T\\
\end{bmatrix}\\
&+
\underbrace{\begin{bmatrix}
0 & 0 & 0 & 0\
0 & 0 & 0 & 0\\
0 & 0 & 0 & 0\\
0 & 0 & 0 & 0\\
\end{bmatrix}
}_{\bm{A}^4}
\begin{bmatrix}
\eta_1 \bm{v}_1^\T\\
\eta_2 \bm{v}_2^\T\\
\eta_3 \bm{v}_3^\T\\
\eta_4 \bm{v}_4^\T\\
\end{bmatrix}+\bm{O}\bm{C}\\
\end{align}$$
展开有
$$\begin{align}
\bm{u}_1 &= \eta_1\bm{v}_1\\
\bm{u}_2 &= \eta_2\bm{v}_2 + A_{21}(\eta_1\bm{v}_1)\\
\bm{u}_3 &= \eta_3\bm{v}_3 + A_{31}(\eta_1\bm{v}_1) + A_{32}(\eta_2\bm{v}_2) + A_{32}A_{21}(\eta_1\bm{v}_1)\\
\bm{u}_4 &= \eta_4\bm{v}_4 + A_{41}(\eta_1\bm{v}_1) + A_{42}(\eta_2\bm{v}_2) + A_{43}(\eta_3\bm{v}_3) \\
&+ A_{42}A_{21}(\eta_1\bm{v}_1)+A_{43}A_{31}(\eta_1\bm{v}_1)+A_{43}A_{32}(\eta_2\bm{v}_2)\\
&+A_{43}A_{32}A_{21}(\eta_1\bm{v}_1)\\
\end{align}$$
如果以$\bm{c}_\tau=\eta_\tau \bm{v}_\tau$为节点,以下三角矩阵 $\bm{A}$中的元素 $A_{ij} = -\beta_i (\bm{k}_j^\T\bm{k}_i)$乘积为边, 其中 $i>j$。那么 $\bm{u}_\tau$的计算可视为一个图:
- 1、$\bm{u}_\tau$表示节点 $\tau$的更新路径信息, 编码了信息注入路径
- 2、$\bm{v}_\tau$是节点的特征向量,影响信息传播的方向
- 3、$\eta_\tau$是节点更新强度,控制新信息权重
- 4、$A_{ij}$是边权值,是状态衰减系数

这样的路径求和将顺序依赖转化为独立路径的叠加,$\bm{U}$的本质是图上的信息传播算子,是状态更新的累积效应。
- 1、当 $\eta_\tau \approx 0$ 时:节点孤立(无信息传播)
- 2、当 $\mathbf{k}_i \perp \mathbf{k}_j$ 时:边权为零(无信息传递)
- 3、当 $|\mathbf{v}_\tau| \to 0$ 时:节点影响消失
四、状态矩阵的并行计算
$$\begin{align}
\bm{S}_t
= \bm{S}_{t-1}\big[\bm{E}-\beta_t\bm{k}_t \bm{k}_t^\T\big]+ \eta_t \bm{v}_t\bm{k}_t^\T
=\bm{S}_0\Big[\bm{E}-\sum_{\tau}^t \bm{w}_\tau \bm{k}_\tau ^\T\Big]+\sum_{\tau=1}^t \bm{u}_\tau \bm{k}_\tau^\T
=\bm{S}_0\big[\bm{E}-\bm{W}^\T\bm{K}\big]+ \bm{U}^\T\bm{K}
\end{align}$$
其中 $\bm{W}^\T=[\bm{w}_1,\cdots,\bm{w}_t]$, $\bm{U}^\T=[\bm{u}_1,\cdots,\bm{u}_t]$, $\bm{K}^\T=[\bm{k}_1,\cdots,\bm{k}_t]$, $\bm{W}=\big[\bm{E}-\bm{A}\big]^{-1}\mathrm{diag}[\beta]\bm{K}$, $\bm{U}=\big[\bm{E}-\bm{A}\big]^{-1}\mathrm{diag}[\eta]\bm{V}$,下三角矩阵 $\bm{A}$中的元素 $A_{ij} = -\beta_i (\bm{k}_j^\T\bm{k}_i)\quad,i>j$。
最后以总体视角讨论一下更新公式,然后结束本文。对于上式,稍加变形有
$$\begin{align}
\bm{S}_t = \bm{S}_0+\big[\bm{U}-\bm{W}\bm{S}_0^\T\big]^\T\bm{K}
\end{align}$$
令 $\dim[\bm{Q}]=t\times d_q$, $\dim[\bm{K}]=t\times d_k$, $\dim[\bm{V}]=t\times d_v$, 通常 $d=d_q=d_k=d_v$。明晰一下更新公式中的维度:其中 $\dim[\bm{S}]=d_v\times d_k$、 $\dim[\bm{U}]=t\times d_v$、 $\dim[\bm{W}]=t\times d_k$, 上述变形后可以降低计算复杂度。实际有
$$\begin{align}
&O(\bm{S}_0+\big[\bm{U}-\bm{W}\bm{S}_0^\T\big]^\T\bm{K})\\
&=O(\bm{T}_1 = \bm{U} - \bm{W}\bm{S}_0^\T)
+O(\bm{T}_2 = \bm{T}_1\bm{K})
+O(\bm{S}_0+\bm{T}_2)\\
&=O(td^2+td+t^2d+d^2)\\
&=O(td^2+t^2d)\\
&O(\bm{S}_0\big[\bm{E}-\bm{W}^\T\bm{K}\big]+ \bm{U}^\T\bm{K})\\
&=O(\bm{T}_1 = \bm{E}-\bm{W}^\T\bm{K})
+O(\bm{T}_2 = \bm{S}_0\bm{T}_1)
+O(\bm{T}_3=\bm{U}^\T\bm{K})
+O(\bm{T}_2+\bm{T}_3)\\
&=O(t^2d+td+d^3+t^2d+d^2)\\
&=O(d^3+2t^2d)
\end{align}$$
当 $t \ll d$的时候,也就是说序列长度小于特征维度。对于长序列而言,可以分块计算,从而减小 $t$。说白了其实就是切分为若干段来计算,一张图说明所有:块内并行,块间依赖。

那么又有好事者问:先计算 $\bm{S}=\bm{S}_0+\big[\bm{U}-\bm{W}\bm{S}_0^\T\big]^\T\bm{K}$,再计算 $\bm{O}=\bm{Q}\bm{S}^\T$,和直接按 $\bm{O}=\bm{Q}\bm{S}^\T+(\bm{Q}\bm{K}^\T\odot\bm{M}) (\bm{U} - \bm{W}\bm{S}^\T) $公式计算,有什么不同?后面的复杂度要越高一点,因为前者只需三次矩阵乘法,后者需要4次矩阵乘法。
分开计算
- $\bm{T}_1 = \bm{U} - \bm{W}\bm{S}_0^\T$:$O(td^2+td)$
- $\bm{T}_2 = \bm{T}_1\bm{K}$:$O(t^2d)$
- $\bm{S} = \bm{S}_0+\bm{T}_2$:$O(d^2)$
- $\bm{O} = \bm{Q} \bm{S}^\T$:$O(td^2)$
总复杂度:$O(2td^2+t^2d+d^2+td)$
- $\bm{T}_1 = \bm{Q} \bm{K}^\T$:$O(t^2 d)$
- $\bm{T}_2 = \bm{T}_1 \odot \bm{M}$:$O(t^2)$ (轻量)
- $\bm{T}_3 = \bm{U} - \bm{W}\bm{S}_0^\T$:$O(td^2+td)$
- $\bm{T}_4 = \bm{T}_2 \bm{T}_3$:$O(t^2 d)$
- $\bm{T}_5 = \bm{Q} \bm{S}_0^\T$:$O(td^2)$
- $\bm{O} = \bm{T}_5 + \bm{T}_4$:$O(t^2)$
合计计算
总复杂度:$O(2td^2+2t^2d+2t^2+td)$
五、评述
1、充分利用对角加秩一(DPLR:Diagonal Plus Low-Rank)的特点,实现了基于Delta规则RNN的并行计算。
2、并行计算公式中最重要的其实是两个信息传播算子 $\bm{W}$、 $\bm{U}$的计算。尤其以下三角矩阵 $\bm{A}$中元素 $A_{ij} = -\beta_i (\bm{k}_j^\T\bm{k}_i)$最为关键。这就凸显了key矩阵 $\bm{K
}$的重要性。因为它元素的内积或者说自相关性直接影响每个token信息如何传播到瞬时记忆状态 $\bm{S}$。针对这个问题,在RWKV7中对 $\bm{K}$做了特殊处理,增强了 $\bm{K}$的表达能力。
3、现代RNN核心机制和计算问题解决后,如何进一步增强RNN的表达能力。就是本系列的下一篇内容了。
4、成文有些匆忙,难免有些错误,请大家批评指正。
参考文献
[^2]: Yang, S., Wang, B., Zhang, Y., Shen, Y., & Kim, Y. (2025, January 15). Parallelizing linear transformers with the delta rule over sequence length. arXiv. https://doi.org/10.48550/arXiv.2406.06484
| 版权声明 | ![]() |
| 由引线小白创作并维护的柠檬CC博客采用署名-非商业-禁止演绎4.0国际许可证。 本文首发于柠檬CC [ https://www.limoncc.com ] , 版权所有、侵权必究。 | |
| 本文永久链接 | httpss://www.limoncc.com/post/2823bb8e386a0878/ |
| 如果您需要引用本文,请参考: |
| 引线小白. (Aug. 17, 2025). 《RNN的复兴03:现代RNN的并行计算》[Blog post]. Retrieved from https://www.limoncc.com/post/2823bb8e386a0878 |
| @online{limoncc-2823bb8e386a0878, title={RNN的复兴03:现代RNN的并行计算}, author={引线小白}, year={2025}, month={Aug}, date={17}, url={\url{https://www.limoncc.com/post/2823bb8e386a0878}}, } |
