一、在pytorch中紧凑画出子图

(1)在一行里画出多张图像和对应标签

1)代码
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms
from IPython import display

np.set_printoptions(threshold=100000000)

mnist_train = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True,
                                                transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True,
                                               transform=transforms.ToTensor())


def use_svg_display():
    """Use svg format to display plot in jupyter"""
    display.set_matplotlib_formats('svg')


# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]


# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):
    use_svg_display()
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        # np.squeeze将(1,28,28)→(28,28)
        f.imshow(np.squeeze(img.numpy()))
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()


# 完成了torch.utils.data.DataLoader的功能
X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

2)效果展示
色偏原因分析:

(2)以矩阵的形式展示多张图片

1)代码
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms


trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                        download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=49,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                       download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=49,
                                         shuffle=False, num_workers=0)

classes = ('t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot')


# functions to show an image
def imshow(img):
    print(img.shape)
    # img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.shape)

# show images
imshow(torchvision.utils.make_grid(images, nrow=7, padding=1))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(49)))

2)效果展示

在这里插入图片描述

二、在matplotlib中紧凑画出子图

(1)区分 subplot 和 subplots

  • plt.subplot()在指定分割子图个数和定位子图时可以使用参数连写的方式如:plt.subplot(221)
  • plt.subplots(m,n)返回的值的类型为元组,其中包含两个元素:第一个为一个画布fig,第二个是子图ax,结果是一个mxn的矩阵,调用时要XXX=ax.[i,j]来调用。

(2)代码

import matplotlib.pyplot as plt

# # 写法一:
# # 定义fig
# fig = plt.figure()
# # 建立子图
# ax = fig.subplots(2, 2)  # 2*2

# 写法二:
fig, ax = plt.subplots(2, 2)

# 第一个图为
ax[0, 0].plot([2, 1], [3, 4])
# 第二个图为
ax[0, 1].plot([1, 2], [3, 4])
# 第三个图为
ax[1, 0].plot([1, 2], [4, 3])
# 第四个图为
ax[1, 1].plot([1, 2], [3, 4])
plt.show()

(3)效果展示

在这里插入图片描述

三、手动将一个文件夹下的图片拼接在一起

import PIL.Image as Image
import os

IMAGES_PATH = './loop_img/'  # 图片集地址
IMAGES_FORMAT = ['.jpg', '.JPG', '.png']  # 图片格式
IMAGE_SIZE = 256  # 每张小图片的大小
IMAGE_ROW = 2  # 图片间隔,也就是合并成一张图后,一共有几行
IMAGE_COLUMN = 2  # 图片间隔,也就是合并成一张图后,一共有几列
IMAGE_SAVE_PATH = 'final.jpg'  # 图片转换后的地址

# 获取图片集地址下的所有图片名称
image_names = [name for name in os.listdir(IMAGES_PATH) for item in IMAGES_FORMAT if
               os.path.splitext(name)[1] == item]

# 简单的对于参数的设定和实际图片集的大小进行数量判断
if len(image_names) != IMAGE_ROW * IMAGE_COLUMN:
    raise ValueError("合成图片的参数和要求的数量不能匹配!")


# 定义图像拼接函数
def image_compose():
    to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE, IMAGE_ROW * IMAGE_SIZE))  # 创建一个新图
    # 循环遍历,把每张图片按顺序粘贴到对应位置上
    for y in range(1, IMAGE_ROW + 1):
        for x in range(1, IMAGE_COLUMN + 1):
            from_image = Image.open(IMAGES_PATH + image_names[IMAGE_COLUMN * (y - 1) + x - 1]).resize(
                (IMAGE_SIZE, IMAGE_SIZE), Image.ANTIALIAS)
            to_image.paste(from_image, ((x - 1) * IMAGE_SIZE, (y - 1) * IMAGE_SIZE))
    return to_image.save(IMAGE_SAVE_PATH)  # 保存新图


image_compose()  # 调用函数
  • 结果
    在这里插入图片描述

参考:

使用python将多张图片拼接成大图

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐