nn.Linear ()的用法
在PyTorch中,nn.Linear。
·
在PyTorch中,nn.Linear 是一个用于实现线性变换(全连接层)的模块,其用法总结如下:
基本语法
torch.nn.Linear(in_features, out_features, bias=True)
- in_features:输入张量的特征数(最后一个维度的大小)。
- out_features:输出张量的特征数。
- bias:是否包含偏置项(默认为
True)。
核心功能
对输入数据执行线性变换:
y=xAT+b
其中:
- x:输入张量,形状为
(..., in_features)。 - A:权重矩阵,形状为
(out_features, in_features)。 - b:偏置项,形状为
(out_features,)(当bias=True时)。 - 输出形状:
(..., out_features)。
使用示例
示例1:基础用法
import torch
import torch.nn as nn
# 定义线性层:输入4维,输出2维
linear_layer = nn.Linear(4, 2)
# 输入数据(3个样本,每个样本4个特征)
x = torch.randn(3, 4) # 形状:(3, 4)
# 前向传播
output = linear_layer(x)
print(output.shape) # 输出形状:(3, 2)
示例2:多维输入处理
# 输入形状:(batch, seq_len, in_features)
x = torch.randn(5, 10, 20) # 5个样本,序列长度10,每个元素20维
linear_layer = nn.Linear(20, 30)
output = linear_layer(x) # 输出形状:(5, 10, 30)
示例3:禁用偏置项
linear_layer = nn.Linear(4, 2, bias=False)
print(linear_layer.bias) # 输出:None
参数初始化
自定义权重和偏置的初始化:
# 定义模型
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
# 使用Xavier均匀初始化权重
nn.init.xavier_uniform_(self.linear.weight)
# 偏置初始化为0
if self.linear.bias is not None:
nn.init.zeros_(self.linear.bias)
def forward(self, x):
return self.linear(x)
常见注意事项
-
输入维度匹配:输入张量的最后一个维度必须与
in_features一致。- ❌ 错误示例:
nn.Linear(10, 5)处理形状为(3, 3)的输入。 - ✅ 正确示例:输入形状应为
(..., 10)。
- ❌ 错误示例:
-
多维输入支持:自动处理除最后一维外的其他维度(如批处理、序列长度等)。
-
设备一致性:确保模型和张量在同一设备(CPU/GPU)上。
-
展平操作:处理图像等数据时,需先展平空间维度:
x = torch.flatten(x, start_dim=1) # 展平为 (batch, features)
典型应用场景
- 全连接神经网络。
- 作为Transformer中的投影层。
- 分类任务的最后一层(输出类别概率)。
通过灵活使用 nn.Linear,可以构建复杂的神经网络结构,适用于各种深度学习任务。
更多推荐



所有评论(0)