如何拓展大模型上下文长度

26 min

最近需要将 Qwen2.5 Math 7B 的模型上下文从 4k 扩展到 16k,记录一下是怎么实现的。这里采用的是 NTK aware 的方法,原因是足够简单且可行。

RoPE

Qwen 系列模型使用了 RoPE 位置编码,所以首先回顾一下 RoPE 的方法:对于给定的 Query (qq) 和 Key (kk) 向量,它们在送入注意力计算之前,会根据其在序列中的绝对位置 ii,将每两个维度视为复平面上的一个点,进行旋转。由于旋转的相对性,天然就有相对位置编码的特性。

对一个维度为 dd 的向量 x=[x0,x1,,xd1]x=[x_0,x_1,\dots,x_{d-1}](在序列中的位置为 mm),将第 2i2i2i+12i+1 位作为一个二维向量进行旋转:

(x2ix2i+1)=(cos(mθi)sin(mθi)sin(mθi)cos(mθi))(x2ix2i+1)\begin{pmatrix} x'_{2i} \\ x'_{2i+1} \end{pmatrix} = \begin{pmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{pmatrix} \begin{pmatrix} x_{2i} \\ x_{2i+1} \end{pmatrix}

其中 θi\theta_{i} 是一个和位置 ii 相关的角频率,计算方式为:

θi=b2i/d\theta_{i}=b^{-2i/d}

bb 为基频,在 Qwen2.5 Math 7B 中为 1000010000。在实际代码中,旋转矩阵 RΘ,md\mathbf{R}_{\mathbf{\Theta},m}^{d} 非常稀疏,需要使用如下方式进行计算以提高计算效率:

RΘ,mdx=(x0x1x2x3xd2xd1)(cosmθ0cosmθ0cosmθ1cosmθ1cosmθd21cosmθd21)+(x1x0x3x2xd1xd2)(sinmθ0sinmθ0sinmθ1sinmθ1sinmθd21sinmθd21)\mathbf{R}^d_{\mathbb{\Theta},m} \mathbf{x} = \begin{pmatrix} x_0 \\ x_1 \\ x_{2} \\ x_{3} \\ \cdots \\ x_{d-2} \\ x_{d-1} \end{pmatrix} \otimes \begin{pmatrix} \cos m \theta_0 \\ \cos m\theta_{0} \\ \cos m\theta_{1} \\ \cos m\theta_{1} \\ \cdots \\ \cos m\theta_{\frac{d}{2}-1} \\ \cos m\theta_{\frac{d}{2}-1} \end{pmatrix} + \begin{pmatrix} -x_{1} \\ x_{0} \\ -x_{3} \\ x_{2} \\ \cdots \\ -x_{d-1} \\ x_{d-2} \end{pmatrix} \otimes \begin{pmatrix} \sin m \theta_0 \\ \sin m\theta_{0} \\ \sin m\theta_{1} \\ \sin m\theta_{1} \\ \cdots \\ \sin m\theta_{\frac{d}{2}-1} \\ \sin m\theta_{\frac{d}{2}-1} \end{pmatrix}

在 LLaMA 中的实现 1

# 生成旋转矩阵
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    # 计算词向量元素两两分组之后,每组元素对应的旋转角度\theta_i
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # 生成 token 序列索引 t = [0, 1,..., seq_len-1]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2]
    freqs = torch.outer(t, freqs).float()  # 计算m * \theta

    # 计算结果是个复数向量
    # 假设 freqs = [x, y]
    # 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

# 旋转位置编码计算
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # xq.shape = [batch_size, seq_len, dim]
    # xq_.shape = [batch_size, seq_len, dim // 2, 2]
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)

    # 转为复数域
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)

    # 应用旋转操作,然后将结果转回实数域
    # xq_out.shape = [batch_size, seq_len, dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
    return xq_out.type_as(xq), xk_out.type_as(xk)

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.wq = Linear(...)
        self.wk = Linear(...)
        self.wv = Linear(...)

        self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)

    def forward(self, x: torch.Tensor):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(batch_size, seq_len, dim)
        xk = xk.view(batch_size, seq_len, dim)
        xv = xv.view(batch_size, seq_len, dim)

        # attention 操作之前,应用旋转位置编码
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        # scores.shape = (bs, seqlen, seqlen)
        scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)
        scores = F.softmax(scores.float(), dim=-1)
        output = torch.matmul(scores, xv)  # (batch_size, seq_len, dim)
  # ......

对位置在 mm 处的 query qmq_{m} 和位置在 nn 处的 key knk_{n},计算内积就变成了

(RΘ,mdqm)(RΘ,mdkn)=qm(RΘ,mdRΘ,nd)kn=qmRΘ,mndkn\begin{align} (\mathbf{R}_{\mathbf{\Theta},m}^{d}q_{m})^{\top}(\mathbf{R}_{\mathbf{\Theta},m}^{d}k_{n}) & = q_{m}^{\top}\left({\mathbf{R}_{\mathbf{\Theta},m}^{d}}^{\top}\mathbf{R}_{\mathbf{\Theta},n}^{d}\right) k_{n} \\ & = q_{m}^{\top}\mathbf{R}_{\mathbf{\Theta },m-n}^{d}k_{n} \end{align}

的确只和相对位置有关。

θi\theta_{i},当 ii 较小时(xx 的低维部分),θi1\theta_{i}\approx 1,是高频信息(即在 mθim\theta_{i} 旋转速度更快),用于捕捉短距离注意力关系;而当 ii 较大时(xx 的高维部分),是低频信息,用于捕捉长距离注意力关系。总的来说就是低维高频、高维低频。至此,对于 RoPE 的理解就足够了,可以开始推导如何扩展上下文长度了。

扩展上下文长度

虽然好像 RoPE 编码的定义使其好像很容易进行上下文扩展(扩展上下文长度就相当于增大 mm 的取值范围),但是注意,由于基频是预先定义好的,所以所有维度的频率都是固定的,面对更长的长度,mm 的取值范围增大,mθim\theta_{i} 的变化范围也会增大,多出来的范围就是模型并没有学习过的,对高频短距离信息尤为如此;从另一个视角来看,上下文长度的增加会带来更多的低频长距离信息,这是模型没有学习过的。因此必须做出一些必要的调整以适应增加的低频长距离信息。与此同时,需要注意的是,高频的短距离信息几乎不应该有变化,因此调整时也要注意到这些信息的保存。

由于低频、高频信息和 θi\theta_{i} 直接相关,而 θi=b2i/d\theta_{i}=b^{-2i/d} 中的 dd 是随模型固定的,ii 应该尽量避免变化,所以自然的想法是调整原有的基频 bb。从上面说到的,上下文长度扩展的程度对低频信息的引入有很大的关联,因此引入一个变量 ss 为新、旧最大上下文长度的比值,并定义新的基频

b=bskb' = b\cdot s^{k}

这里 kk 是一个待定的常数。带入 θi\theta_{i} 可得

θi=(b)2i/d=b2i/ds2ik/d=θis2ik/d\theta'_{i} = (b')^{-2i/d} = b^{-2i/d}\cdot s^{-2ik/d} = \theta_{i}\cdot s^{-2ik/d}

对高频信息,当 i0i\to 0 时,s2ik/d1s^{-2ik/d}\to 1,故 θiθi\theta'_{i}\approx\theta_{i},满足之前讨论到的“高频的短距离信息几乎不应该有变化”。

而对低频信息,即当 id2i\to \frac{d}{2} 时(注意 ii 的取值范围是 [0,d/21][0, d / 2-1]),由于扩展 ss 倍上下文,所以希望能够在低频部分进行 s1s^{-1} 的线性插值(角频率缩小 ss 倍),即希望 s2ik/ds1s^{-2ik/d}\approx s^{-1},得到

k=d2imax=dd2k = \frac{d}{2i_{\text{max}}}=\frac{d}{d-2}

那么也就得到了最终的基频修改公式:

b=bsdd2b'=b\cdot s^{\frac{d}{d-2}}

在实际应用中,发现由于 dd 通常比较大,即使是取 b=bsb'=b\cdot s 也能有不错的效果。并且原作者提到的,对 LLaMA 7B 的模型,你甚至不需要做任何的微调就能适应新的上下文长度!2

扩展 Qwen2.5 Math 7B 上下文长度

回到正题,终于可以扩展上下文长度了!Qwen2.5 Math 7B 的默认参数配置 3 为:

{
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151643,
  "hidden_act": "silu",
  "hidden_size": 3584,
  "initializer_range": 0.02,
  "intermediate_size": 18944,
  "max_position_embeddings": 4096,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 28,
  "num_hidden_layers": 28,
  "num_key_value_heads": 4,
  "rms_norm_eps": 1e-06,
  "rope_theta": 10000,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.44.0",
  "use_cache": true,
  "use_mrope": false,
  "use_sliding_window": false,
  "vocab_size": 152064
}

这里的 rope_theta 就是基频 bb,由于使用了多头注意力机制,所以实际的向量维度为 hidden_size / num_attention_heads,即 128128,新旧最大上下文长度之比 s=16k/4k=4s=16\text{k} / 4\text{k}=4,那么可以算得

b=bsdd241810.5b'=b\cdot s^{\frac{d}{d-2}} \approx 41810.5

如果取 ss 的指数为 11,那么

bbs=40000.0b'\approx b\cdot s=40000.0

max_position_embeddingsrope_theta 的值分别修改为 16384163844000040000 即可。

结束……了?

当然没有!实际上,上面所推导的方法正是 NTK Aware 方法 2,基于这个方法,还有一些其他的变体,例如 Dynamic NTK Interpolation4、NTK-by-parts Interpolation5、YaRN6 等,可以阅读 这篇博客 来迅速了解这些方法。

另外,上面的推导仅仅只是对 NTK aware 算法的一个近似求解,具体的推导过程可以看下面的内容(对数学要求较高,可以跳过不看)。

神经内切核

首先我们需要了解 NTK(Neural Tangent Kernel,神经内切核)的理论 7。对于一个由参数 θRP\mathbf{\theta}\in \mathbb{R}^{P} 参数化的神经网络 f(x;θ)f(\mathbf{x};\mathbf{\theta}),其神经内切核 Θ(x,x)\mathbf{\Theta}(\mathbf{x},\mathbf{x}') 是一个衡量输入 x\mathbf{x}x\mathbf{x}' 之间关系的核函数,定义为

Θ(x,x)=θf(x;θ),θf(x;θ)\mathbf{\Theta}(\mathbf{x},\mathbf{x}')= \langle \nabla_{\theta}f(\mathbf{x};\theta), \nabla_{\theta}f(\mathbf{x}';\mathbf{\theta}) \rangle

根据 Jacot 等 7 的研究,可以知道在无限宽度和无穷小学习率的极限下,神经网络的训练动态可以用这个核函数来精确描述。模型在函数空间中的演化等价于一个核回归问题,其使用的核就是 NTK。因此,NTK 的性质(尤其是其谱特性)决定了模型的学习能力和泛化行为。

NTK Aware 方法

RoPE 核函数

将 RoPE 视为一个特征映射 ϕ:NCd/2\phi:\mathbb{N}\to \mathbb{C}^{d/2},将一个位置索引 mm 映射到一个 d2\frac{d}{2} 维的复数向量:

ϕ(m)=1d/2(eimθ0eimθ1eimθd/21)\phi(m)=\frac{1}{\sqrt{ d/2 }}\begin{pmatrix} e^{\mathrm{i}m\theta_{0}} \\ e^{\mathrm{i}m\theta_{1}} \\ \dots \\ e^{\mathrm{i}m\theta_{d/2-1}} \end{pmatrix}

其中 θi=b2i/d\theta_{i}=b^{-2i/d}(d/2)1/2(d / 2)^{-1/2} 用于归一化。

然后将注意力机制中 RoPE 的贡献抽象为一个核函数 K(m,n)K(m,n),用于衡量位置 mmnn 之间的相似性,有

K(m,n)=Re(ϕ(m),ϕ(n))=Re(2dk=0d/21ei(mn)θk)=2di=0d/21cos(mn)θi\begin{align} K(m,n) & = \mathrm{Re}(\langle \phi(m),\phi(n) \rangle) \\ & = \mathrm{Re}\left( \frac{2}{d}\sum_{k=0}^{d/2-1} e^{\mathrm{i}(m-n)\theta_{k}} \right) \\ & = \frac{2}{d} \sum_{i=0}^{d/2-1} \cos(m-n)\theta_{i} \end{align}

只和 mnm-n 有关,为了简单,记核函数 K(m,n)=K(Δm)=K(Δm;b)K(m,n)=K(\Delta m)=K(\Delta m; b)f(x)=cos(Δmx)f(x)=\cos(\Delta m\cdot x),则 K(Δm)=i=0d/21f(θi)K(\Delta m)=\sum_{i=0}^{d/2-1}f(\theta_{i})。另外也可以发现,任意两个位置之间的相似性和所有频率 {θi}i=0,1,,d/21\{ \theta_{i} \}_{i=0,1,\dots,d/2-1} 相关,也就是说,

NOTE

模型内部对位置差异的感知,是由所有频率分量共同贡献的结果

在进行上下文长度扩展前后,希望核函数的性质尽可能保持不变,也就是说

NOTE

在新的、扩展后的上下文中,任意两个位置 mmnn 的核函数值 K(Δm;b)K(\Delta m;b'),与在原始上下文中、按比例缩放后的位置 m/s,n/sm / s, n / s 的核函数值 K(ms,ns;b)K\left(\frac{m}{s}, \frac{n}{s};b \right) 尽可能保持一致:

K(m,n;b)K(ms,ns;b)K(m,n;b') \approx K\left( \frac{m}{s}, \frac{n}{s};b \right)

i=0d/21cos(Δmθi)i=0d/21cos(Δmsθi)\sum_{i=0}^{d/2-1} \cos(\Delta m\cdot \theta'_{i}) \approx \sum_{i=0}^{d/2-1} \cos\left( \frac{\Delta m}{s} \theta_{i} \right)

这里 θi=b2i/d\theta_{i}'=b'^{-2i/d}。那么问题就转化为如何计算或近似出上式的左右两边(实际上只需要左边就行了),当向量维度 dd 足够大时,求和可以转化为积分近似,即:

K(Δm;b)=2di=0d/21f(θi)=2di=0d/21f(b2i/d)2d0d/2f(b2i/d)di=01f(bx)dx=01cos(Δmbx)dx\begin{align} K(\Delta m;b) & = \frac{2}{d} \sum_{i=0}^{d/2-1} f(\theta_{i}) \\ & = \frac{2}{d}\sum_{i=0}^{d/2-1} f(b^{-2i/d}) \\ & \approx \frac{2}{d} \int _{0}^{d/2} f(b^{-2i/d}) \, \mathrm{d}i \\ & = \int _{0}^{1}f(b^{-x}) \, dx =\int_{0}^{1}\cos(\Delta m\cdot b^{-x})\mathrm{d}x \end{align}

代入上式得到

01cos(Δmbx)dx01cos(Δmsbx)dx\int_0^1 \cos(\Delta m \cdot b'^{-x}) \mathrm{d}x \approx\int_0^1 \cos\left(\frac{\Delta m}{s} \cdot b^{-x}\right) \mathrm{d}x

谱密度函数

然而,求解这个积分等式非常困难,需要通过其他途径进行分析。将核函数 K(m,n)K(m,n) 生成的核矩阵 [K]mn[K]_{mn} 视为一个大型矩阵,从随机矩阵理论的视角,这个矩阵的谱特性(如谱密度、谱半径)与核函数的分析性质(如其傅里叶变换)紧密相关。保持谱特性的稳定是实现稳健上下文扩展的关键。我们来计算核函数 K(m,n)K(m,n) 的谱密度 K^(ω;b)\hat{K}(\omega;b)

K^(ω;b)=K(Δm;b)eiωΔmd(Δm)=(01cos(Δmbx)dx)d(Δm)=1201(eiΔmbx+eiΔmbx)eiωΔmdxd(Δm)=1201[eiΔm(bxω)dΔm+eiΔm(bx+ω)dΔm]dx\begin{align} \hat{K}(\omega;b) & =\int_{-\infty}^{\infty} K(\Delta m;b)e^{-\mathrm{i}\omega\Delta m} \, \mathrm{d}(\Delta m) \\ & = \int_{-\infty}^{\infty} \left( \int_{0}^{1} \cos(\Delta m\cdot b^{-x})\, \mathrm{d}x \right) \, \mathrm{d}(\Delta m) \\ & = \frac{1}{2} \int_{-\infty}^{\infty} \int_{0}^{1} (e^{\mathrm{i}\Delta m\cdot b^{-x}} + e^{-\mathrm{i}\Delta m\cdot b^{-x}}) e^{-\mathrm{i}\omega\Delta m} \, \mathrm{d}x\mathrm{d}(\Delta m) \\ & = \frac{1}{2} \int _{0}^{1} \left[ \int_{-\infty}^{\infty} e^{\mathrm{i}\Delta m(b^{-x}-\omega)} \, \mathrm{d}\Delta m+\int_{-\infty}^{\infty} e^{-\mathrm{i}\Delta m(b^{-x}+\omega)} \, \mathrm{d}\Delta m \right] \, \mathrm{d}x \end{align}

代入狄拉克函数

δ(k)=12πeikxdx\delta(k)=\frac{1}{2\pi}\int_{-\infty}^{\infty} e^{\mathrm{i}kx} \, \mathrm{d}x

并由于 δ(x)=δ(x)\delta(x)=\delta(-x),得到

K^(ω;b)=1201[2πδ(bxω)+2πδ((bx+ω))]dx=π01δ(ωbx)+δ(ω+bx)dx\begin{align} \hat{K}(\omega;b) & = \frac{1}{2}\int _{0}^{1} [2\pi\delta(b^{-x}-\omega) + 2\pi\delta(-(b^{-x}+\omega))] \, \mathrm{d}x \\ & = \pi \int _{0}^{1} \delta(\omega-b^{-x}) + \delta(\omega+b^{-x}) \, \mathrm{d}x \end{align}

只考虑正频率 ω>0\omega>0,由于 b>1b>10<x<10<x<1,有 bx>0b^{-x}>0,则 δ(ω+bx)=0\delta(\omega+b^{-x})=0,上式简化为:

K^(ω;b)=π01δ(ωbx)dx\hat{K}(\omega;b)=\pi \int_{0}^{1} \delta(\omega-b^{-x}) \mathrm{d}x

然后分析这个谱的支撑集。为了使谱密度 K^(ω;b)>0\hat{K}(\omega;b) > 0,则必然存在 x[0,1]x\in[0,1],使得 ωbx=0\omega-b^{-x}=0,解得 x=lnω/lnbx=-\ln \omega / \ln b。为了让 x[0,1]x\in[0,1],可以得到

1bω1\frac{1}{b} \leq \omega \leq 1

那么就得到:

IMPORTANT

理想化的 RoPE 核(在无限长序列上)其谱密度函数的支撑集为 [1b,1]\left[\frac{1}{b}, 1 \right]

接着我们使用狄拉克函数的性质:

g(x)δ(f(x))dx=ig(xi)f(xi)\int g(x)\delta(f(x)) \mathrm{d}x = \sum_{i} \frac{g(x_{i})}{\lvert f'(x_{i}) \rvert }

其中 xix_{i}f(x)=0f(x)=0 的根。代入后可以得到谱密度:

K^(ω;b)=πωlnb, for ω[1b,1]\hat{K}(\omega;b)=\frac{\pi}{\omega \ln b}\text{, for } \omega\in\left[ \frac{1}{b},1 \right]

它在低频端点 ω=1b\omega=\frac{1}{b} 处取最大值,在高频端点 ω=1\omega=1 处取最小值。至此,谱密度函数的性质已经分析完毕。

非理想情况

然而以上是理想情况——因为我们不知不觉中,在计算谱密度的时候,假定了傅里叶变换的存在性,即 Δm\Delta m 可以取得 ±\pm \infty,也即序列长度可以达到 ++\infty。这显然是不可能的。我们不得不回到现实情况中来,也即序列长度最多只能到达 LL。此时核函数 K(m,n)K(m,n) 的谱矩阵是一个 L×LL\times L 的矩阵,更准确的说,是一个托普利茨矩阵(常对角矩阵)8,记为 TL(f)T_{L}(f)

CAUTION

除此之外,还有一个让这种理想情况不可能达到的原因,想想是什么?(在后面会有揭晓)

TL(f)=(K(0)K(1)K(2)K(1L)K(1)K(0)K(0)K(2L)K(L1)K(L2)K(L3)K(0))T_{L}(f) = \begin{pmatrix} K(0) & K(-1) & K(-2) & \cdots & K(1-L) \\ K(1) & K(0) & K(0) & \cdots & K(2-L) \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ K(L-1) & K(L-2) & K(L-3) & \cdots & K(0) \end{pmatrix}

这个矩阵的行为完全由生成函数 ff,即上一步计算得到的谱密度函数 K^(ω;b)\hat{K}(\omega;b) 决定 9

接下来,我们需要使用一个研究大型托普利茨矩阵行列式渐近行为的强大的引理:强斯格极限定理 (Strong Szegő Limit Theorem) 10,如下

NOTE

若生成函数 f(θ)f(\theta) 满足某些条件,例如 f>0f>0 且足够光滑,则:

limLdet(TL(f))G(f)L=E(f)\lim_{ L \to \infty } \frac{\det(T_{L}(f))}{G(f)^{L}} = E(f)

其中:

  • G(f)G(f)ff 的几何平均数 (Geometric Mean),代表了行列式的体行为 (bulk behavior)。对于充分大的 LL,有 det(TL(f))G(f)L\det(T_{L}(f))\approx G(f)^{L}
G(f)=exp(12π02πlnf(eiθ)dθ)G(f)=\exp\left( \frac{1}{2\pi}\int _{0}^{2\pi}\ln f(e^{i\theta}) \, \mathrm{d}\theta \right)
  • E(f)E(f) 是一个和 ff 的光滑性相关的边界项,依赖于 lnf\ln f 的傅里叶系数
E(f)=exp(k=1klnf^k2)E(f)=\exp\left( \sum_{k=1}^{\infty} k\left\lvert\widehat{\ln f}_{k} \right\rvert^{2} \right)

其中 lnf^k\widehat{\ln f}_{k}lnf\ln f 的第 kk 个傅里叶系数。

将上述引理应用于解释扩展上下文长度的稳定性质:回顾我们求出的谱密度函数:

f(ω)=K^(ω;b)={πωlnb1bω10otherwisef(\omega)=\hat{K}(\omega;b) = \begin{cases} \frac{\pi}{\omega \ln b} & \frac{1}{b} \leq \omega \leq 1 \\ \\ 0 & \text{otherwise} \end{cases}

明显,在两侧端点 1b\frac{1}{b}11 处,ff 均出现了跳变,结合强斯格极限定理,可以得到如下结论。

IMPORTANT

由于谱密度函数 ff 天然具有跳变不连续的性质,在简单的拓展上下文长度 LL 时,模型的离散频率分辨率 1L\sim \frac{1}{L} 和谱密度函数 ff 两侧端点位置之间的关系发生了改变,系统进入了一个对这种不连续性极其敏感的区域。从而使矩阵 TL(f)T_L(f) 出现了离群特征值,同时 LL\to \inftyE(f)E(f) 没有良好的收敛性质,从而导致了简单扩展上下文长度 LL 的崩溃。

数学量对应概念
det(TL(f))\det(T_L(f))TL(f)T_L(f) 的(离群)特征值
E(f)ff 的(不)连续性

离群特征值的出现原因

可以看到,在上面的结论中,我们使用了一些描述性的语言,以便于当前能顺利理解。接下来,我们顺理成章的需要研究:

这种不连续性是如何导致托普利茨矩阵 TL(f)T_{L}(f) 的离群特征值的出现,以及这一现象何时发生? (但是这太困难了,已经超出了我的知识范围,只好请教 gemini 大人了😭,贴在下面等有缘人来确认正确性)

首先让我们描述离群特征值是什么。根据 Fisher–Hartwig 猜想 11(已被证明)以及 Harold Widom12 的后续工作,可知对于一个具有跳变不连续的生成函数 f(ω)f(\omega),其对应的托普利茨矩阵 TL(f)T_{L}(f) 的谱在 LL\to \infty 时呈现如下结构:

  • 连续谱 (Continuous Spectrum):绝大多数(几乎所有)的特征值会密集地分布在由 f(ω)f(\omega) 的值域构成的区间内。这个区间我们称之为体谱 (bulk spectrum)
  • 离散谱 (Discrete Spectrum):在生成函数 f(ω)f(\omega) 的不连续点处,可能会有一个或多个特征值脱离体谱,成为孤立的离群点。

不稳定性就等价于这些离群特征值的出现和它们的病态行为。

接下来,是最后的推导(虽然并不详细)。

Harold Widom12 对特定生成函数,即可以表示为一个光滑函数和一个区间指示函数的乘积的函数,证明了积分方程定理(Integral Equation Theorem for Outliers):当 LL\to \infty 时,大多数特征值位于所谓的体谱区间内(由函数值域决定),而离群特征值的渐近行为,可以被一个定义在有限区间上的积分算子 KL\mathcal{K}_L 的特征值精确描述。该积分算子的核一般与 sinc 核形式(如 sin(c(xy))/(xy)\sin(c(x-y)) / (x-y))相关。这意味着可以将一个复杂的、L×LL\times L 离散矩阵的谱问题,转化成一个更易于分析的连续积分算子 KL\mathcal{K}_{L} 的问题。准确来说,TL(f)T_{L}(f) 的行列式的渐近行为被 KL\mathcal{K}_{L} 的 Fredholm 行列式 13 det(IKL)\det(I-\mathcal{K}_{L}) 描述。

通过对这个等效积分算子 KL\mathcal{K}_{L} 的研究,数学家们发现它的谱(特别是其最大特征值)存在相变 (Phase Transition)14 现象。相变的发生与否,由一个关键的无量纲参数控制,我们称之为序参量,记为 β\beta

在进行接下来的推导之前,首先需要对原始 RoPE 核进行适当的“修正”:回归原始 RoPE 核:

K(Δm;b)=i=0d/21cos(Δmb2i/d)K(\Delta m;b)=\sum_{i=0}^{d/2-1} \cos(\Delta m\cdot b^{-2i/d})

i=0i=0 这一项,即 cos(Δm)\cos(\Delta m) 这一项,因为它并非平方可积,所以不满足连续谱分析的前提,应该排除这一项,即稳定部分由 i{1,2,,d/21}i\in \{ 1,2,\dots,d/2-1 \} 这些项构成。这样做改变了生成函数的精确形状和支撑集,但可以验证,其定性行为(如边界不连续性)依然存在。那么有效维度就只有 deff=d2d_{\text{eff}}=d-2 个了(排除了第 0,10,1 两个维度)。

进行了这样的修正之后,x=2k/dx=2k/d 的范围就缩小为了 [2/d,(d2)/d][2 / d, (d-2) / d],对应的频率 ω=bx\omega=b^{-x} 的范围变为 [b(d2)/d,b2/d][b^{-(d-2)/d}, b^{-2/d}],也就是新的支撑集。

通过分析与算子 KL\mathcal{K}_{L} 相关的潘勒韦方程 15 的解的渐近行为,并对支撑集下边界 b(d2)/db^{-(d-2)/d} 进行分析,得到与这个边界相关的系统特征长度 λchar\lambda_{\text{char}},和基频 bb、有效维数 deffd_{\text{eff}} 的关系为:

λcharbd2d\lambda_{\text{char}} \propto b^{\frac{d-2}{d}}

这个特征长度可以被理解为系统内部结构能够保持相干的最大距离。定义上面提到的序参量 β\beta 为系统外在尺寸 LL 和这个内在特征长度 λchar\lambda_{\text{char}} 的比值:

β=LλcharLbd2d\beta=\frac{L}{\lambda_{\text{char}}} \propto L\cdot b^{-\frac{d-2}{d}}

Harold Widom 和他的合作者一起证明了:离群特征值的出现,精确地发生在序参量 β\beta 跨越一个特定的临界值 βc\beta_{c}​ 的时候,即,当 β<βc\beta <\beta_{c}​ 时,系统处于稳定状态,没有离群特征值;当 β>βc\beta>\beta_{c}​ 时,系统失稳,离群特征值从体谱中分裂出来。因此,系统保持在临界稳定状态的条件是:

β=βc=constant\beta=\beta_{c}=\text{constant}

那么为了能够稳定的扩展上下文,就需要有

Lbd2d=CL\cdot b^{-\frac{d-2}{d}} =C

CC 是一个常数。重新整理得到:

bLdd2b \propto L^{\frac{d}{d-2}}

将新旧基频 bb'bb 代入 L=sLL'=s\cdot L,得到

b(sL)dd2=sdd2Ldd2b'\propto(s\cdot L)^{\frac{d}{d-2}} = s^{\frac{d}{d-2}}\cdot L^{\frac{d}{d-2}}

那么就有

b=sdd2bb'=s^{\frac{d}{d-2}}\cdot b

总结

至此,艺术已成。

至此,我们从理论上证明了:

NOTE

当扩展上下文长度 LL 时,基频 bb 需要满足 b=bsd/(d2)b'=b\cdot s^{d/(d-2)} 的变化方式,才能保证模型不会崩溃。

回顾我们在非理想情况一节中使用的描述性语言:

由于谱密度函数 ff 天然具有跳变不连续的性质,在简单的拓展上下文长度 LL 时,模型的离散频率分辨率 1L\sim \frac{1}{L} 和谱密度函数 ff 两侧端点位置之间的关系发生了改变,系统进入了一个对这种不连续性极其敏感的区域。从而使矩阵 TL(f)T_L(f) 出现了离群特征值,同时 LL\to \inftyE(f)E(f) 没有良好的收敛性质,从而导致了简单扩展上下文长度 LL 的崩溃。

现在可以理解:

  • 系统进入了一个对这种不连续性极其敏感的区域:就是指随着 LL 的增大,β>βc\beta > \beta_c,从而导致离群特征值的产生,也正是相变现象。
  • 模型的离散频率分辨率 1L\sim \frac{1}{L} 和谱密度函数 ff 两侧端点位置之间的关系发生了改变
    • 对长度 LL 有限的离散序列,频率空间就是由离散傅里叶变换 (DFT) 的频率点构成,即 ΩL={ωk=2πkLk=0,1,2,,L1}\Omega_L = \{ \omega_k = \frac{2\pi k}{L} | k=0,1,2,\dots, L-1 \},格点分辨率 Δω=2πL1L\Delta \omega = \frac{2\pi}{L}\sim \frac{1}{L}
    • 随着 LL 的增大,如果 ff 满足良好的性质,那么 TL(fb)T_L(f_b) 的特征值会越来越精确地填充到体谱中;然而 ff 存在跳跃不连续点,所以这种情况不可能发生,从而导致离群特征值出现的必然性。说“关系”发生改变,也正是说离散特征值不能精确填充到体谱中去,逐渐变成了离群特征值。

总结一下我们的证明结论:

NOTE

对于一个由固定的谱密度函数 fbf_b(其支撑集 IbI_b 的边界存在跳变不连续)生成的托普利茨算子/矩阵序列 {TL(fb)}LZ+\{ T_L(f_b) \}_{L\in \mathbb{Z}^+},存在一个由上下文长度 LL 和系统参数 (b,d)(b ,d) 共同决定的无量纲序参量 βLb(d2)/d\beta \propto L \cdot b^{-(d-2)/d}。当 LL 增大并使得 β\beta 超过一个临界阈值 βc\beta_c 时,系统发生相变:TL(fb)T_L(f_b) 的谱结构发生定性改变,一个或多个离群特征值会从体谱 IbI_b 中分裂出来。这个相变点 Lcritb(d2)/dL_{\text{crit}} \propto b^{(d-2)/d} 定义了在不修改基数 bb 的情况下,上下文长度能够保持稳定的上限。因此,当扩展上下文长度 LL 时,基频 bb 需要满足 b=bsd/(d2)b'=b\cdot s^{d/(d-2)} 的变化方式,才能保证模型不会崩溃。

Footnotes

  1. 代码注释来自 https://zhuanlan.zhihu.com/p/647109286

  2. https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ 2

  3. https://huggingface.co/Qwen/Qwen2.5-Math-7B/blob/main/config.json

  4. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/

  5. https://github.com/jquesnelle/yarn/pull/1

  6. https://arxiv.org/abs/2309.00071

  7. https://arxiv.org/abs/1806.07572 2

  8. https://zh.wikipedia.org/zh-hans/%E5%B8%B8%E5%B0%8D%E8%A7%92%E7%9F%A9%E9%99%A3

  9. 这里对 ffθ\theta 的使用有点泛滥,请不要将此处的 ffθ\theta 和上文中的 ff 相混淆

  10. https://en.wikipedia.org/wiki/Szeg%C5%91_limit_theorems

  11. https://www.sciencedirect.com/science/article/pii/0024379594901872

  12. https://en.wikipedia.org/wiki/Harold_Widom 2

  13. https://en.wikipedia.org/wiki/Fredholm_determinant

  14. https://en.wikipedia.org/wiki/Phase_transition

  15. https://en.wikipedia.org/wiki/Painlev%C3%A9_transcendents