CelebA数据集读取+处理

August 13, 2021 · 深度学习 · 2182次阅读

功能说明

就是因为CelebA好像没有标准的tensorflow库(虽然又好像有,但是官网那不清不楚的文档还是惹不起惹不起)。本代码一次性读取所有202600张图并中心裁切为64*64大小(也可以自己限制数量),同时读取list_attr标签txt文档并将其处理为one-hot编码,最终均以list返回。

代码

import cv2
import scipy
import tensorflow as tf
# from sklearn import preprocessing
from tensorflow.keras.preprocessing import image
import numpy as np
import os
import math
import matplotlib.pyplot as plt
from keras.utils import np_utils
import imageio
from PIL import Image

celebAPath = "/Users/xueluopoi/Desktop/机器学习/科研/celebA/celebAImage/"  # celebA的图片
celebAAttrPath = "list_attr.txt"  # celebA的one-hot标签文件


def get_image(image_path, input_height=64, input_width=64,
              resize_height=64, resize_width=64,
              crop=True, grayscale=True):
    images = []
    for foot in range(1, 100):
        fileNo = str(foot)
        fileNo = fileNo.zfill(6)  # 补齐6位
        imagePath = image_path + fileNo + ".jpg"
        image = imread(imagePath, grayscale)
        image = tf.expand_dims(image, 2)
        cropImage = transform(image, input_height, input_width,
                              resize_height, resize_width, crop)
        images.append(cropImage)
    return images


def imread(path, grayscale=True):
    if (grayscale):
        grayImage = Image.open(path).convert('L')
        arrayImage = np.array(grayImage)
        return arrayImage.astype('float32')
    else:
        return cv2.imread(path).astype(np.float)


def transform(image, input_height, input_width,
              resize_height=64, resize_width=64, crop=True):
    if crop:
        cropped_image = center_crop(
            image, input_height, input_width,
            resize_height, resize_width)
    else:
        cropped_image = tf.image.resize(image, [resize_height, resize_width])
    return np.array(cropped_image) / 255


def center_crop(x, crop_h, crop_w,
                resize_h=64, resize_w=64):
    if crop_w is None:
        crop_w = crop_h
    h, w = x.shape[:2]
    j = int(round((h - crop_h) / 2.))
    i = int(round((w - crop_w) / 2.))
    returnImage = tf.image.resize(x[j:j + crop_h, i:i + crop_w], [resize_h, resize_w])
    return returnImage


def handleAttr(path):
    # 处理标签
    label = []
    with open(path, "r") as Attr_file:
        Attr_info = Attr_file.readlines()
        Attr_info = Attr_info[2:]
        index = 0
        for line in Attr_info:
            index += 1
            info = line.split()  # 切分标签内容
            info = info[1:]  # 删掉文件名称
            oneHot = [int(num) for num in info]
            label.append(oneHot)
    return label


def main():
    celebA_image = get_image(celebAPath)
    celebA_label = handleAttr(celebAAttrPath)
    num = 0
    celebA_label = tf.clip_by_value(celebA_label, 0, 1)
    print(celebA_label[0:10])
    num = 1


if __name__ == '__main__':
    main()

pythonMachine learning

最后编辑于3年前

添加新评论