nn.Linear() 是 PyTorch 中定义全连接层(也称为线性层、全连接层或密集层)的函数。它的作用是对输入数据进行线性变换,也就是矩阵乘法操作。具体地说,它的作用是对输入向量进行线性变换并加上偏置,从而输出一个新的向量。

函数定义

torch.nn.Linear(in_features, out_features, bias=True)

参数说明

  1. in_features:输入特征的维度,即输入张量的最后一维的大小。它表示输入向量的长度。

    • 例如:如果输入是形状为 (batch_size, in_features) 的张量,则 in_features 指定为这个输入张量的特征数。
  2. out_features:输出特征的维度,即输出张量的最后一维的大小。它表示线性层将输入向量转换后的长度。

    • 例如:如果我们希望将输入的 in_features 维度变为 out_features 维度,则这个参数指定为输出向量的长度。
  3. bias:布尔值,默认为 True。如果 bias=True,则该层会有一个偏置向量(bias term),否则没有。

    • bias 会对输出结果加上一个偏置项,通常用于提升模型表现。如果 bias=False,则该线性层只进行矩阵乘法。

工作原理

  • nn.Linear 层实际上实现了线性变换: y=xW+b 其中:
    • x 是输入向量,大小为 in_features
    • W是权重矩阵,大小为 (out_features, in_features)
    • b 是偏置向量,大小为 out_features,只有在 bias=True 时存在。
    • y 是输出向量,大小为 out_features

这个线性变换相当于对输入的向量进行一次仿射变换:先通过矩阵 WW 进行变换,再加上偏置 b。

举例

假设我们有一个输入的维度是 10,输出的维度是 5,且有偏置项。我们可以定义一个线性层:


import torch
import torch.nn as nn

# 定义线性层:输入维度为10,输出维度为5
linear = nn.Linear(10, 5)

# 创建输入张量,假设输入是一个 batch 中的 3 个样本,每个样本维度为 10
input_data = torch.randn(3, 10)

# 前向传播
output_data = linear(input_data)

# 输出的形状是 (3, 5),即 batch 中每个样本经过线性层后变成 5 维
print(output_data.shape)  # torch.Size([3, 5])

权重和偏置

  • nn.Linear 层中有两个可学习的参数:权重 WWW 和偏置 bbb(如果有)。
    • linear.weight 是权重矩阵,形状为 (out_features, in_features)
    • linear.bias 是偏置向量,形状为 (out_features,),在 bias=True 的情况下存在。

可以通过访问这些参数来查看或者修改它们:


print(linear.weight) # 输出权重矩阵 
print(linear.bias) # 输出偏置向量

总结

nn.Linear() 是 PyTorch 中非常常见的层,它对输入进行线性变换,通常作为网络的全连接层,尤其是在神经网络的分类器部分。它的核心作用是将输入向量通过矩阵乘法变换到新的空间,并且可以加上偏置来增强表达能力。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐