Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions Lotka-Volterra-Pytorch/efficient_kan/efficientkan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self.grid_size = grid_size
self.spline_order = spline_order

h = (grid_range[1] - grid_range[0]) / grid_size
h = (grid_range[1] - grid_range[0]) / grid_size==============构造knot节点:就是把定义域切成小区间的分割点。B-spline 需要这些分割点来定义基函数。
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
Expand All @@ -36,11 +36,11 @@ def __init__(
self.register_buffer("grid", grid)

self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
self.spline_weight = torch.nn.Parameter(
self.spline_weight = torch.nn.Parameter(=======================两个权重,线性权重和B-spline系数
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = torch.nn.Parameter(
self.spline_scaler = torch.nn.Parameter(==================每条边的缩放因子,控制函数的整体大小
torch.Tensor(out_features, in_features)
)

Expand All @@ -54,7 +54,7 @@ def __init__(
self.reset_parameters()

def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)=============基础线性权重:标准kaiming初始化
with torch.no_grad():
noise = (
(
Expand All @@ -63,18 +63,18 @@ def reset_parameters(self):
)
* self.scale_noise
/ self.grid_size
)
)=====================生成一种特殊设计的随机噪声,用于神经网络权重的随机初始化
self.spline_weight.data.copy_(
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
* self.curve2coeff(
self.grid.T[self.spline_order : -self.spline_order],
noise,
self.grid.T[self.spline_order : -self.spline_order],=================内部网格点
noise,==============对应的函数值
)
)
if self.enable_standalone_scale_spline:
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

========================为什么要如此初始化呢?如果直接赋值,得到的样条函数可能很不平滑,这样初始化函数是平滑的,幅度小的,有利于训练起步的
def b_splines(self, x: torch.Tensor):
"""
Compute the B-spline bases for the given input tensor.
Expand All @@ -91,7 +91,7 @@ def b_splines(self, x: torch.Tensor):
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)==================先判断x落在了哪个区间
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
Expand All @@ -101,7 +101,7 @@ def b_splines(self, x: torch.Tensor):
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)
)=================然后由低阶样条函数递推到高阶样条函数,直觉上就像是在模糊化函数一样,让函数变得更光滑

assert bases.size() == (
x.size(0),
Expand Down Expand Up @@ -134,12 +134,12 @@ def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)

====================================================================================本质上就是在解最小二乘,已知一些采样点和对应的函数值,去反推B-spline系数
assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
)==============================确认形状
return result.contiguous()

@property
Expand All @@ -155,12 +155,13 @@ def forward(self, x: torch.Tensor):
original_shape = x.shape
x = x.reshape(-1, self.in_features)

base_output = F.linear(self.base_activation(x), self.base_weight)
base_output = F.linear(self.base_activation(x), self.base_weight)=================这就是一个普通的"激活函数 + 线性变换",和 MLP 一样
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
output = base_output + spline_output
)=============================view(1, -1)展平成 (1, 16),把两个通道的基函数值拼起来

output = base_output + spline_output================合并起来

output = output.reshape(*original_shape[:-1], self.out_features)
return output
Expand Down Expand Up @@ -228,16 +229,16 @@ def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0)
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
regularization_loss_activation = l1_fake.sum()=======================L1:鼓励系数稀疏
p = l1_fake / regularization_loss_activation======================归一化成概率
regularization_loss_entropy = -torch.sum(p * p.log())================= 熵:鼓励少数边活跃
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)


class KAN(torch.nn.Module):
class KAN(torch.nn.Module):======================实现的是KAN网络怎么搭建
def __init__(
self,
layers_hidden,
Expand All @@ -255,7 +256,7 @@ def __init__(
self.spline_order = spline_order

self.layers = torch.nn.ModuleList()
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):==============layers_hidden = [2, 10, 2]时,会创建两个KANLinear层,如何forward里面依次通过,本质上还是和搭MLP一样只是计算方式不同而已
self.layers.append(
KANLinear(
in_features,
Expand All @@ -282,4 +283,4 @@ def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0)
return sum(
layer.regularization_loss(regularize_activation, regularize_entropy)
for layer in self.layers
)
)