CharTextCNN(AG数据集---新闻主题分类)

文章目录

  • CharTextCNN
  • 一、文件目录
  • 二、语料集下载地址(本文选择AG)
  • 三、数据处理(data_loader.py)
  • 四、模型(chartextcnn.py)
  • 五、训练和测试
  • 实验结果


CharTextCNN

在这里插入图片描述

一、文件目录

在这里插入图片描述

二、语料集下载地址(本文选择AG)

AG News: https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz
DBPedia: https://s3.amazonaws.com/fast-ai-nlp/dbpedia_csv.tgz
Sogou news: https://s3.amazonaws.com/fast-ai-nlp/sogou_news_csv.tgz
Yelp Review Polarity: https://s3.amazonaws.com/fast-ai-nlp/yelp_review_polarity_csv.tgz
Yelp Review Full: https://s3.amazonaws.com/fast-ai-nlp/yelp_review_full_csv.tgz
Yahoo! Answers: https://s3.amazonaws.com/fast-ai-nlp/yahoo_answers_csv.tgz
Amazon Review Full: https://s3.amazonaws.com/fast-ai-nlp/amazon_review_full_csv.tgz
Amazon Review Polarity: https://s3.amazonaws.com/fast-ai-nlp/amazon_review_polarity_csv.tgz

三、数据处理(data_loader.py)

1.数据集加载
2.读取标签和数据
3.读取所有的字符
4.将句子ont-hot表示

import os
import torch
import json
import csv
import numpy as np
from torch.utils import dataclass AG_Data(data.DataLoader):def __init__(self,data_path,l0=1014):self.path = os.path.abspath('.')if "data" not in self.path:self.path +="/data"self.data_path = data_pathself.l0 = l0self.load_Alphabet()self.load(self.data_path)# 读取所有字符def load_Alphabet(self):with open(self.path+"/alphabet.json") as f:self.alphabet = "".join(json.load(f))# 下载数据,读取标签和数据def load(self, data_path,lowercase=True):self.label = []self.data = []# 数据集加载with open(self.path+data_path,"r") as f:# 默认读写用逗号做分隔符(delimiter),双引号作引用符(quotechar)datas = list(csv.reader(f, delimiter=',', quotechar='"'))for row in datas:self.label.append(int(row[0]) - 1)txt = " ".join(row[1:])if lowercase:txt = txt.lower()self.data.append(txt)self.y = self.label# 句子one-hot表示,X:batch_size*字符one-hot表示(feature)*句子中字符个数(length=1014),Y:标签def __getitem__(self, idx):X = self.oneHotEncode(idx)y = self.y[idx]return X, ydef oneHotEncode(self, idx):X = np.zeros([len(self.alphabet),self.l0])for index_char, char in enumerate(self.data[idx][::-1]):if self.char2Index(char) != -1:X[self.char2Index(char)][index_char] = 1.0return X# 返回字符的下标,字符存在输出下标,不存在输出-1.def char2Index(self,char):return self.alphabet.find(char)# 读取标签长度def __len__(self):return len(self.label)

四、模型(chartextcnn.py)

在这里插入图片描述
每层前都进行归一化:
在这里插入图片描述

import torch
import torch.nn as nn
import numpy as np
class CharTextCNN(nn.Module):def __init__(self,config):super(CharTextCNN,self).__init__()in_features = [config.char_num] + config.features[0:-1]out_features = config.featureskernel_sizes = config.kernel_sizesself.convs = []# bs*70*1014self.conv1 = nn.Sequential(nn.Conv1d(in_features[0], out_features[0], kernel_size=kernel_sizes[0], stride=1), # 一维卷积nn.BatchNorm1d(out_features[0]), # bn层nn.ReLU(), # relu激活函数层nn.MaxPool1d(kernel_size=3, stride=3) #一维池化层) # 卷积+bn+relu+pooling模块self.conv2  = nn.Sequential(nn.Conv1d(in_features[1], out_features[1], kernel_size=kernel_sizes[1], stride=1),nn.BatchNorm1d(out_features[1]),nn.ReLU(),nn.MaxPool1d(kernel_size=3, stride=3))self.conv3 = nn.Sequential(nn.Conv1d(in_features[2], out_features[2], kernel_size=kernel_sizes[2], stride=1),nn.BatchNorm1d(out_features[2]),nn.ReLU())self.conv4 = nn.Sequential(nn.Conv1d(in_features[3], out_features[3], kernel_size=kernel_sizes[3], stride=1),nn.BatchNorm1d(out_features[3]),nn.ReLU())self.conv5 = nn.Sequential(nn.Conv1d(in_features[4], out_features[4], kernel_size=kernel_sizes[4], stride=1),nn.BatchNorm1d(out_features[4]),nn.ReLU())self.conv6 = nn.Sequential(nn.Conv1d(in_features[5], out_features[5], kernel_size=kernel_sizes[5], stride=1),nn.BatchNorm1d(out_features[5]),nn.ReLU(),nn.MaxPool1d(kernel_size=3, stride=3))self.fc1 = nn.Sequential(nn.Linear(8704, 1024), # 全连接层 #((l0-96)/27)*256nn.ReLU(),nn.Dropout(p=config.dropout) # dropout层) # 全连接+relu+dropout模块self.fc2 = nn.Sequential(nn.Linear(1024, 1024),nn.ReLU(),nn.Dropout(p=config.dropout))self.fc3 = nn.Linear(1024, config.num_classes)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.conv5(x)x = self.conv6(x)x = x.view(x.size(0), -1) # 变成二维送进全连接层x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x

五、训练和测试

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from model import CharTextCNN
from data import AG_Data
from tqdm import tqdm
import numpy as np
import config as argumentparser
config = argumentparser.ArgumentParser() # 读入参数设置
config.features = list(map(int,config.features.split(","))) # 将features用,分割,并且转成int
config.kernel_sizes = list(map(int,config.kernel_sizes.split(","))) # kernel_sizes,分割,并且转成int
config.pooling = list(map(int,config.pooling.split(",")))if config.gpu and torch.cuda.is_available():  # 是否使用gputorch.cuda.set_device(config.gpu)
# 导入训练集
training_set = AG_Data(data_path="/AG/train.csv",l0=config.l0)
training_iter = torch.utils.data.DataLoader(dataset=training_set,batch_size=config.batch_size,shuffle=True,num_workers=0)
# 导入测试集
test_set = AG_Data(data_path="/AG/test.csv",l0=config.l0)test_iter = torch.utils.data.DataLoader(dataset=test_set,batch_size=config.batch_size,shuffle=False,num_workers=0)
model = CharTextCNN(config) # 初始化模型
if config.cuda and torch.cuda.is_available(): # 如果使用gpu,将模型送进gpumodel.cuda()
criterion = nn.CrossEntropyLoss() # 构建loss结构
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) #构建优化器
loss  = -1
def get_test_result(data_iter,data_set):# 生成测试结果model.eval()data_loss = 0true_sample_num = 0for data, label in data_iter:if config.cuda and torch.cuda.is_available():data = data.cuda()label = label.cuda()else:data = torch.autograd.Variable(data).float()out = model(data)true_sample_num += np.sum((torch.argmax(out, 1) == label).cpu().numpy()) # 得到一个batch的预测正确的样本个数acc = true_sample_num / data_set.__len__()return data_loss,accfor epoch in range(config.epoch):model.train()process_bar = tqdm(training_iter)for data, label in process_bar:if config.cuda and torch.cuda.is_available():data = data.cuda()  # 如果使用gpu,将数据送进goulabel = label.cuda()else:data = torch.autograd.Variable(data).float()label = torch.autograd.Variable(label).squeeze()out = model(data)loss_now = criterion(out, autograd.Variable(label.long()))if loss == -1:loss = loss_now.data.item()else:loss = 0.95 * loss + 0.05 * loss_now.data.item()  # 平滑操作process_bar.set_postfix(loss=loss_now.data.item())  # 输出loss,实时监测loss的大小process_bar.update()optimizer.zero_grad()  # 梯度更新loss_now.backward()optimizer.step()test_loss, test_acc = get_test_result(test_iter, test_set)print("The test acc is: %.5f" % test_acc)

实验结果

输出测试集准确率:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/353926.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

[NOI2009] 描边

题目描述 小 Z 是一位杰出的数学家。聪明的他特别喜欢研究一些数学小问题。 有一天,他在一张纸上选择了 n 个点,并用铅笔将它们两两连接起来,构成 (�−1)22n(n−1)​ 条线段。由于铅笔很细,可以认为这些线段的宽度为…

【英文文本分类实战】之二——数据集挑选与划分

请参考本系列目录:【英文文本分类实战】之一——实战项目总览 下载本实战项目资源:神经网络实现英文文本分类.zip(pytorch) [1] 数据集平台 在阅读了大量的论文之后,由于每一篇论文都会提出一个模型,十分想…

使用Keras进行单模型多标签分类

原文:https://www.pyimagesearch.com/2018/05/07/multi-label-classification-with-keras/ 作者:Adrian Rosebrock 时间:2018年5月7日 源码:https://pan.baidu.com/s/1x7waggprAHQDjalkA-ctvg (wa61) 译者&…

图像分类实战:mobilenetv2从训练到TensorRT部署(pytorch)

文章目录 摘要mobilenetv2简介线性瓶颈倒残差 ONNXTensorRT项目结构训练数据增强Cutout和Mixup导入包设置全局参数图像预处理与增强读取数据设置模型定义训练和验证函数 测试模型转化及推理转onnxonnx推理转TensorRTTensorRT推理 摘要 本例提取了植物幼苗数据集中的部分数据做…

商品分类js html,json数据来制作商城的产品分类菜单

人们早就习惯了在互联网购物买东西,甚至有一部分朋友还是上瘾了。本篇PHP教程就来帮助您的电子商务项目实现最重要的产品类别的导航菜单系统。我已经使用PHP、MYSQL及JQuery实现了亚马逊样式的产品分类图像菜单,下面让我们来看一下如何使用json数据来制作商城的产品分类菜单。…

Amazon SPAPI PII权限申请问题汇总

亚马逊PII权限申请 官方文档地址:Selling Partner API 目录 亚马逊PII权限申请 Amazon PII开发者角色申请问题罗列: 接下来很多人都可能会遇到的拒绝原因: eg.1 eg.2 eg.3 eg.4 审计及远程 最后,坐等申请通过的case Amazon PII开发…

文本分类方案,飞浆PaddleNLP涵盖了所有

文章目录 1.前言2.核心技术2.1 文本分类方案全覆盖2.1.1 分类场景齐全2.1.2 多方案满足定制需求方案一:预训练模型微调方案二:提示学习方案三:语义索引 2.2 更懂中文的训练基座2.3 高效模型调优方案2.4 产业级全流程方案 3. 快速开始4. 常用中…

亚马逊中国站通过ASIN获取商品信息

目录 亚马逊中国站获取全部商品分类 亚马逊中国站获取商品列表 亚马逊中国站通过ASIN获取商品信息 亚马逊中国站获取商品库存信息 亚马逊国际站获取全部商品分类 亚马逊国际站获取商品列表 亚马逊国际站处理图形验证码 亚马逊国际站通过ASIN获取商品信息 亚马逊国际站获取商品…

亚马逊国际站获取商品列表

目录 亚马逊中国站获取全部商品分类 亚马逊中国站获取商品列表 亚马逊中国站通过ASIN获取商品信息 亚马逊中国站获取商品库存信息 亚马逊国际站获取全部商品分类 亚马逊国际站获取商品列表 亚马逊国际站处理图形验证码 亚马逊国际站通过ASIN获取商品信息 亚马逊国际站获取商品…

亚马逊国际站处理图形验证码

目录 亚马逊中国站获取全部商品分类 亚马逊中国站获取商品列表 亚马逊中国站通过ASIN获取商品信息 亚马逊中国站获取商品库存信息 亚马逊国际站获取全部商品分类 亚马逊国际站获取商品列表 亚马逊国际站处理图形验证码 亚马逊国际站通过ASIN获取商品信息 亚马逊国际站获取商品…

dvwa靶场通关(五)

第五关 File Upload(文件上传漏洞) File Upload,即文件上传漏洞,通常是由于对上传文件的类型、内容没有进行严格的过滤、检查,使得攻击者可以通过上传木马获取服务器的webshell权限 low low等级没有任何的防护 创建…

亚马逊国际站通过ASIN获取商品信息

目录 亚马逊中国站获取全部商品分类 亚马逊中国站获取商品列表 亚马逊中国站通过ASIN获取商品信息 亚马逊中国站获取商品库存信息 亚马逊国际站获取全部商品分类 亚马逊国际站获取商品列表 亚马逊国际站处理图形验证码 亚马逊国际站通过ASIN获取商品信息 亚马逊国际站获取商品…

亚马逊国际站获取全部商品分类

目录 亚马逊中国站获取全部商品分类 亚马逊中国站获取商品列表 亚马逊中国站通过ASIN获取商品信息 亚马逊中国站获取商品库存信息 亚马逊国际站获取全部商品分类 亚马逊国际站获取商品列表 亚马逊国际站处理图形验证码 亚马逊国际站通过ASIN获取商品信息 亚马逊国际站获取商品…

亚马逊中国站获取全部商品分类

目录 亚马逊中国站获取全部商品分类 亚马逊中国站获取商品列表 亚马逊中国站通过ASIN获取商品信息 亚马逊中国站获取商品库存信息 亚马逊国际站获取全部商品分类 亚马逊国际站获取商品列表 亚马逊国际站处理图形验证码 亚马逊国际站通过ASIN获取商品信息 亚马逊国际站获取商品…

推荐12个开放式免费收录网站的分类目录

很多做网站推广的网友都问2019网址分类目录还有用吗?一些长久网站或高权重、流量高的网站分类目录还有用的,不但能增加优质的外链,提高网站的权重,还能增加网站的曝光率。做seo网页优化的,很多都将会新站提交到各个分类目录网站&…

亚马逊分类目录_新版亚马逊分类目录v2.4程序源码官方分享下载

亚马逊分类目录程序是个跨平台的开源软件,具备来路、去路统计功能,支持两级分类,具有操作简单、功能强大、稳定性好、扩展性及安全性强、二次开发及后期维护方便,可以帮您迅速、轻松地构建起一个强大、专业的分类目录或网址导航网…

《大数据技术与应用》课程实验报告|week12|实验8|Pig——高级编程环境|验证评估函数

目录 一、实验内容 二、实验目的 三、实验设备 四、实验步骤 步骤一 步骤二 步骤三 步骤四 步骤五 步骤六 步骤七 步骤八 步骤九 步骤十 步骤十一 步骤十二 步骤十三 步骤十四 步骤十五 步骤十六 五、实验结果 六、实验小结 一、实验内容 验证19.5节中的…

微信Mac版客户端(支持发布朋友圈)v3.1.5(18841)正式版

微信Mac版客户端全新功能升级!!不仅支持查看朋友圈,还能发布朋友圈啦!!!微信正式版支持对朋友圈进行互动和点 赞等操作,还可以浏览朋友圈相册,这是一款运行在OS X上的 社交聊天工具&…

怒肝半月!Python 学习路线+资源大汇总

Python 学习路线 by 鱼皮。 原创不易,请勿抄袭,违者必究! 大家好,我是鱼皮,肝了十天左右的 Python 学习路线终于来了~ 和之前一样,在看路线前,建议大家先通过以下视频了解几个问题:…

计算机考研,这样选学校才是正解

写了一篇《启舰:对计算机专业来说学历真的重要吗?》,一时间N多同学咨询自身情况要不要考研,眼看有点Hold不住,索性又出了一篇《启舰:计算机专业有必要考研吗?》,结果,又有…