使用GCN根据颗粒图像预测对应性能

之前做一个小实验写的代码,本想创建个git repo,想了想好像没必要,直接用篇博文记录一下吧。
对应资源 : https://download.csdn.net/download/rayso9898/87865298

0. 大纲

0.1 代码说明

  1. dataGeneration.py ->
    RSA生成n张图像,可以指定颗粒个数为m,半径为r。

  2. exGraph.py ->
    将一张sim仿真图像,转换成一个node_feature matrix, n * 4 : n个节点,3个特征-质心x y,等效直径,面积
    有r12 - r22共 6 种不同粒径的图像,组成字典,每种类别有100张图的node_feature

  3. create_dataset.py ->
    主要是 create_graph函数

  4. gcn.py ->
    gcn回归模型

  5. task_script ->
    提供模型训练 预测保存 loss查看等功能

0.2 数据说明

  1. sim_data.zip -> dataGeneration.py 生成
    r12-r22共6个文件夹,每个文件夹100张图像。

  2. sim_gdata.pt -> exGraph.py 生成
    sim_gdata structure:
    all - r12 r14 r16 r18 r20 r22
    r12 - img1 img2 … img100
    img1 - particle1 … paritle n
    particle1 - [x,y,dalimeter,area]

  3. dataset_img_property.pt -> create_dataset.py 生成
    x = [] # 图所对应的节点 num4 (x,y,r,area)
    pos = []# 图所对应节点的位置 num
    2 (x,y)
    y = stress[i] # 图所对应的力学性能数据
    x = torch.tensor(x,dtype=torch.float32)
    # 构造一张img的一个图
    y = torch.tensor([y],dtype=torch.float32)
    pos = torch.tensor(pos,dtype=torch.float32)
    edge_index = knn_graph(pos, k=5)
    g = Data(x=x, y=y, pos=pos,edge_index=edge_index)
    all.append(g)

  4. gcn.pt ->
    模型 后续加载即可使用

  5. loss_gcn.pt
    训练过程中的训练集和测试集loss变化

1. dataGeneration.py ->随机序列吸附法(RSA)生成颗粒图像

可以指定生成颗粒的个数和颗粒半径,以及生成的图像个数。
在这里插入图片描述
在这里插入图片描述

# -*- coding: utf-8 -*-
# @Time    : 2021/5/1 15:12
# @Author  : Ray_song
# @File    : dataGeneration.py
# @Software: PyCharmimport cv2
import math
import numpy as npdef calcDis(a,b,c,d):return math.sqrt(math.pow(a-c,2)+math.pow(b-d,2))def generate(n,r,filename,save_path):'''基于RSA算法,生成一张图像:param n: 该图像中有n个颗粒:param r: 颗粒的半径大小为r:param filename: 图像的名字  如 1 2 3 ....:param save_path: 图像的保存路径 - 文件夹:return: None'''# 1.创建白色背景图片d = 512img = np.ones((d, d, 3), np.uint8) * 0#testinglist = []center_x = np.random.randint(0, high=d)center_y = np.random.randint(0, high=d)list.append([center_x,center_y])# 随机半径与颜色radius = rcolor = (0, 255, 0)cv2.circle(img, (center_x, center_y), radius, color, -1)# 2.循环随机绘制实心圆for i in range(1, n):flag = True# 随机中心点while flag:center_x_new = np.random.randint(radius, high=d-radius)center_y_new = np.random.randint(radius, high=d-radius)panduan = Truefor per in list:Dis = calcDis(center_x_new, center_y_new, per[0], per[1])if Dis<2*r:panduan = Falsebreakelse:continueif panduan:list.append([center_x_new,center_y_new])cv2.circle(img, (center_x_new, center_y_new), radius, color, -1)break# 3.显示结果# cv2.imshow("img", img)# cv2.waitKey()# cv2.destroyAllWindows()# 4.保存结果root = f'{save_path}/{filename}.jpg'cv2.imwrite(root,img)def main():# example1 : 随机生成 100张 颗粒个数为80 半径为20 的 图像save_path = 'sim_data/r20'for i in range(100):generate(80,20,i+1,save_path)if __name__ == '__main__':main()

2. exGraph.py

将一张sim仿真图像,转换成一个node_feature matrix, n * 4 : n个节点,3个特征-质心x y,等效直径,面积
有r12 - r22共 6 种不同粒径的图像,组成字典,每种类别有100张图的node_feature
# -*- coding: utf-8 -*-
# @Time    : 2021/11/3 22:18
# @Author  : Ray_song
# @File    : exGraph.py
# @Software: PyCharmimport os
import torch# 计算面积占比
def countArea(img):# 返回面积占比area = 0size = img.shapeheight,width = size[0],size[1]for i in range(height):for j in range(width):if img[i, j] == 255:area += 1total = height * widthratio = area / totalreturn ratiodef distributionFit(img):'''计算一张照片中所有的 质心坐标xy、颗粒直径、面积,从一张图像中构建graph:param img: RSA生成的图像:return: 图像对应的graph,n*4, n个node, 4个feature'''import numpy as npimport cv2 as cvimg_color = cv.imread(img,1) # countors现实阶段使用img = cv.imread(img,0)# 遍历文件夹中所有的图像thresh_mode = 'THRESH_BINARY+THRESH_OTSU'# 阈值分割内容thresh_down = 127thresh_up = 256if thresh_mode == 'THRESH_BINARY':ret, thresh = cv.threshold(img, thresh_down, thresh_up, cv.THRESH_BINARY)elif thresh_mode == 'THRESH_BINARY_INV':ret, thresh = cv.threshold(img, thresh_down, thresh_up, cv.THRESH_BINARY_INV)elif thresh_mode == 'THRESH_TRUNC':ret, thresh = cv.threshold(img, thresh_down, thresh_up, cv.THRESH_TRUNC)elif thresh_mode == 'THRESH_TOZERO':ret, thresh = cv.threshold(img, thresh_down, thresh_up, cv.THRESH_TOZERO)elif thresh_mode == 'THRESH_TOZERO_INV':ret, thresh = cv.threshold(img, thresh_down, thresh_up, cv.THRESH_TOZERO_INV)elif thresh_mode == 'THRESH_BINARY+THRESH_OTSU':ret, thresh = cv.threshold(img, thresh_down, thresh_up, cv.THRESH_BINARY + cv.THRESH_OTSU)elif thresh_mode == 'No':thresh = imgcontours, hierarchy = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)# 显示contours的代码,可以取消注释进行显示# cv.drawContours(img_color,contours,-1,(255,0,0),1)# cv.imshow('1',img_color)# cv.waitKey(0)graph = []for i in range(len(contours)):node = []cnt = contours[i]M = cv.moments(cnt)  # 计算图像矩# print(M)cx = int(M['m10'] / M['m00'])  # 重心的x坐标cy = int(M['m01'] / M['m00'])  # 重心的y坐标area = cv.contourArea(cnt)equi_diameter = np.sqrt(4 * area / np.pi)node.append(cx)node.append(cy)node.append(equi_diameter)node.append(area)graph.append(node)return graphdef exGraph():dirs = os.listdir('./sim_data')print(dirs)sim_data = {}list = ['r12','r14','r16','r18','r20','r22']index = 0for dir in dirs:path = r'sim_data'+'//'+dir# path = './sim_data/r12'imgs = os.listdir(path)print(len(imgs))con = []for i in range(len(imgs)):img_path = path+'/'+imgs[i]graph = distributionFit(img_path)con.append(graph)sim_data[list[index]] = conindex = index+1torch.save(sim_data,'sim_gdata.pt')if __name__ == '__main__':a = torch.load('./sim_gdata.pt')print(a)

3. create_dataset.py ->

主要是 create_graph函数
# -*- coding: utf-8 -*-
# @Time    : 2021/11/4 23:01
# @Author  : Ray_song
# @File    : create_dataset.py
# @Software: PyCharmimport torch
from torch_geometric.data import Data
from torch_cluster import knn_graph'''sim_gdata structure:all - r12  r14  r16  r18 r20 r22r12 - img1 img2 ... img100img1 - particle1 .. paritle nparticle1 - [x,y,dalimeter,area]
'''def create_graph():path = r'./sim_gdata.pt'data = torch.load(path)r = ['r12', 'r14', 'r16', 'r18', 'r20', 'r22']stress = [225, 230, 235, 240, 245, 250, 255]# 构造图all = []for i in range(len(data)):ri = r[i] # 字典 - 图数据的字典imgs = data[ri]y = stress[i]  # 图所对应的力学性能数据# 遍历ri中的每一张图for j in range(len(imgs)):x = [] # 图所对应的节点 num*4 (x,y,r,area)pos = []# 图所对应节点的位置 num*2  (x,y)img = imgs[j]# 遍历图中所有的节点for k in range(len(img)):# 单个节点的特征xi = []xi.append(img[k][0])xi.append(img[k][1])xi.append(img[k][2])xi.append(img[k][3])x.append(xi)# 位置信息(x,y)posi = []posi.append(img[k][0])posi.append(img[k][1])pos.append(posi)x = torch.tensor(x,dtype=torch.float32)# 构造一张img的一个图y = torch.tensor([y],dtype=torch.float32)pos = torch.tensor(pos,dtype=torch.float32)g = Data(x=x, y=y, pos=pos)g.edge_index = knn_graph(pos, k=5)all.append(g)torch.save(all,r'dataset_img_property.pt')return allif __name__ == '__main__':create_graph()

4. gcn.py ->

gcn回归模型

# -*- coding: utf-8 -*-
# @Time    : 2021/11/5 19:13
# @Author  : Ray_song
# @File    : gcn.py
# @Software: PyCharmimport torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_poolclass GCN(torch.nn.Module):def __init__(self, hidden_channels):super(GCN, self).__init__()# torch.manual_seed(12345)self.conv1 = GCNConv(4, hidden_channels)self.conv2 = GCNConv(hidden_channels, hidden_channels)# self.conv3 = GCNConv(hidden_channels, hidden_channels)self.lin = Linear(hidden_channels, 1)def forward(self, x, edge_index, batch):# 1. Obtain node embeddingsx = self.conv1(x, edge_index)x = x.relu()x = self.conv2(x, edge_index)# x = x.relu()# x = self.conv3(x, edge_index)# 2. Readout layerx = global_mean_pool(x, batch)  # [batch_size, hidden_channels]# 3. Apply a final regressorx = self.lin(x)return x

5. task_script ->

提供模型训练 预测保存 loss查看等功能
import numpy as np
import pandas as pd
import torch
# from torch_geometric.data import Data
# from torch_cluster import knn_graph
# import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D
# from torch_geometric.datasets import GeometricShapes
# from torch_cluster import knn_graph
from torch_geometric.loader import DataLoader
from gcn import GCN,GCNConvdef showDataset():path = r'./dataset_img_property.pt'dataset = torch.load(path)print()# print(f'Dataset: {dataset}:')print('====================')print(f'Number of graphs: {len(dataset)}')# print(f'Number of features: {dataset.num_features}')# print(f'Number of classes: {dataset.num_classes}')data = dataset[0]  # Get the first graph object.print()print(data)print('=============================================================')# Gather some statistics about the first graph.print(f'Number of nodes: {data.num_nodes}')print(f'Number of edges: {data.num_edges}')print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')print(f'Has isolated nodes: {data.has_isolated_nodes()}')print(f'Has self-loops: {data.has_self_loops()}')# print(f'Is undirected: {data.is_undirected()}')def split_train_val(path):dataset = torch.load(path)import randomrandom.shuffle(dataset)num = len(dataset)# print(dataset,num)ratio = 0.8train_dataset = dataset[:int(num*ratio)]val_dataset = dataset[int(num*ratio):]print(len(train_dataset),len(val_dataset))return train_dataset,val_datasetdef begin():import torchpath = r'./dataset_img_property.pt'dataset = torch.load(path)train_dataset,test_dataset = split_train_val(path)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)model = GCN(hidden_channels=64)print(model)optimizer = torch.optim.Adam(model.parameters(), lr=0.001,weight_decay=1e-8)criterion = torch.nn.MSELoss()def train():model.train()train_loss = 0for data in train_loader:  # Iterate in batches over the training dataset.out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.loss = criterion(out.squeeze(-1), data.y)  # Compute the loss.train_loss = train_loss + loss.item()# print(f'train_loss:{train_loss/data.batch}')loss.backward()  # Derive gradients.optimizer.step()  # Update parameters based on gradients.optimizer.zero_grad()  # Clear gradients.def test(loader):model.eval()loss_ = 0count = 0for data in loader:  # Iterate in batches over the training/test dataset.out = model(data.x, data.edge_index, data.batch)loss = criterion(out.squeeze(-1), data.y)loss_ += loss.item()count += 1return loss_/countloss_t = []loss_v = []for epoch in range(1, 300):loss = []train()train_acc = test(train_loader)test_acc = test(test_loader)if test_acc<45:torch.save(model.state_dict(),'./gcn.pt')loss_t.append(train_acc)loss_v.append(test_acc)print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')loss.append(loss_t)loss.append(loss_v)torch.save(loss,'loss_gcn.pt')def draw_loss():import matplotlib.pyplot as pltloss = torch.load('loss_gcn.pt')plt.plot(loss[0])plt.plot(loss[1])plt.show()def prediction_all():from sklearn.metrics import mean_squared_errorprediction_classes = 1g_path = r'./dataset_img_property.pt'save_path = r'./prediction.csv'g_data = torch.load(g_path)model = GCN(hidden_channels=64)model_weight = r'./gcn.pt'model.load_state_dict(torch.load(model_weight))model.eval()# begin predicting...res = []val_data = g_data[:600]pre_dataloader = DataLoader(val_data, batch_size=1)y = []with torch.no_grad():for item in pre_dataloader:# predict classy.append(item.y.item())output = torch.squeeze(model(item.x,item.edge_index,item.batch))  # 将batch维度压缩掉prediction = output.numpy()if prediction_classes > 1:prediction = list(prediction)res.append(prediction)res = np.array(res)y = np.array(y)acc = mean_squared_error(res,y)print('acc',acc)res = pd.DataFrame(res)res.to_csv(save_path,header=None,index=None)if __name__ == '__main__':prediction_all()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/350316.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

电表网关BL102采集DL/T645电表的操作步骤

使用钡铼BL102网关&#xff1a;西门子S7-200PLC对接ThingsBoard流程 本文主要讲述了钡铼技术BL102物联网网关如何通过RS485采集DL/T645规约电表 BL102是一款采集西门子、三菱、欧姆龙、台达、AB、施耐德等各种PLC数据转换为Modbus TCP、OPC UA、MQTT、ThingsBoard等协议的工业…

JAVA开发(神乎其神的区块链技术之数据上链)

这是我第二遍写关于区块链的博文&#xff0c;前一篇文章《神乎其神的区块链概念和技术》主要介绍区块链的由来和基本概念。因为博主最近在做一个区块链项目&#xff0c;所以有时候也遇到一些概念性的知识需要去理解&#xff0c;比如数据的上链。谈到数据上链&#xff0c;我们先…

OpenCV(图像处理)-基于Python-图像的基本变换-平移-翻转-仿射变换-透视变换

1. 概述2. 接口介绍resize()flip()rotate()仿射变换warpAffine()getRotationMatrix2D()-变换矩阵1getAffineTransform()-变换矩阵2 透视变换warpPerspective()getPerspectiveTransform() 1. 概述 为了方便开发人员的操作&#xff0c;OpenCV还提供了一些图像变换的API&#xff…

计算机改名字后找不到网络,改了wifi名字后电脑搜不到网络怎么办? | 192路由网...

问&#xff1a;为什么我改了wifi名字后&#xff0c;我的电脑就搜不到wifi信号了&#xff1f; 答&#xff1a;修改wifi名称后如果搜索不到wifi信号了&#xff0c;可以按照下面的步骤进行操作&#xff0c;以解决此问题。 1. 如果将wifi名字改成了中文&#xff0c;建议你将其修改为…

柠檬班python自动化百度云_柠字取名2019-尚名网

柠字取名2019-尚名网 名字不是一个简单、随便的称号&#xff0c;它隐含着不容忽视的信息力量。寓意好的名字有积极的暗示作用&#xff0c;使人更有信心和勇气去实现理想&#xff0c;寓意欠佳的名字反之。可见&#xff0c;名字对人们而言是非常重要的&#xff0c;为人父母者一定…

为了取一个花名,我爬下了中草药网所有的名字!

很酷哦&#xff01;不过&#xff0c;对我这个选择恐惧症来说&#xff0c;也很纠结…我们先看一下有哪些要求吧&#xff1f; 中草药名&#xff1f;人参&#xff1f;西洋参&#xff1f;还有啥&#xff1f;&#xff1f;&#xff1f; 作为一个不怎么吃药的非医学生&#xff0c;这题…

使用MySQL查找姓名重名_查询名字有多少人重名,全国同名同姓查询全国姓名数据库...

查询名字有多少人重名&#xff0c;全国同名同姓查询全国姓名数据库 时间&#xff1a;2020-04-04 15:30:01 很多爸爸妈妈在帮孩子取姓名的时候&#xff0c;会想了解在全国范围内重名的人数&#xff0c;希望宝宝的名字不会跟太多人一样。或者有的小伙伴单纯想弄明白全中国同自己姓…

百度排名优化工具 V3.0 正式版

介绍 百度排名优化工具正式版是款可以迅速提升网址百度搜索排名的工具。软件拥有智能计算关键词点击数&#xff0c;点击规则自动添加等。软件还提供了维护模式&#xff0c;自动维护您的关键词排名&#xff0c;让您的关键词排名更加稳定可靠。百度排名优化工具可以将你的网站在…

给Android系统瘦身,安卓优化大师:给系统瘦身

安卓优化大师是一款基于Android平台的系统优化软件&#xff0c;最新版本界面设计简单&#xff0c;功能全面&#xff0c;可以帮助Android手机用户给系统瘦身&#xff0c;优化手机性能。 程序名称&#xff1a;安卓优化大师 平台&#xff1a;Android 类型&#xff1a;系统优化 软件…

Windows优化大师7.96版下载

Windows优化大师提供了全面且有效而简便安全的系统检测、系统优化、系统清理、系统维护四大功能模块以及数个附加的工具软件。它能够有效地帮助用户了解自己的计算机软硬件信息&#xff1b;简化操作系统设置步骤&#xff1b;提升计算机运行效率&#xff1b;清理系统运行时产生的…

SEO优化工具-免费SEO优化工具下载-SEO优化工具大全中心

什么是SEO优化工具&#xff1f;SEO优化工具&#xff08;Seo tools&#xff09;能在搜索引擎优化过程中起到辅助的作用&#xff0c;如数据查询工具、网站排名工具、网站流量分析功能&#xff0c;站群管理工具等&#xff0c;用来提高每个SEO人员工作中的效率。 seo优化工具&#…

Android性能优化之APK优化,完整版开放下载

前言 移动研发火热不停&#xff0c;越来越多人开始学习 android 开发。但很多人感觉入门容易成长很难&#xff0c;对未来比较迷茫&#xff0c;不知道自己技能该怎么提升&#xff0c;到达下一阶段需要补充哪些内容。市面上也多是谈论知识图谱&#xff0c;缺少体系和成长节奏感&a…

win10优化大师v1.0去插件免费版

名称&#xff1a;win10优化大师v1.0去插件免费版 版本&#xff1a;1.0 软件大小&#xff1a;5.70MB 软件语言&#xff1a;中文简体 软件授权&#xff1a;免费版 应用平台&#xff1a;Win10 win10优化大师是一款面向Win10操作系统提供的优化软件&#xff0c;提供常用系统功能的…

Android性能优化工具

一、性能优化工具基础 1.1 概述 在Android开发中&#xff0c;开发者可通过"系统跟踪"观察Android设备的运行情况并生成跟踪报告&#xff0c;在此基础上进行分析优化。Android 平台提供了多种获取跟踪信息的工具&#xff1a; Android Studio CPU 性能剖析器Systrace…

PS 的常见抠图工具

PS 的常见抠图工具 1. 套索工具2. 多边形套索工具3. 磁性套索工具4. 对象套索工具5. 快速套索工具6. 魔棒工具7. 其他 1. 套索工具 能完成快速抠图, 缺点是不好控制. 2. 多边形套索工具 绘制多边形区域抠图, 缺点是不够圆滑, 返回上步是 Backspace 键. 3. 磁性套索工具 吸附边缘…

PS抠图的6种方法

1. 魔棒工具 用于去除单色背景色图片。 选中魔棒工具后&#xff0c;可以点击选中图片中的背景色进行选取&#xff0c;选中后可以去除背景。魔棒工具一般用来去除背景色为单调色的背景&#xff0c;比如背景是白色或者其他纯色之类的。 在选择时可以选择容差\连续&#xff1a; 连…

【QQ聊天界面、创建模型、懒加载数据 Objective-C语言】

一、今天我们要做的就是这个案例 1.我们今天要做的案例,做好了之后的效果就是这样 这个案例,和昨天那个微博的案例是非常相像的, 哪些相像呢, 1)整体是不是也是能滚动啊, 2)能滚动,它不仅仅是一个UIScrollView 它里面,这个也是一行、两行、三行、四行、 所以说,…

Hive学习---7、企业级调优

1、企业级调优 1.1 计算资源配置 到此学习的计算环境为HIve on MR。计算资源的调整主要包括Yarn和MR。 1.1.1 Yarn资源配置 1、Yarn配置说明 需要调整的Yarn的参数均与CPU、内存等资源有关&#xff0c;核心配置参数如下&#xff1a; &#xff08;1&#xff09;yarn.nodeman…

代码随想录算法训练营第四十八天|198.打家劫舍|213.打家劫舍II|337.打家劫舍III

LeetCode198.打家劫舍 动态规划五部曲&#xff1a; 1&#xff0c;确定dp数组&#xff08;dp table&#xff09;以及下标的含义&#xff1a;dp[i]&#xff1a;考虑下标i&#xff08;包括i&#xff09;以内的房屋&#xff0c;最多可以偷窃的金额为dp[i]。 2&#xff0c;确定递…

MockServer 服务框架设计

【摘要】 大部分现有的 mock 工具只能满足 HTTP 协议下简单业务场景的使用。但是面对一些复杂的业务场景就显得捉襟见肘&#xff0c;比如对 socket 协议的应用进行 mock&#xff0c;或者对于支付接口的失败重试的定制化 mock 场景。为解决上述问题&#xff0c;霍格沃兹测试学院…