Timm 加载本地 huggingface 模型
最近使用 Timm 自动加载在线 hf-hub 模型时,由于 huggingface 无法正常连接,在加载模型的时候是无法正常下载。解决办法就是本地电脑下载,再上传到服务器。以下载为例。
·
最近使用 Timm 自动加载在线 hf-hub 模型时,由于服务器存在网络限制 huggingface 无法正常连接,导致无法加载模型以及权重。解决办法就是本地电脑下载,再上传到服务器。
以下载 huggingface.co/MahmoodLab/UNI 为例。
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login
login() # login with your User Access Token, found at https://huggingface.co/settings/tokens
# pretrained=True needed to load UNI weights (and download weights for the first time)
# init_values need to be passed in to successfully load LayerScale parameters (e.g. - block.0.ls1.gamma)
model = timm.create_model("hf-hub:MahmoodLab/uni", pretrained=True, init_values=1e-5, dynamic_img_size=True)
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
model.eval()
本地下载
# 在可联网的机器运行,确保模型缓存
from huggingface_hub import snapshot_download
# 指定存储路径
download_path = "D:/Research/pre_training_models"
models = [
"MahmoodLab/uni"
]
for repo in models:
snapshot_download(repo_id=repo),
cache_dir=download_path
下载时可能存在一些获取 hf token 和对应模型库的邮箱认证等问题,可以自行 AI 获取解决步骤。
上传到服务器指定缓存目录
将文件复制到你的服务器 [用户名]/.cache/huggingface/hub
中,尝试复制到其他的路径发现 timm.create_model
无法正确识别,尽量还是放在该目录下。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model(model_name, device):
if model_name == 'UNI':
model = timm.create_model("hf-hub:MahmoodLab/uni", pretrained=True, init_values=1e-5, dynamic_img_size=True) # PMID:38504018
else:
raise NotImplementedError(f'Model {model_name} not implemented !')
return model.to(device).eval()
uni_model = load_model('UNI', device)
uni_transform = create_transform(**resolve_data_config(uni_model.pretrained_cfg, model=uni_model))
权限设置
Hugging Face 似乎对 Gated Repo 的访问会在每次加载时进行在线令牌校验,即使本地有缓存。所以可能需要在终端进行 login():
# bash
huggingface-cli login
输入你的 token,获取自 https://huggingface.co/settings/tokens 。
更多推荐
所有评论(0)