论文来自:https://arxiv.org/pdf/2305.15213.pdf

github项目:https://github.com/NWUzhouwei/GTNet

安装

配置如下所示

python=3.8
torch=2.1.0+cu121
torchvision=0.16.0
torchaudio=2.1.0
h5py=3.11.0
opencv-python=4.10.0
plyfile=1.0.3
scikit-learn=1.3.2
tqdm=4.66.5
numba=0.58.1

数据预处理

以S3DIS为例

打开data.py

如果没有下载数据集的话可以将135-137行的代码注释取消掉

    # print("******************")
    # download_S3DIS()
    # print("**************")

憨厚在大概最后的__name__ == ‘main’处修改,注释掉其它两个数据集的处理

if __name__ == '__main__':
    # train = ModelNet40(1024)
    # test = ModelNet40(1024, 'test')
    # data, label = train[0]
    # print(data.shape)
    # print(label.shape)
    #
    # trainval = ShapeNetPart(2048, 'trainval')
    # test = ShapeNetPart(2048, 'test')
    # data, label, seg = trainval[0]
    # print(data.shape)
    # print(label.shape)
    # print(seg.shape)

    train = S3DIS(4096)
    test = S3DIS(4096, 'test')
    data, seg = train[0]
    print(data.shape)
    print(seg.shape)

由于代码路径写得比较死,这里首先需要到第132行修改load_data_semseg函数以及121行的prepare_test_data_semseg函数的数据路径

第二,还需要修改prepare_data\\collect_indoor3d_data.py和prepare_data\\gen_indoor3d_h5.py的数据路径

第三,indoor3d_util.py中187-192行的注释需要取消掉

顺便吐槽一下,gen_indoor3d_h5.py中23行有个低级错误,meta给打成mete了

filelist = os.path.join(BASE_DIR, 'mete\\all_data_label.txt')

完成修改后运行data.py

开始训练

完成预处理操作后,开始运行main_cls.py,需要调整参数eval=False,同时注释掉215-222行的S3DISDataset的dataloader,并取消208-213行的注释改用S3DIS的dataloader

BUG和解决方案

1.FileNotFoundError: [Errno 2] Unable to synchronously open file 

原因是S3DISDataSet读取场景时读入了其它的文件,检查一下Area文件夹内除了文件夹是否还包含其它文件。

也可以通过在data.py的395行以下加入以下代码

if not os.path.isdir(os.path.join(area, scene)):
   continue

2.RuntimeError: selected index k out of range

这里的原因是在transformer_divide.py内50行的get_graph_feature函数传入的x是(batch_size,num_points,feature_size)的形式,但是传入的tensor维度反了。

解决方法将num_points修改为如下所示:

num_points = x.size(1)

3.RuntimeError: mat1 and mat2 shapes cannot be multiplied (36x4096 and 9x96)

原因在transformer_divide.py的128行,feature张量的shape搞反了,直接取消注释189行的:

res = res.permute(0, 2, 1)

4.RuntimeError: view size is not compatible with input tensor's size and stride

原因在transformer_divide.py的53行,view()需要连续地址的Tensor,可以改成下列所示:

x = x.contiguous().view(batch_size, -1, num_points)

5.RuntimeError: Given groups=1, weight of size [1024, 384, 1], expected input[4, 16384, 96] to have 384 channels, but got 16384 channels instead

修改model.py的440行为如下所示:

x = torch.cat((x1, x2, x3, x4), dim=2).permute(0, 2, 1)

同时取消注销448行和450行代码并修改为如下所示:

x = x.permute(0, 2, 1)
x = torch.cat((x, x1, x2, x3, x4), dim=2)  # (batch_size, 1024+64*3, num_points)
x = x.permute(0, 2, 1)

完成Debug,开始训练

实验结果

DataSets mAcc mious
S3DIS-Area5 0.554 0.480 
ShapeNet 0.753 0.823
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐