Position Interpolation

使用RoPE,以长度为L训练完成模型后,当输入超过L会发生性能剧烈下降,一些论文提出可以通过给模型喂一些长度大于L的输入来微调模型,进而逐步将原始窗口长度扩大,但其代价和成效不佳(实验结果如下图4-1)。#card
image.png

思路:将超出L部分编码值压缩到L内。#card

  • 上图4-2中,外推bound的计算是用了aber变换

image.png

论文针对外推(Extrapolation)和内插( interpolation),也给出了一些实验,见图4-3: #card
image.png

图4-3的实验代码如下: #card

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import matplotlib.pyplot as plt

# build basis function
d = 4096 // 32
theta = 10000
# Frequency computation
freqs = 1.0 / (theta ** (torch.arange(0, d, 2).float() / d))
# construct basis function
L = 2048
x = torch.arange(0, L)
# basis functions
xfreq = torch.outer(x, freqs)
print(xfreq.shape)
y = torch.randn(x.shape[0])
# do linear regression
X = torch.cat([xfreq.sin(), xfreq.cos()], dim=1)

eps = 1e-5 # small regularization term
# 实现线性回归使得X*coeffs很好的逼近y,最终求解coeffs
coeffs = torch.linalg.solve(X.t() @ X + torch.eye(X.shape[1]) * eps, X.t() @ y)

x2 = torch.arange(0, 2*L)
xfreq2 = torch.outer(x2, freqs)
X2 = torch.cat([xfreq2.sin(), xfreq2.cos()], dim=1)
y2 = X2 @ coeffs

x3 = torch.arange(25, 75, 0.125)
xfreq3 = torch.outer(x3, freqs)
X3 = torch.cat([xfreq3.sin(), xfreq3.cos()], dim=1)
y3 = X3 @ coeffs

plt.figure(figsize=(16, 5))

plt.subplot(1, 3, 1)
plt.plot(x2[:L], y2[:L], "r")
plt.scatter(x, y)
plt.ylabel("attention score $a(s)$")
plt.xlabel("Positional difference $s$")

plt.subplot(1, 3, 2)
plt.plot(x2, y2, "r")
plt.scatter(x, y)
plt.axvline(L, color="k", linestyle="--", linewidth=0.5)
plt.title("Effect of Extrapolation")
plt.xlabel("Positional difference $s$")

plt.subplot(1, 3, 3)
plt.plot(x3, y3, "r")
for i in range(25, 75):
plt.axvline(i, color="k", linestyle="--", linewidth=0.5)
plt.title("Effect of Interpolation")
plt.xlabel("Positional difference $s$")
plt.savefig('PI.png',dpi=300, bbox_inches='tight')
# plt.show()
作者

Ryen Xiang

发布于

2025-04-20

更新于

2025-04-20

许可协议


网络回响

评论