基于Pytorch的ResNet垃圾图片分类
项目链接
数据集下载链接
1. 数据集预处理
1.1 画图片的宽高分布散点图
import osimport matplotlib.pyplot as plt
import PIL.Image as Imagedef plot_resolution(dataset_root_path):image_size_list = []#存放图片尺寸for root, dirs, files in os.walk(dataset_root_path):for file in files:image_full_path = os.path.join(root, file)image = Image.open(image_full_path)image_size = image.sizeimage_size_list.append(image_size)print(image_size_list)image_width_list = [image_size_list[i][0] for i in range(len(image_size_list))]#存放图片的宽image_height_list = [image_size_list[i][1] for i in range(len(image_size_list))]#存放图片的高plt.rcParams['font.sans-serif'] = ['SimHei']#设置中文字体plt.rcParams['font.size'] = 8plt.rcParams['axes.unicode_minus'] = False#解决图像中的负号乱码问题plt.scatter(image_width_list, image_height_list, s=1)plt.xlabel('宽')plt.ylabel('高')plt.title('图像宽高分布散点图')plt.show()if __name__ == '__main__':dataset_root_path = "F:\数据与代码\dataset"plot_resolution(dataset_root_path)
运行结果:
注: os.walk详细解释参考
1.2 画出数据集的各个类别图片数量的条形图
文件组织结构:
def plot_bar(dataset_root_path):file_name_list = []file_num_list = []for root, dirs, files in os.walk(dataset_root_path):if len(dirs) != 0 :for dir in dirs:file_name_list.append(dir)file_num_list.append(len(files))file_num_list = file_num_list[1:]#去掉根目录下面的文件数量(0) [0, 20, 1, 15, 23, 25, 22, 121, 7, 286, 233, 22, 27, 5, 6, 4]#[20, 1, 15, 23, 25, 22, 121, 7, 286, 233, 22,27, 5, 6, 4]mean = np.mean(file_num_list)print("mean= ", mean)bar_positions = np.arange(len(file_name_list))fig, ax = plt.subplots()ax.bar(bar_positions, file_num_list, 0.5)# 柱间的距离, 柱的值, 柱的宽度ax.plot(bar_positions, [mean for i in bar_positions], color="red")#画出平均线plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体plt.rcParams['font.size'] = 8plt.rcParams['axes.unicode_minus'] = False # 解决图像中的负号乱码问题ax.set_xticks(bar_positions)#设置x轴的刻度ax.set_xticklabels(file_name_list, rotation=98) #设置x轴的标签ax.set_ylabel("类别数量")ax.set_title("各个类别数量分布散点图")plt.show()
运行结果
1.3 删除宽高有问题的图片
import os
import PIL.Image as ImageMIN = 200
MAX = 2000
ratio = 0.5def delete_img(dataset_root_path):delete_img_list = [] #需要删除的图片地址for root, dirs, files in os.walk(dataset_root_path):for file in files:img_full_path = os.path.join(root, file)img = Image.open(img_full_path)img_size = img.sizemax_l = img_size[0] if img_size[0] > img_size[1] else img_size[1]min_l = img_size[0] if img_size[0] < img_size[1] else img_size[1]# 把图片宽高限制在 200~2000 这里可能会重复添加图片路径if img_size[0] < MIN or img_size[1] < MIN:delete_img_list.append(img_full_path)print("不满足要求", img_full_path, img_size)elif img_size[0] > MAX or img_size[1] > MAX:delete_img_list.append(img_full_path)print("不满足要求", img_full_path, img_size)#避免图片窄长elif min_l / max_l < ratio:delete_img_list.append(img_full_path)print("不满足要求", img_full_path, img_size)for img in delete_img_list:print("正在删除", img)os.remove(img)if __name__ == '__main__':dataset_root_img = 'F:\数据与代码\dataset'delete_img(dataset_root_img)
再次运行1.1 和1.2的代码得到处理后的数据集宽高分布和类别数量