一、下载:

1、代码

https://gitee.com/zhang_minyi/unet-pytorch/

为了怕gitee断供,我还弄了百度网盘:

链接:https://pan.baidu.com/s/1GYBka-wzQMmisVwlnU40jA

提取码:u8tr

--来自百度网盘超级会员V2的分享

2、数据

训练所需的unet_voc.pth和unet_medical.pth可在百度网盘中下载。

链接: https://pan.baidu.com/s/1AUBpqsSgamoQGEYpNjJg7A 提取码: i3ck

VOC拓展数据集的百度网盘如下:

链接: https://pan.baidu.com/s/1BrR7AUM1XJvPWjKMIy2uEw 提取码: vszf

我也狡兔三窟,把数据搞了百度网盘:

链接:https://pan.baidu.com/s/1Zr4ThzjaWNQRrL7DT_Vsaw

提取码:jtqd

--来自百度网盘超级会员V2的分享

二、创建环境(注意python>=3.8才行)

conda create --name unet python=3.8

之后切换到unet环境下:

conda activate unet

三、安装pytorch

pip install torchvision-0.14.0+cu117-cp38-cp38-win_amd64.whl

链接:https://pan.baidu.com/s/1bdW2GmxD5N8NOujefEX6NQ

提取码:yrbk

--来自百度网盘超级会员V2的分享

pip install torch-1.13.0+cu117-cp38-cp38-win_amd64.whl

链接:https://pan.baidu.com/s/1XQh3fK-GYbjIO5Q7HJPEpQ

提取码:09zl

--来自百度网盘超级会员V2的分享

四、安装其他库

1、torchsummary

pip install torchsummary

五、修改一个bug

from torchvision.models.utils import load_state_dict_from_url

更改为:

from torch.hub import load_state_dict_from_url

这是版本导致的。

六、还有两个大坑

  1. train.py

预训练模型要改成True,他会自动下载预训练模型。我这里做个百度网盘吧

链接:https://pan.baidu.com/s/1C7OcXHTzENHm0KHp-SYS6A

提取码:g5qw

--来自百度网盘超级会员V2的分享

2. unet.py

你应该改成log文件夹下的最后一次训练pth

3.显示图片会出错

OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program.

怎么办?

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐