【pytorch + matplotlib】将若干张图像拼接成一张图像(附代码,以FashionMNIST为例)(subplot 和 subplots区别)
文章目录一、在pytorch中紧凑画出子图(1)在一行里画出多张图像和对应标签1)代码2)效果展示色偏原因分析:(2)以矩阵的形式展示多张图片1)代码2)效果展示二、在matplotlib中紧凑画出子图(1)区分 subplot 和 subplots(2)代码(3)效果展示一、在pytorch中紧凑画出子图(1)在一行里画出多张图像和对应标签1)代码import matplotlib.py...
·
文章目录
一、在pytorch中紧凑画出子图
(1)在一行里画出多张图像和对应标签
1)代码
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms
from IPython import display
np.set_printoptions(threshold=100000000)
mnist_train = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True,
transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True,
transform=transforms.ToTensor())
def use_svg_display():
"""Use svg format to display plot in jupyter"""
display.set_matplotlib_formats('svg')
# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):
use_svg_display()
# 这里的_表示我们忽略(不使用)的变量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
# np.squeeze将(1,28,28)→(28,28)
f.imshow(np.squeeze(img.numpy()))
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
# 完成了torch.utils.data.DataLoader的功能
X, y = [], []
for i in range(10):
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
2)效果展示
色偏原因分析:
- 下图每一张图像是由(28,28, 1)压缩(np.squeeze)到(28,28)的异色图像,
- 若要得到黑白图像,则应该将3个(28,28, 1)的图像拼接在一起就可以得到(28,28,3),绘制出来就是黑白图像了。
- 色偏问题原因:plt.imshow()在绘制2维图像时,0(最小值)显示深蓝色,1(最大值)显示黄色,其他数值显示由蓝到黄的过度颜色,所以(28,28)的图像画出来是蓝黄相间的。
- 具体的色偏色差问题分析请转至博客:【matplotlib + opencv】关于opencv和matplotlib绘制图像时,出现色差色偏的问题探讨,思考,解决。(plt.imshow()绘制的图像底色偏绿蓝偏黄)

(2)以矩阵的形式展示多张图片
1)代码
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=49,
shuffle=True, num_workers=0)
testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=49,
shuffle=False, num_workers=0)
classes = ('t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot')
# functions to show an image
def imshow(img):
print(img.shape)
# img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.shape)
# show images
imshow(torchvision.utils.make_grid(images, nrow=7, padding=1))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(49)))
2)效果展示

二、在matplotlib中紧凑画出子图
(1)区分 subplot 和 subplots
- plt.subplot()在指定分割子图个数和定位子图时可以使用参数连写的方式如:plt.subplot(221)
- plt.subplots(m,n)返回的值的类型为元组,其中包含两个元素:第一个为一个画布fig,第二个是子图ax,结果是一个mxn的矩阵,调用时要XXX=ax.[i,j]来调用。
(2)代码
import matplotlib.pyplot as plt
# # 写法一:
# # 定义fig
# fig = plt.figure()
# # 建立子图
# ax = fig.subplots(2, 2) # 2*2
# 写法二:
fig, ax = plt.subplots(2, 2)
# 第一个图为
ax[0, 0].plot([2, 1], [3, 4])
# 第二个图为
ax[0, 1].plot([1, 2], [3, 4])
# 第三个图为
ax[1, 0].plot([1, 2], [4, 3])
# 第四个图为
ax[1, 1].plot([1, 2], [3, 4])
plt.show()
(3)效果展示

三、手动将一个文件夹下的图片拼接在一起
import PIL.Image as Image
import os
IMAGES_PATH = './loop_img/' # 图片集地址
IMAGES_FORMAT = ['.jpg', '.JPG', '.png'] # 图片格式
IMAGE_SIZE = 256 # 每张小图片的大小
IMAGE_ROW = 2 # 图片间隔,也就是合并成一张图后,一共有几行
IMAGE_COLUMN = 2 # 图片间隔,也就是合并成一张图后,一共有几列
IMAGE_SAVE_PATH = 'final.jpg' # 图片转换后的地址
# 获取图片集地址下的所有图片名称
image_names = [name for name in os.listdir(IMAGES_PATH) for item in IMAGES_FORMAT if
os.path.splitext(name)[1] == item]
# 简单的对于参数的设定和实际图片集的大小进行数量判断
if len(image_names) != IMAGE_ROW * IMAGE_COLUMN:
raise ValueError("合成图片的参数和要求的数量不能匹配!")
# 定义图像拼接函数
def image_compose():
to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE, IMAGE_ROW * IMAGE_SIZE)) # 创建一个新图
# 循环遍历,把每张图片按顺序粘贴到对应位置上
for y in range(1, IMAGE_ROW + 1):
for x in range(1, IMAGE_COLUMN + 1):
from_image = Image.open(IMAGES_PATH + image_names[IMAGE_COLUMN * (y - 1) + x - 1]).resize(
(IMAGE_SIZE, IMAGE_SIZE), Image.ANTIALIAS)
to_image.paste(from_image, ((x - 1) * IMAGE_SIZE, (y - 1) * IMAGE_SIZE))
return to_image.save(IMAGE_SAVE_PATH) # 保存新图
image_compose() # 调用函数
- 结果

参考:
更多推荐


所有评论(0)