nn.Linear()线性层(全连接层)
是 PyTorch 中非常常见的层,它对输入进行线性变换,通常作为网络的全连接层,尤其是在神经网络的分类器部分。它的核心作用是将输入向量通过矩阵乘法变换到新的空间,并且可以加上偏置来增强表达能力。
·
nn.Linear() 是 PyTorch 中定义全连接层(也称为线性层、全连接层或密集层)的函数。它的作用是对输入数据进行线性变换,也就是矩阵乘法操作。具体地说,它的作用是对输入向量进行线性变换并加上偏置,从而输出一个新的向量。
函数定义
torch.nn.Linear(in_features, out_features, bias=True)
参数说明
-
in_features:输入特征的维度,即输入张量的最后一维的大小。它表示输入向量的长度。- 例如:如果输入是形状为
(batch_size, in_features)的张量,则in_features指定为这个输入张量的特征数。
- 例如:如果输入是形状为
-
out_features:输出特征的维度,即输出张量的最后一维的大小。它表示线性层将输入向量转换后的长度。- 例如:如果我们希望将输入的
in_features维度变为out_features维度,则这个参数指定为输出向量的长度。
- 例如:如果我们希望将输入的
-
bias:布尔值,默认为True。如果bias=True,则该层会有一个偏置向量(bias term),否则没有。bias会对输出结果加上一个偏置项,通常用于提升模型表现。如果bias=False,则该线性层只进行矩阵乘法。
工作原理
nn.Linear层实际上实现了线性变换: y=xW+b 其中:- x 是输入向量,大小为
in_features。 - W是权重矩阵,大小为
(out_features, in_features)。 - b 是偏置向量,大小为
out_features,只有在bias=True时存在。 - y 是输出向量,大小为
out_features。
- x 是输入向量,大小为
这个线性变换相当于对输入的向量进行一次仿射变换:先通过矩阵 WW 进行变换,再加上偏置 b。
举例
假设我们有一个输入的维度是 10,输出的维度是 5,且有偏置项。我们可以定义一个线性层:
import torch
import torch.nn as nn
# 定义线性层:输入维度为10,输出维度为5
linear = nn.Linear(10, 5)
# 创建输入张量,假设输入是一个 batch 中的 3 个样本,每个样本维度为 10
input_data = torch.randn(3, 10)
# 前向传播
output_data = linear(input_data)
# 输出的形状是 (3, 5),即 batch 中每个样本经过线性层后变成 5 维
print(output_data.shape) # torch.Size([3, 5])
权重和偏置
nn.Linear层中有两个可学习的参数:权重 WWW 和偏置 bbb(如果有)。linear.weight是权重矩阵,形状为(out_features, in_features)。linear.bias是偏置向量,形状为(out_features,),在bias=True的情况下存在。
可以通过访问这些参数来查看或者修改它们:
print(linear.weight) # 输出权重矩阵
print(linear.bias) # 输出偏置向量
总结
nn.Linear() 是 PyTorch 中非常常见的层,它对输入进行线性变换,通常作为网络的全连接层,尤其是在神经网络的分类器部分。它的核心作用是将输入向量通过矩阵乘法变换到新的空间,并且可以加上偏置来增强表达能力。
更多推荐


所有评论(0)