项目的网页

项目源代码

目录

引言

算法原理

一、网络结构

二、图像预处理和后处理

源代码解析

一、目录结构

二、demo_release.py

三、eccv16.py siggraph17.py

四、util.py

结果展示


引言

Colorful Image Colorization是一个较为稳定的黑白图片上色算法,其原理和代码也相对简单。在ECCV 2016上展示了第一代模型,后来又在SIGGRAPH 2017发布了改进之后的模型。优化后的模型更加稳定,上色效果也更加接近现实。

算法原理

一、网络结构

        算法的主要原理是,将一张LAB格式灰度图片的L通道,输入模型进行推理,推理出其AB通道,最后将原始的L通道和推理出的AB通道结合起来,得到一张上色的图片。

        一般常见的图片格式是RGB,有三个通道分别表示红色、绿色、蓝色。三个颜色组合出各种不同的颜色。而LAB图片格式的L通道表示图像的亮度,取值范围为0到100,颜色越大表示颜色越亮。AB的值域都是从-128到+128,A代表从绿色到红色的分量,B代表从蓝色到黄色的分量。

二、图像预处理和后处理

        所有输入的图像都会被resize成256×256的大小。为了最后输出图像的清晰度,所以需要保留原始输入图像的L通道。将图像转换成LAB格式,然后提取L通道,输入模型。输入模型的size应该是1×1×256×256。

        经过模型推理,输出的size应该是1×2×256×256。将预测的AB通道resize成原始图像大小,然后和原始图像的L通道合并,得到最后的结果。

源代码解析

一、目录结构

colorizers中是主要的模型代码以及图片处理的代码

imgs 存放输入图片

imgs_out 存放输出图片

demo_release.py 组织代码,运行程序的文件

二、demo_release.py

import argparse
import matplotlib.pyplot as plt

from colorizers import *

parser = argparse.ArgumentParser()
parser.add_argument('-i','--img_path', type=str, default='imgs/ansel_adams3.jpg') #输入图片路径
parser.add_argument('--use_gpu', action='store_true', help='whether to use GPU') #是否使用GPU
parser.add_argument('-o','--save_prefix', type=str, default='saved', help='will save into this file with {eccv16.png, siggraph17.png} suffixes') #保存图片路径
opt = parser.parse_args()

# load colorizers
colorizer_eccv16 = eccv16(pretrained=True).eval()  #初始化eccv16网络
colorizer_siggraph17 = siggraph17(pretrained=True).eval()  #初始化siggraph17网络
if(opt.use_gpu):
	colorizer_eccv16.cuda()
	colorizer_siggraph17.cuda()

img = load_img(opt.img_path)  #加载图片
(tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256))  #图片预处理
if(opt.use_gpu):
	tens_l_rs = tens_l_rs.cuda()

img_bw = postprocess_tens(tens_l_orig, torch.cat((0*tens_l_orig,0*tens_l_orig),dim=1))
out_img_eccv16 = postprocess_tens(tens_l_orig, colorizer_eccv16(tens_l_rs).cpu())  #推理和图片后处理
out_img_siggraph17 = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu())   #推理和图片后处理

plt.imsave('%s_eccv16.png'%opt.save_prefix, out_img_eccv16)  #保存图片
plt.imsave('%s_siggraph17.png'%opt.save_prefix, out_img_siggraph17)  #保存图片

        在此文件中可以自定义输入输出文件路径。

三、eccv16.py siggraph17.py

        定义eccv16和siggraph17的网络结构,下载对应的权重文件。

四、util.py

from PIL import Image
import numpy as np
from skimage import color
import torch
import torch.nn.functional as F
from IPython import embed

def load_img(img_path):  # 加载图片
	out_np = np.asarray(Image.open(img_path))  #将图片转为numpy数组
	if(out_np.ndim==2):
		out_np = np.tile(out_np[:,:,None],3)
	return out_np

def resize_img(img, HW=(256,256), resample=3):  # 调整图片大小
	return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))

def preprocess_img(img_rgb_orig, HW=(256,256), resample=3):  #图片与处理

	img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)   #将图片调整为256×256
	
	img_lab_orig = color.rgb2lab(img_rgb_orig)    #rgb格式转换为lab格式
	img_lab_rs = color.rgb2lab(img_rgb_rs)

	img_l_orig = img_lab_orig[:,:,0]   #取图片的L通道
	img_l_rs = img_lab_rs[:,:,0]

	tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:]   #转换成tensor
	tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:]

	return (tens_orig_l, tens_rs_l)

def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'):   #图片后处理

	HW_orig = tens_orig_l.shape[2:]  #取原始图片的长宽
	HW = out_ab.shape[2:]

	if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):   #恢复到原始图片大小
		out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear')
	else:
		out_ab_orig = out_ab

	out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)   # L+AB
	return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))  #转换为rgb图片

结果展示

原图:

ECCV 2016:

SIGGRAPH 2017 :

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐