就是因为CelebA好像没有标准的tensorflow库(虽然又好像有,但是官网那不清不楚的文档还是惹不起惹不起)。本代码一次性读取所有202600张图并中心裁切为64*64大小(也可以自己限制数量),同时读取list_attr标签txt文档并将其处理为one-hot编码,最终均以list返回。
import cv2
import scipy
import tensorflow as tf
# from sklearn import preprocessing
from tensorflow.keras.preprocessing import image
import numpy as np
import os
import math
import matplotlib.pyplot as plt
from keras.utils import np_utils
import imageio
from PIL import Image
celebAPath = "/Users/xueluopoi/Desktop/机器学习/科研/celebA/celebAImage/" # celebA的图片
celebAAttrPath = "list_attr.txt" # celebA的one-hot标签文件
def get_image(image_path, input_height=64, input_width=64,
resize_height=64, resize_width=64,
crop=True, grayscale=True):
images = []
for foot in range(1, 100):
fileNo = str(foot)
fileNo = fileNo.zfill(6) # 补齐6位
imagePath = image_path + fileNo + ".jpg"
image = imread(imagePath, grayscale)
image = tf.expand_dims(image, 2)
cropImage = transform(image, input_height, input_width,
resize_height, resize_width, crop)
images.append(cropImage)
return images
def imread(path, grayscale=True):
if (grayscale):
grayImage = Image.open(path).convert('L')
arrayImage = np.array(grayImage)
return arrayImage.astype('float32')
else:
return cv2.imread(path).astype(np.float)
def transform(image, input_height, input_width,
resize_height=64, resize_width=64, crop=True):
if crop:
cropped_image = center_crop(
image, input_height, input_width,
resize_height, resize_width)
else:
cropped_image = tf.image.resize(image, [resize_height, resize_width])
return np.array(cropped_image) / 255
def center_crop(x, crop_h, crop_w,
resize_h=64, resize_w=64):
if crop_w is None:
crop_w = crop_h
h, w = x.shape[:2]
j = int(round((h - crop_h) / 2.))
i = int(round((w - crop_w) / 2.))
returnImage = tf.image.resize(x[j:j + crop_h, i:i + crop_w], [resize_h, resize_w])
return returnImage
def handleAttr(path):
# 处理标签
label = []
with open(path, "r") as Attr_file:
Attr_info = Attr_file.readlines()
Attr_info = Attr_info[2:]
index = 0
for line in Attr_info:
index += 1
info = line.split() # 切分标签内容
info = info[1:] # 删掉文件名称
oneHot = [int(num) for num in info]
label.append(oneHot)
return label
def main():
celebA_image = get_image(celebAPath)
celebA_label = handleAttr(celebAAttrPath)
num = 0
celebA_label = tf.clip_by_value(celebA_label, 0, 1)
print(celebA_label[0:10])
num = 1
if __name__ == '__main__':
main()