2024Datawhale AI夏令营---Inclusion・The Global Multimedia Deepfake Detection--学习笔记

赛题背景:

        其实总结起来就是一句话,这个项目是基于目前的深度伪装技术,就是通过大量人脸的原数据集进行模型训练之后,能够生成伪造的人脸视频。这项目就是教我们如何去实现这个DeepFake技术。

Task1:了解Deepfake和跑通baseline

代码架构如下:

  1. 模型定义:使用timm库创建一个预训练的resnet18模型。

  2. 训练/验证数据加载:使用torch.utils.data.DataLoader来加载训练集和验证集数据,并通过定义的transforms进行数据增强。

  3. 训练与验证过程

    1. 定义了train函数来执行模型在一个epoch上的训练过程,包括前向传播、损失计算、反向传播和参数更新。

    2. 定义了validate函数来评估模型在验证集上的性能,计算准确率。

  4. 性能评估:使用准确率(Accuracy)作为性能评估的主要指标,并在每个epoch后输出验证集上的准确率。

  5. 提交:最后,将预测结果保存到CSV文件中,准备提交到Kaggle比赛。

代码解释如下:

详见代码注释吧

这份代码后续还是要好好精读理解一下的吧,好好分析一下,顺便提升一下代码能力。--7.15

from PIL import Image
Image.open('/kaggle/input/deepfake/phase1/trainset/63fee8a89581307c0b4fd05a48e0ff79.jpg')import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = Trueimport torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
import timm
import timeimport pandas as pd
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm_notebooktrain_label = pd.read_csv('/kaggle/input/deepfake/phase1/trainset_label.txt')
val_label = pd.read_csv('/kaggle/input/deepfake/phase1/valset_label.txt')train_label['path'] = '/kaggle/input/deepfake/phase1/trainset/' + train_label['img_name']
val_label['path'] = '/kaggle/input/deepfake/phase1/valset/' + val_label['img_name']train_label['target'].value_counts()val_label['target'].value_counts()train_label.head(10)class AverageMeter(object):"""Computes and stores the average and current value"""def __init__(self, name, fmt=':f'):self.name = nameself.fmt = fmtself.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.countdef __str__(self):fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'return fmtstr.format(**self.__dict__)class ProgressMeter(object):def __init__(self, num_batches, *meters):self.batch_fmtstr = self._get_batch_fmtstr(num_batches)self.meters = metersself.prefix = ""def pr2int(self, batch):entries = [self.prefix + self.batch_fmtstr.format(batch)]entries += [str(meter) for meter in self.meters]print('\t'.join(entries))def _get_batch_fmtstr(self, num_batches):num_digits = len(str(num_batches // 1))fmt = '{:' + str(num_digits) + 'd}'return '[' + fmt + '/' + fmt.format(num_batches) + ']'def validate(val_loader, model, criterion):batch_time = AverageMeter('Time', ':6.3f')losses = AverageMeter('Loss', ':.4e')top1 = AverageMeter('Acc@1', ':6.2f')progress = ProgressMeter(len(val_loader), batch_time, losses, top1)# switch to evaluate modemodel.eval()with torch.no_grad():end = time.time()for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):input = input.cuda()target = target.cuda()# compute outputoutput = model(input)loss = criterion(output, target)# measure accuracy and record lossacc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100losses.update(loss.item(), input.size(0))top1.update(acc, input.size(0))# measure elapsed timebatch_time.update(time.time() - end)end = time.time()# TODO: this should also be done with the ProgressMeterprint(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))return top1def predict(test_loader, model, tta=10):# switch to evaluate modemodel.eval()test_pred_tta = Nonefor _ in range(tta):test_pred = []with torch.no_grad():end = time.time()for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):input = input.cuda()target = target.cuda()# compute outputoutput = model(input)output = F.softmax(output, dim=1)output = output.data.cpu().numpy()test_pred.append(output)test_pred = np.vstack(test_pred)if test_pred_tta is None:test_pred_tta = test_predelse:test_pred_tta += test_predreturn test_pred_ttadef train(train_loader, model, criterion, optimizer, epoch):batch_time = AverageMeter('Time', ':6.3f')losses = AverageMeter('Loss', ':.4e')top1 = AverageMeter('Acc@1', ':6.2f')progress = ProgressMeter(len(train_loader), batch_time, losses, top1)# switch to train modemodel.train()end = time.time()for i, (input, target) in enumerate(train_loader):input = input.cuda(non_blocking=True)target = target.cuda(non_blocking=True)# compute outputoutput = model(input)loss = criterion(output, target)# measure accuracy and record losslosses.update(loss.item(), input.size(0))acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100top1.update(acc, input.size(0))# compute gradient and do SGD stepoptimizer.zero_grad()loss.backward()optimizer.step()# measure elapsed timebatch_time.update(time.time() - end)end = time.time()if i % 100 == 0:progress.pr2int(i)class FFDIDataset(Dataset):def __init__(self, img_path, img_label, transform=None):self.img_path = img_pathself.img_label = img_labelif transform is not None:self.transform = transformelse:self.transform = Nonedef __getitem__(self, index):img = Image.open(self.img_path[index]).convert('RGB')if self.transform is not None:img = self.transform(img)return img, torch.from_numpy(np.array(self.img_label[index]))def __len__(self):return len(self.img_path)import timm
model = timm.create_model('resnet18', pretrained=True, num_classes=2)
model = model.cuda()train_loader = torch.utils.data.DataLoader(FFDIDataset(train_label['path'].head(1000), train_label['target'].head(1000), transforms.Compose([transforms.Resize((256, 256)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])), batch_size=40, shuffle=True, num_workers=4, pin_memory=True
)val_loader = torch.utils.data.DataLoader(FFDIDataset(val_label['path'].head(1000), val_label['target'].head(1000), transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])), batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), 0.005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
best_acc = 0.0
for epoch in range(2):scheduler.step()print('Epoch: ', epoch)train(train_loader, model, criterion, optimizer, epoch)val_acc = validate(val_loader, model, criterion)if val_acc.avg.item() > best_acc:best_acc = round(val_acc.avg.item(), 2)torch.save(model.state_dict(), f'./model_{best_acc}.pt')test_loader = torch.utils.data.DataLoader(FFDIDataset(val_label['path'], val_label['target'], transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])), batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)val_label['y_pred'] = predict(test_loader, model, 1)[:, 1]
val_label[['img_name', 'y_pred']].to_csv('submit.csv', index=None)

        本来是想直接在本地运行的,但是这个数据集实在是太大了,受限于操作和设备,只能在kaggle云运行这个代码咯,结果如下:

        提交上kaggle进行评分:

Inclusion・The Global Multimedia Deepfake Detection | Kaggle

        结果挺差的,毕竟这就是个普通的原始代码,啥都没优化的,参数也没调,只能先这样咯,后续task再来优化调整咔咔上分吧。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3246182.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

C语言 | Leetcode C语言题解之第237题删除链表中的节点

题目: 题解: /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/void deleteNode(struct ListNode* node) {struct ListNode * p node->next;int temp;temp node->val;node->val…

C++从入门到起飞之——inline/nullptr关键字全方位剖析!

个人主页:秋风起,再归来~ C从入门到起飞 个人格言:悟已往之不谏,知来者犹可追 克心守己,律己则安! 目录 1、inline 2、nullptr 3.完结散花 1、inline • ⽤inline修饰的函数叫…

苹果手机的微信过期文件怎么恢复?3个小窍门,让你快速找回

在微信APP里,发送过的文件只能储存7天,7天之后就会自动清除,导致无法打开。那么,微信过期文件怎么恢复呢?别担心,今天我们就来分享3个实用的小窍门,帮助你轻松恢复苹果手机上过期的微信文件。赶…

React Native 自定义 Hook 获取组件位置和大小

在 React Native 中自定义 Hook useLayout 获取 View、Pressable 等组件的位置和大小的信息 import {useState, useCallback} from react import {LayoutChangeEvent, LayoutRectangle} from react-nativeexport function useLayout() {const [layout, setLayout] useState&l…

springcolud学习03Eureka

Eureka 模块 来实现服务治理 服务治理就是提供了微服务架构中各微服务实例的快速上线或下线且保持各服务能正常通信的能力的方案总称 建立eureka模型 导入依赖 <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XM…

Linux 上 TTY 的起源

注&#xff1a;机翻&#xff0c;未校对。 What is a TTY on Linux? (and How to Use the tty Command) What does the tty command do? It prints the name of the terminal you’re using. TTY stands for “teletypewriter.” What’s the story behind the name of the co…

浅谈Visual Studio 2022

Visual Studio 2022&#xff08;VS2022&#xff09;提供了众多强大的功能和改进&#xff0c;旨在提高开发者的效率和体验。以下是一些关键功能的概述&#xff1a;12 64位支持&#xff1a;VS2022的64位版本不再受内存限制困扰&#xff0c;主devenv.exe进程不再局限于4GB&#xf…

SQL常用数据过滤---IN操作符

在SQL中&#xff0c;IN操作符常用于过滤数据&#xff0c;允许在WHERE子句中指定多个可能的值。如果列中的值匹配IN操作符后面括号中的任何一个值&#xff0c;那么该行就会被选中。 以下是使用IN操作符的基本语法&#xff1a; SELECT column1, column2, ... FROM table_name WH…

用Vue3和WebCola实现3D图的在线展示

本文由ScriptEcho平台提供技术支持 项目地址&#xff1a;传送门 基于Cola.js的网络图绘制 应用场景 Cola.js是一个JavaScript库&#xff0c;用于绘制交互式网络图。它广泛应用于社交网络、知识图谱、生物网络等领域&#xff0c;帮助用户可视化和探索复杂的数据关系。 基本…

c语言唯一一个三目运算符

条件表达式由两个符号&#xff08;&#xff1f;和&#xff1a;&#xff09;组成&#xff0c;必须一起使用。要求有三个操作对象&#xff0c;称为三目运算符。 一般形式为 表达式1&#xff1f;表达式2&#xff1a;表达式3 理解如下&#xff1a; a>b?(maxa):(maxb); //相当…

用AI生成Springboot单元测试代码太香了

你好&#xff0c;我是柳岸花开。 在当今软件开发过程中&#xff0c;单元测试已经成为保证代码质量的重要环节。然而&#xff0c;编写单元测试代码却常常让开发者头疼。幸运的是&#xff0c;随着AI技术的发展&#xff0c;我们可以利用AI工具来自动生成单元测试代码&#xff0c;极…

JS+CSS特效:HTML+JS+CSS 实现精致的带二级菜单的头部菜单

本篇&#xff0c;我们来演示一个二级菜单是怎么做出来的。 案例效果图 因为本次内容主要目标是实现顶部的导航菜单&#xff0c;所以我们不关心其他内容。 第一步&#xff1a;清除浏览器默认样式 & 添加基本样式 *{ margin: 0px; padding: 0px; box-sizing: border-box; …

万界星空科技电线电缆行业MES系统核心功能

在日新月异的科技浪潮中&#xff0c;电线电缆行业作为国民经济的重要支柱&#xff0c;正面临着前所未有的挑战与机遇。如何在激烈的市场竞争中脱颖而出&#xff0c;实现生产效率与产品质量的双重飞跃&#xff0c;成为了每一家线缆企业亟需解决的课题。万界星空科技&#xff0c;…

电池放电倍率

电池放电倍率是指电池在单位时间内放电的速率与其额定容量之比 &#xff0c;放电倍率越大&#xff0c;表示电池能够在较短的时间内释放更多的电能。一般来说&#xff0c;电池的放电倍率会影响其使用时的性能和寿命。 电池的放电倍率主要取决于其设计和制造工艺。一般来说&#…

Github 2024-07-17 开源项目日报 Top10

根据Github Trendings的统计,今日(2024-07-17统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量非开发语言项目3Python项目3Rust项目2TypeScript项目2MDX项目1项目化学习 创建周期:2538 天协议类型:MIT LicenseStar数量:161973 个Fork数量…

气象数据文件名解析:使用正则表达式提取时间信息

气象数据文件名解析&#xff1a;使用正则表达式提取时间信息 前言 在处理大量气象数据文件时&#xff0c;文件名往往携带了关键的元数据信息&#xff0c;如日期、时间、地点、测量设备等。其中&#xff0c;时间信息尤为重要&#xff0c;因为它帮助我们理解数据的时效性和用于…

2024 50+行业大模型应用解决方案全解

第一章&#xff1a;以“生成”能力赋能产业智慧化 从当前大模型的行业应用发展中可以看到&#xff0c;现阶段的大模型更适合于企业的“生成”任务&#xff0c;而非“决策”任务。 “生成”任务主要指文本生成、对话系统、语言翻译等&#xff0c;大模型可以通过分析大量文本数…

算法项目报告:物流中的最短路径问题

问题描述 物流问题 有一个物流公司需要从起点A到终点B进行货物运输&#xff0c;在运输过程中&#xff0c;该公司需要途径多个不同的城市&#xff0c;并且在每个城市中都有一个配送站点。为了最大程度地降低运输成本和时间&#xff0c;该公司需要确定经过哪些配送站点&#xff…

<数据集>猫狗识别数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;3686张 标注数量(xml文件个数)&#xff1a;3686 标注数量(txt文件个数)&#xff1a;3686 标注类别数&#xff1a;2 标注类别名称&#xff1a;[cat, dog] 序号类别名称图片数框数1cat118811892dog24982498 使用标…

印尼语翻译通:AI驱动的智能翻译与语言学习助手

在这个多元文化交织的世界中&#xff0c;语言是连接我们的桥梁。印尼语翻译通&#xff0c;一款专为打破语言障碍而生的智能翻译软件&#xff0c;让您与印尼语的世界轻松接轨。无论是商务出差、学术研究&#xff0c;还是探索印尼丰富的文化遗产&#xff0c;印尼语翻译通都是您的…