在PyTorch中,nn.Linear 是一个用于实现线性变换(全连接层)的模块,其用法总结如下:

基本语法

torch.nn.Linear(in_features, out_features, bias=True)
  • in_features‌:输入张量的特征数最后一个维度的大小)。
  • out_features‌:输出张量的特征数
  • bias‌:是否包含偏置项(默认为 True)。

核心功能

对输入数据执行线性变换:

y=xAT+b

其中:

  • x:输入张量,形状为 (..., in_features)
  • A:权重矩阵,形状为 (out_features, in_features)
  • b:偏置项,形状为 (out_features,)(当 bias=True 时)。
  • 输出形状:(..., out_features)

使用示例

示例1:基础用法
import torch
import torch.nn as nn

# 定义线性层:输入4维,输出2维
linear_layer = nn.Linear(4, 2)

# 输入数据(3个样本,每个样本4个特征)
x = torch.randn(3, 4)  # 形状:(3, 4)

# 前向传播
output = linear_layer(x)
print(output.shape)  # 输出形状:(3, 2)

示例2:多维输入处理
# 输入形状:(batch, seq_len, in_features)
x = torch.randn(5, 10, 20)  # 5个样本,序列长度10,每个元素20维
linear_layer = nn.Linear(20, 30)
output = linear_layer(x)     # 输出形状:(5, 10, 30)

示例3:禁用偏置项
linear_layer = nn.Linear(4, 2, bias=False)
print(linear_layer.bias)  # 输出:None


参数初始化

自定义权重和偏置的初始化:

# 定义模型
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)
        # 使用Xavier均匀初始化权重
        nn.init.xavier_uniform_(self.linear.weight)
        # 偏置初始化为0
        if self.linear.bias is not None:
            nn.init.zeros_(self.linear.bias)

    def forward(self, x):
        return self.linear(x)


常见注意事项

  1. 输入维度匹配‌:输入张量的最后一个维度必须与 in_features 一致。

    • ❌ 错误示例:nn.Linear(10, 5) 处理形状为 (3, 3) 的输入。
    • ✅ 正确示例:输入形状应为 (..., 10)
  2. 多维输入支持‌:自动处理除最后一维外的其他维度(如批处理、序列长度等)。

  3. 设备一致性‌:确保模型和张量在同一设备(CPU/GPU)上。

  4. 展平操作‌:处理图像等数据时,需先展平空间维度:

    x = torch.flatten(x, start_dim=1)  # 展平为 (batch, features)
    

典型应用场景

  • 全连接神经网络。
  • 作为Transformer中的投影层。
  • 分类任务的最后一层(输出类别概率)。

通过灵活使用 nn.Linear,可以构建复杂的神经网络结构,适用于各种深度学习任务。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐