人工智能图片分类Python小程序

个人小作业,虽说做的很差,也算是一个学习的转化;主要用于分类自己下载的壁纸

1 背景

学期末需要一个学习成果的展示,高难度的自己做不来,模型也跑不动(电脑有点渣),刚好自己也有图片分类的需求,最后决定做了这个,确实也算做了一个自己用得到的小程序

2 项目说明

2.1 项目需求

需要自动加载指定目录所有图片,自行迁移至指定目录并存入不同的文件夹

2.2 实现思路

  1. 数据来源于各大壁纸网站,通过下载分类好的图片免去了自己手动分类的痛苦
  2. 将图片进行微缩处理,将1920 × \times × 1080的图片转化为192 × \times × 108的尺寸,不然尺寸太大硬件吃不消。
  3. 第二步可以将图片转化为单通道,数据量会小很多,但是测试过程中发现数据集较小时准确率比直接使用三通道要高一些,但是数据集大之后三通道的图片识别更加准确
  4. 目前数据集是共10000多张图片共五个分类(差不多自己电脑的上限),通过第二步、第三步的三通道缩小处理后,所有数据集大小约600MB,还在接受范围内。
  5. 模型的搭建与其他模型搭建基本一致

3 项目说明

3.1 项目结构

│  colorUi.ui	正在使用的UI界面文件
│  fun.py		对于模型函数的初步封装,为PyQt界面提供支持
│  main.py		入口部分
│  model.py		模型的训练、加载
│  ui.py		正在使用的UI界面py文件
│  ui.ui		老的UI界面文件
│  utils.py		一些读取图片处理图片的函数
├─fun_test			内含各类图片共100张,用于最后的功能测试
├─make_data_set		用于处理制作数据集
├─model				训练好的模型存储的路径
├─test				内含处理好的数据集的测试集,存储格式是是numpy数组的序列化,三通道维度信息(N,108.,192,3);标签一维数组
├─test_pic		测试集原始数据目录,路径下各种图片独占一个目录,用于通过make_data_set制作数据集,目录应与train_pic对应
│  ├─dongman	其中一个分类
│  ├─dongwu		其中一个分类
│  ├─fengjing	其中一个分类
│  ├─meinv		其中一个分类
│  └─youxi		其中一个分类
├─train			内含处理好的数据集的训练集,存储格式是是numpy数组的序列化,三通道维度信息(N,108.,192,3);标签一维数组
└─train_pic├─dongman	其中一个分类├─dongwu	其中一个分类├─fengjing	其中一个分类├─meinv		其中一个分类└─youxi		其中一个分类

3.2 源码说明

3.2.1 模型的创建、加载、训练

import json
import osimport cv2
import numpy
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tqdm import tqdmfrom utils import img_resizedef init_network():"""初始化神经网络,支持五种类型:return: 模型"""model = tf.keras.Sequential([tf.keras.layers.Conv2D(filters=48, kernel_size=(3, 3), padding='same', activation='relu', strides=1,input_shape=(108, 192, 3)),tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),# 抑制过拟合tf.keras.layers.Dropout(rate=0.6),tf.keras.layers.Conv2D(filters=24, kernel_size=(3, 3), padding='same', activation='relu', strides=1),# 2*2池化取最大值tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),# 抑制过拟合tf.keras.layers.Dropout(rate=0.6),# 维度拉伸成1维tf.keras.layers.Flatten(),# 第二层隐藏层,使用relu激活函数tf.keras.layers.Dense(256, activation='relu'),# 抑制过拟合tf.keras.layers.Dropout(rate=0.6),tf.keras.layers.Dense(256, activation='relu'),tf.keras.layers.Dropout(rate=0.5),tf.keras.layers.Dense(256, activation='relu'),tf.keras.layers.Dropout(rate=0.5),# 输出层tf.keras.layers.Dense(5, activation='softmax')])model.compile(loss='categorical_crossentropy',optimizer='adam', metrics=['accuracy'])model.summary()return modeldef getTrainData():"""获取训练集数据:return: train_images, train_labels, class_names"""fp = open('./train/train.json', 'r', encoding='utf8')class_names = json.load(fp)['support']fp.close()# 返回加载来的数据集pic_train_images = numpy.load('./train/train_pic.npy')train_images = pic_train_images.reshape(pic_train_images.shape[0], 108, 192, 3) / 255.0print(train_images.shape)train_labels = numpy.load('./train/train_labels.npy')print(numpy.load('./train/train_labels.npy').shape)return train_images, train_labels, class_namesdef getTestData():"""获取测试集包的数据:return: train_images, train_labels, class_names"""fp = open('./test/test.json', 'r', encoding='utf8')class_names = json.load(fp)['support']fp.close()# 返回加载来的数据集pic_test_images = numpy.load('./test/test_pic.npy')test_images = pic_test_images.reshape(pic_test_images.shape[0], 108, 192, 3) / 255.0print(test_images.shape)test_labels = numpy.load('./test/test_labels.npy')print(numpy.load('./test/test_labels.npy').shape)return test_images, test_labels, class_namesdef getTestImages():"""加载测试集1920*1080的壁纸"""path = './test_pic'imgs = []labels = []k = 0paths = os.listdir(path)paths.sort()for j in paths:pbar = tqdm(total=100)for i in os.listdir(path + '/' + j):pbar.update(100.0 / len(os.listdir(path + '/' + j)))pic_path = path + '/' + j + '/' + i# img = img_resize(cv2.imread(pic_path, cv2.IMREAD_GRAYSCALE))img = img_resize(cv2.imread(pic_path))if img.shape[0] != 108 or img.shape[1] != 192:os.remove(pic_path)continueimgs.append(img)labels.append(k)pbar.close()k = k + 1pic_test_images = np.array(imgs)test_images = pic_test_images.reshape(pic_test_images.shape[0], 108, 192, 3) / 255.0return test_images, np.array(labels)def getModel(train_mode=False):"""获取模型:param train_mode: 是否训练:return: 模型"""# 如果训练if train_mode:# 初始化神经网络model = init_network()# 加载数据集train_images, train_labels, _ = getTrainData()test_images, test_labels, _ = getTestData()print(train_images.shape)print(train_labels.shape)print(test_images.shape)print(test_labels.shape)# 开始训练,训练二十次,显示日志信息model.fit(train_images, keras.utils.to_categorical(train_labels), batch_size=128, epochs=100, verbose=2)# 评估模型,不输出预测结果test_loss, test_acc = model.evaluate(test_images, keras.utils.to_categorical(test_labels), verbose=2)# 输出损失值print('测试集损失:', test_loss)# 输出正确率print('测试集正确率:', test_acc)# 保存模型model.save('.\\model\\expll.h5')return model, test_loss, test_accelse:# 加载模型model = tf.keras.models.load_model('.\\model\\780_3x3_1_3_100_expll.h5')# 打印模型信息model.summary()test_images, test_labels, _ = getTestData()# 评估模型,不输出预测结果test_loss, test_acc = model.evaluate(test_images, keras.utils.to_categorical(test_labels), verbose=2)# print([np.where(i == np.max(i))[0][0] for i in model.predict(test_images)])return model, test_loss, test_acc# 训练模型
# if __name__ == '__main__':
#     model = getModel(True)

3.2.2 模型功能的封装,用于支持PyQt功能界面逻辑

import json
import os
import shutilimport numpy as np
from PyQt5.QtCore import *import utils
from model import getModeldef getModelSupportTypes(data):"""获取模型支持的分类:return:"""temp = ''for i in data:temp = temp + ' ' + ireturn tempdef getModelInfo(loss, acc):"""获取模型信息:return: 模型测试准确度"""return '测试集损失:{:.3f}\n测试集准确率:{:.3f}%'.format(loss, acc * 100)class Service(QObject):signalRunTime = pyqtSignal(str, bool)model = NonesignalWorking = pyqtSignal(bool)loadModelStatus = FalsesignalModelInfo = pyqtSignal(str)signalModelSupportTypes = pyqtSignal(str)def __init__(self):super().__init__()def predict(self, imgs: np.array):"""预测:param imgs: 预测图片集:return: 预测结果"""rs = self.model.predict(imgs)return [np.where(i == np.max(i))[0][0] for i in rs]def iniModel(self):"""初始化加载模型"""if self.loadModelStatus:self.signalRunTime.emit('模型加载中···', False)returnself.loadModelStatus = Trueself.signalRunTime.emit('正在加载模型···', False)self.model, loss, acc = getModel()with open('model/model.json', 'r', encoding='utf8') as fp:info = json.load(fp)self.signalModelInfo.emit('方法:' + info['way'] + '\n' + getModelInfo(loss, acc))self.signalModelSupportTypes.emit(getModelSupportTypes(info['support']))self.signalRunTime.emit('模型加载完成', False)self.loadModelStatus = Falsedef startRun(self, window):"""开始进行分类:param window: 窗口对象"""if len(window.getFromPath()) == 0 or len(window.getTargetPath()) == 0:self.signalRunTime.emit('\n存在路径为空\n', False)self.signalWorking.emit(False)returnlist_path = []self.signalRunTime.emit('\n检索中······\n', False)utils.getListDir(window.fromPath.toPlainText(), window.getRecursionPathStatus(), list_path, imageCallback=None,dirCallback=lambda x: self.signalRunTime.emit('检索检索到目录: {0}\n'.format(x), False))self.signalRunTime.emit('检索完成,共计{0}张图片\n'.format(len(list_path)), False)if len(list_path) == 0:self.signalWorking.emit(False)returnself.signalRunTime.emit('开始读取图片······', False)img = utils.get_data(list_path, lambda x: self.signalRunTime.emit('已加载: {0}\n'.format(x), False))self.signalRunTime.emit('读取图片完成', False)self.signalRunTime.emit('维度信息:{0}'.format(img.shape), False)self.signalRunTime.emit('进行分类识别中······', False)rs = self.predict(img)self.signalRunTime.emit('分类识别完成\n***********\n识别结果:\n***********\n***********\n***********\n', False)with open('.\\model\\model.json', encoding='utf8') as fp:supportTypes = json.load(fp)['support']outRunInfo = '\n'for i in zip(list_path, rs):outRunInfo = outRunInfo + '路径: {0}; 结果:{1}\n\n'.format(i[0], supportTypes[i[1]])self.signalRunTime.emit(outRunInfo + '\n\n***********\n***********\n识别结果输出结束\n***********\n***********\n',False)targetPathRoot = window.getTargetPath()for i in supportTypes:if not os.path.exists(targetPathRoot + '/' + i):os.mkdir(targetPathRoot + '/' + i)self.signalRunTime.emit('\n\n开始进行分类迁移······', False)onlyMoveMax = window.getOnlyNumber()with open('.\\model\\model.json', encoding='utf8') as fp:supportTypes = json.load(fp)['support']for j in range(0, int(len(list_path) * 1.0 / onlyMoveMax + 1)):for i in list(zip(list_path, rs))[onlyMoveMax * j:onlyMoveMax * (j + 1)]:try:self.signalRunTime.emit('来源: {0}; 迁移至:{1}\n\n'.format(i[0], (targetPathRoot + '/' + supportTypes[i[1]])), False)shutil.move(i[0], targetPathRoot + '/' + supportTypes[i[1]])except Exception as e:self.signalRunTime.emit('ERROR: {0}'.format(e, False))self.signalRunTime.emit('\n\n迁移结束,任务完成\n\n', False)self.signalWorking.emit(False)

3.2.3 入口部分

# -*- coding: utf-8 -*-
import os
import sys
from concurrent.futures import ThreadPoolExecutorfrom PyQt5.QtWidgets import *import fun
from ui import Ui_FormthreadPool = ThreadPoolExecutor(max_workers=20)def openPath(callback):# 选择图片path = QFileDialog.getExistingDirectory(None, "选择存储文件夹", os.getcwd())if path == "":return 0callback(path)class MainWindow(QWidget, Ui_Form):service = Noneimg = Noneworking = Falsedef __init__(self, service_):super(MainWindow, self).__init__()self.service = service_self.setupUi(self)def openFromPath(self):"""选择来源路径"""openPath(callback=lambda x: self.fromPath.setText(x))def openTargetPath(self):"""选择输出路径"""openPath(callback=lambda x: self.targetPath.setText(x))def outRuntimeInfo(self, data, refresh=True):"""输出运行时:param data: 日志:param refresh: 追加或清空再输出"""if refresh:self.runtimeInfor.setText(data)else:self.runtimeInfor.setText(self.runtimeInfor.toPlainText() + '\n' + data)self.runtimeInfor.moveCursor(self.runtimeInfor.textCursor().End)def getFromPath(self):"""获取源路径:return: 源路径"""return self.fromPath.toPlainText()def getTargetPath(self):"""获取输出路径:return: 输出路径"""return self.targetPath.toPlainText()def outSupportTypes(self, data):"""输出模型支持的类型:param data: 类型串"""self.modelType.setText(data)def outModelInfo(self, data):"""输出模型信息:param data: 模型信息"""self.modelInfor.setText(data)def getOnlyNumber(self):"""单次处理图片数量:return: 数量"""return self.onlyNumber.value()def getRecursionPathStatus(self):"""是否递归目录"""return self.recursionPath.checkState() == 2def startRun(self):"""开始分类"""if self.working:self.outRuntimeInfo('任务执行中', False)returntry:threadPool.submit(service.startRun, self)except Exception as e:print(e)def setWorking(self, status):self.working = statusif __name__ == '__main__':service = fun.Service()app = QApplication(sys.argv)# 初始化窗口m = MainWindow(service)m.btu_selectFromPath.clicked.connect(m.openFromPath)m.btu_selectTargetPath.clicked.connect(m.openTargetPath)m.btu_startRun.clicked.connect(m.startRun)m.setWindowTitle('1920*1080壁纸分类')m.show()service.signalRunTime.connect(m.outRuntimeInfo)service.signalWorking.connect(m.setWorking)service.signalModelInfo.connect(m.outModelInfo)service.signalModelSupportTypes.connect(m.outSupportTypes)threadPool.submit(service.iniModel)sys.exit(app.exec_())

3.2.4 UI界面

在这里插入图片描述

4 结语

  虽说很简单,或许显得很那么······没用,但是也是自己的一个小成果,也算是又做了一个对自己有用的工具吧!

项目文件所在地址,内含训练好的模型,目前支持五种:https://github.com/WindSnowLi/picture-classify
原文

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/1379980.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

人工智能最全图谱

在过去的几个月中,我一直在收集有关人工智能的相关资料。随着各种的问题被越来越频繁的提及,我决定整理并分享有关人工智能、神经网络、机器学习、深度学习与大数据的技术合辑。同时为了内容更加生动易懂,本文将会针对各个大类展开详细解析。…

人工智能AI - 以图搜图产品

网站首页 以图搜图产品 主要特性 底层使用特征向量相似度搜索单台服务器十亿级数据的毫秒级搜索近实时搜索,支持分布式部署随时对数据进行插入、删除、搜索、更新等操作支持在线用户管理与服务器性能监控,支持限制单用户登录 系统功能 搜索管理&#…

2021-05-19 人工智能图片识别

手写数字识别案例(第一版) 任务:输入:28*28的灰度图片 输出:0-9的数字标签 样本量:6万训练样本,1万测试样本 数据处理:读取数据和预处理操作 模型设计:网络结构 训练…

人工智能——图搜索

一.数据驱动和目标驱动搜索 以下情况建议使用目标驱动搜索: (1)目标或假设是在问题陈述中给出的。例如定理的证明,目标就是定理。 (2)与问题数据匹配的规则非常多,会产生大量分支…

手机声音同步到另一部手机_手机数据同步、丢失不再可怕

日常生活中,我们使用手机最大的难题可能就是手机资料的丢失了。熊孩子玩手机在你不注意的情况下把照片删掉了,换新手机资料的同步更是麻烦,还有甚者就是手机丢了,里面的数据资料全面化为泡影,想哭都没地儿哭。而现在不…

互联网日报 | 华为发布首款商用台式机;京东健康正式登陆港交所;苹果推出首款头戴式耳机...

今日看点 ✦ 京东健康港交所上市,募资265亿港元、总市值超3400亿港元 ✦ 华为发布首款商用台式机,商用PC布局更进一步 ✦ 淘宝特价版注册“1元更香”商标,每月最后一周定为“1元更香节” ✦ 大众汽车(安徽)正式揭牌&am…

富士康登陆A股 工业互联网的盛宴

富士康工业互联网(FII)于6月8日登陆A股,开盘大涨44.01%,报19.83元,目前FII总市值达3905亿元,超过海康威视、美的集团等企业,位居A股市值第14名,同时也成为A股市值最高的科技企业。 …

要闻君说: 百度云喜提信息安全首证;紫光展锐携5G芯片进击2019MWC;OPPO首发5G手机惊艳亮相……...

关注并标星星CSDN云计算 每周三次,打卡即read 更快、更全了解泛云圈精彩news go go go 大家好!偶是要闻君。活动多多、新闻不少,精神饱满的周一,学起来!!! 文/要闻君 一年一度,十分…

LVS/DR+Keepalived负载均衡实战(一)

引言 负载均衡这个概念对于一个IT老鸟来说再也熟悉不过了,当听到此概念的第一反应是想到举世闻名的nginx,但殊不知还有一个大名鼎鼎的负载均衡方案可能被忽略了,因为对于一般系统来说,很多应用场合中采用nginx基本已经满足需求&a…

【Java】数据交换 Json 和 异步请求 Ajax

🎄欢迎来到边境矢梦的csdn博文,本文主要讲解Java 中 数据交换和异步请求 Json&Ajax 的相关知识🎄 🌈我是边境矢梦,一个正在为秋招和算法竞赛做准备的学生🌈 🎆喜欢的朋友可以关注一下&#…

go语言从0基础到安全项目开发实战

一.环境搭建并helloworld 搭建环境比较简单 1.1安装SDK 到以下链接下 Go下载 - Go语言中文网 - Golang中文社区 下载windows版本64位zip包 https://studygolang.com/dl/golang/go1.20.7.windows-amd64.zip 1.2配置环境变量 不配置的话就只能在bin目录下才能运行go命令 …

linux安装ftp

一、安装 参考博客 https://blog.csdn.net/dafeigecsdn/article/details/126518069 rpm -qa |grep vsftpd # 查看是否安装ftp yum -y install vsftpd # 安装vsftpuseradd -d /home/lanren312 lanren312 # 指定在/home目录下创建用户 passwd lanren312 # 给用户设置密码 # 输…

20220209学速写

抖音上学速写感觉不太行呀。虽然看起来简单但感觉手很笨,感觉从基础入门后开始讲的,而我还缺少基础。。。

人物速写示范(30张图)

人物速写示范(30张图) 2007/01/11 10:59 扫描自《叶老师速写教学示范》——湖北美术出版社叶军,1964年生于湖北沙市,毕业于湖北美术学院,学士学位。现为湖北美术学院副教授,中国画系副主任,研究…

学习速写的方法有哪些?如何快速学会速写?

本文由“学美术上美术集网校”原创,图片素材来自网络,仅供学习分享 学习速写的方法有哪些?如何快速学会速写?很多初学绘画者,包括有些已经进行过一些素描训练的学画青少年想画速写,总感到无从下手。在与这些初学绘画者的接触中,我总是尽量告诉他们一些速写方面的训练方…

Vscode 速写 HTML

Vscode 速写 HTML 文章目录 Vscode 速写 HTML1. 快速生成HTML结构2. 快速生成标签3. 生成指定标签4. 插件 1. 快速生成HTML结构 输入 ! 后按 Tab <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name&qu…

速写篇—速写打型需要几步?这5步准确起型~

速写怎么打好型&#xff1f;速写打型需要哪些步骤&#xff1f;很多小伙伴在学习美术速写的时候都会遇到各种问题今天美术集网校带大家了解下速写如何打好型&#xff1a; 画速写人物真的很难吗?如果你画的人物得不到高分&#xff0c;你可能需要考虑一下是不是打形没有画好&…

速写想要拿高分?这些要点能提分~

速写怎么画&#xff1f;怎么画速写才能提高分&#xff1f;很多小伙伴在学习美术都会遇到各种问题今天美术集网校带大家了解下速写提高分的方法吧&#xff1a; 速写想要取得高分&#xff0c;首先就要先突破难点&#xff0c;找到短板&#xff0c;逐个克服才能更好的把握速写。 首…

学速写的步骤来啦,零基础学习更简单

最近美术集小编收到了很多新手学习速写的问题点&#xff0c;想要学习速写&#xff0c;应该从哪些步骤开始呢&#xff1f;今天广州美术集网校就帮大家整理了一些画速写的步骤&#xff0c;掌握好这些步骤&#xff0c;速写的学习就像开了加速器&#xff1a; ​ 第一&#xff0c;我…