【域适应论文汇总】未完结

文章目录

  • DANN:Unsupervised Domain Adaptation by Backpropagation (2015)
  • TADA:Transferable Attention for Domain Adaptation(2019 AAAI)
      • 1 局部注意力迁移:Transferable Local Attention
      • 2 全局注意力迁移:Transferable Global Attention
  • DAN:Learning transferable features with deep adaptation networks(JMLR 2015)
  • ADDA:Adversarial discriminative domain adaptation(CVPR 2017)
      • 1 报错
      • 2 代码
      • 3 判别器
      • 4 分类器
      • 5 adapt
  • MCD:Maximum classifier discrepancy for unsupervised domain adaptation(CVPR 2018)
  • MDD:Bridging theory and algorithm for domain adaptation
  • CDAN:Conditional Adversarial Domain Adaptation(Neural 2018)
  • MCC:Moment Matching for Multi-Source Domain Adaptation(ICCV 2019)
  • DAPL:Domain Adaptation via Prompt Learning(DA+prompt)(arXiv 2022)
  • 特征提取器优化

DANN:Unsupervised Domain Adaptation by Backpropagation (2015)

提出DANN
在这里插入图片描述

TADA:Transferable Attention for Domain Adaptation(2019 AAAI)

提出了TADA

  • 由多个区域级 鉴别器产生的局部注意力来突出可迁移的区域
  • 由单个图像级 鉴别器产生的全局注意力来突出可迁移的图像

通过注意力机制挑选出可迁移的图像以及图像中可以重点迁移的区域。因此作者提出了两个与注意力机制结合的迁移过程:

  • Transferable Local Attention
  • Transferable Global Attention。
    在这里插入图片描述

1 局部注意力迁移:Transferable Local Attention

在这里插入图片描述
TADA与DANN的思想相同,都是通过一个特征提取器 来提取特征,之后会将提取的特征输入到域判别器 。但是TADA不同之处在于它的域判别器有多个,并且每一个域判别器是针对专门的一块区域的。在DANN中域判别器是判断输入的所有特征组合起来是属于源域还是目标域,而在TADA中每个域判别器只需要判断当前的这一块区域是属于源域还是目标域的。通过这种做法,可以将源域的图片拆开,找出最有用的区域信息,并且将不可迁移的源域信息过滤掉,减小负迁移的风险。

2 全局注意力迁移:Transferable Global Attention

在这里插入图片描述

这一步骤和DANN的操作更为相似,作者的目的是找出哪些特征映射更值得迁移,不再将特征映射划分为各个区域,而是关注它的整体。

DAN:Learning transferable features with deep adaptation networks(JMLR 2015)

代码

  • 在DAN中,所有特定于任务的层的隐藏表示都嵌入到一个可复制的内核Hilbert空间中,在这个空间中可以显式匹配不同域分布的平均嵌入。
  • 采用均值嵌入匹配的多核优化选择方法,进一步减小了domain间的差异。
  • DAN可以在有统计保证的情况下学习可转移的特性,并且可以通过核嵌入的无偏估计进行线性扩展。
    在这里插入图片描述

1 多层自适应

基本结构是AlexNet,其中三个全连接都已经和特定任务练习密切,当用于其他任务或数据集时会有较大误差,于是作者提出在最后的三个全连接层都使用MMD进行分布距离约束,从而使得模型具备更强的迁移能力。至于前边的卷积层,前三层提取到的是更为一般的特征,在预训练之后权重固定,4、5两层则要在预训练的基础上进行fine-tune(调整,以致达到最佳效果)

2 多核自适应

分布匹配主要依靠MMD作为分布距离约束来实现,而MMD的效果依赖于核函数的选择,单一核函数的表达能力是有限的,因此作者提出使用多核MMD (MK-MMD) 来作为损失

3 CNN经验误差

在这里插入图片描述- J:交叉熵损失函数

  • θ ( x i a ) θ(x_i^{a}) θ(xia) x i a x_i^{a} xia被分配到 y i a y_i^{a} yia的条件概率

4 优化目标

在这里插入图片描述

  • D s ℓ D^ℓ_s Ds:源域的第 ℓ ℓ 层隐藏表征
  • D t ℓ D^ℓ_t Dt:目标域的第 ℓ ℓ 层隐藏表征
  • d k 2 ( D s ℓ , D t ℓ ) d_k^2(D^ℓ_s, D^ℓ_t) dk2(Ds,Dt):MK-MMD评估值

5 learning Θ Θ Θ

MK-MMD计算内核功能的期望
在这里插入图片描述

6 learning β β β

多层执行MK-MMD匹配

ADDA:Adversarial discriminative domain adaptation(CVPR 2017)

在这里插入图片描述

  • 使用标记的源图像示例预训练源编码器CNN
  • 通过学习目标编码器CNN来执行对抗性适应,使得看到编码源和目标示例的鉴别器无法可靠地预测它们的域标签
  • 在测试过程中,目标图像与目标编码器一起映射到共享特征空间,并由源分类器进行分类

1 报错

  1. RuntimeError: result type Float can’t be cast to the desired output type Long
    acc /= len(data_loader.dataset)
    改成
    acc = acc / len(data_loader.dataset)

  2. 取ViT输出的池化后结果
    pred_tgt = critic(feat_tgt)
    增加 pooler_output
    pred_tgt = critic(feat_tgt.pooler_output)

  3. RuntimeError: output with shape [1, 28, 28] doesn’t match the broadcast shape [3, 28, 28]
    mnist和usps需要从灰度图片转成RGB图片,通道数从1变成3

transform = transforms.Compose([transforms.Resize((224, 224)),  # 调整大小为 224x224transforms.Grayscale(num_output_channels=3),  #转化成3通道transforms.ToTensor(),  # 将图像转换为张量])
  1. IndexError: invalid index of a 0-dim tensor. Use tensor.item() in Python or tensor.item<T>() in C++ to convert a 0-dim tensor to a number
    .data[0]
    改成
    .item()

2 代码

将lenet encoder换成vit

import torch
from transformers import ViTModel, ViTConfig
# 下载 vit-base-patch16-224-in21k def load_pretrained_vit_model():# Load pre-trained ViT-B/16 modelmodel_path = "./pretrained_models/pytorch_model.bin"config_path = "./pretrained_models/config.json"config = ViTConfig.from_json_file(config_path)vit_model = ViTModel.from_pretrained(pretrained_model_name_or_path=None,config=config,state_dict=torch.load(model_path),ignore_mismatched_sizes=True  # 忽略大小不匹配的错误)return vit_model

3 判别器

"""Discriminator model for ADDA."""
from torch import nn
class Discriminator(nn.Module):"""Discriminator model for source domain."""def __init__(self, input_dims, hidden_dims, output_dims):"""Init discriminator."""super(Discriminator, self).__init__()print("Shape of input_dims:", input_dims)self.restored = Falseself.layer = nn.Sequential(nn.Linear(input_dims, hidden_dims),nn.ReLU(),nn.Linear(hidden_dims, hidden_dims),nn.ReLU(),nn.Linear(hidden_dims, output_dims),nn.LogSoftmax())def forward(self, input):"""Forward the discriminator."""out = self.layer(input)return out

4 分类器

"""LeNet model for ADDA."""
import torch
import torch.nn.functional as F
from torch import nnclass LeNetClassifier(nn.Module):"""LeNet classifier model for ADDA."""def __init__(self, input_size):"""Init LeNet encoder."""super(LeNetClassifier, self).__init__()self.input_size = input_size# Add linear layers to adjust the size of the input feature to fit LeNet# vitself.fc1 = nn.Linear(input_size, 500)# swin# self.fc1 = nn.Linear(49 * 1024, 500)self.fc2 = nn.Linear(500, 10)def forward(self, feat):"""Forward the LeNet classifier."""# vitfeat = feat.pooler_output# swin# feat = feat.view(feat.size(0), -1)# Apply the linear layers and activation functionout = F.dropout(F.relu(self.fc1(feat)), training=self.training)out = self.fc2(out)return out

5 adapt

"""Adversarial adaptation to train target encoder."""
import os
import torch
import torch.optim as optim
from torch import nn
import params
from utils import make_variabledef train_tgt(src_encoder, tgt_encoder, critic,src_data_loader, tgt_data_loader,model_type):"""Train encoder for target domain."""##################### 1. setup network ###################### set train state for Dropout and BN layerstgt_encoder.train()critic.train()# setup criterion and optimizercriterion = nn.CrossEntropyLoss()optimizer_tgt = optim.Adam(tgt_encoder.parameters(),lr=params.c_learning_rate,betas=(params.beta1, params.beta2))optimizer_critic = optim.Adam(critic.parameters(),lr=params.d_learning_rate,betas=(params.beta1, params.beta2))len_data_loader = min(len(src_data_loader), len(tgt_data_loader))##################### 2. train network #####################for epoch in range(params.num_epochs):# zip source and target data pairdata_zip = enumerate(zip(src_data_loader, tgt_data_loader))for step, ((images_src, _), (images_tgt, _)) in data_zip:############################ 2.1 train discriminator ############################# make images variableimages_src = make_variable(images_src.cuda())images_tgt = make_variable(images_tgt.cuda())# zero gradients for optimizeroptimizer_critic.zero_grad()# extract and concat featuresfeat_src = src_encoder(images_src).pooler_outputfeat_tgt = tgt_encoder(images_tgt).pooler_outputfeat_concat = torch.cat((feat_src, feat_tgt), 0)# predict on discriminatorpred_concat = critic(feat_concat.detach())# prepare real and fake labellabel_src = make_variable(torch.ones(feat_src.size(0)).long().cuda())label_tgt = make_variable(torch.zeros(feat_tgt.size(0)).long().cuda())label_concat = torch.cat((label_src, label_tgt), 0)# compute loss for criticloss_critic = criterion(pred_concat, label_concat)loss_critic.backward()# optimize criticoptimizer_critic.step()pred_cls = torch.squeeze(pred_concat.max(1)[1])############################# 2.2 train target encoder ############################## zero gradients for optimizeroptimizer_critic.zero_grad()optimizer_tgt.zero_grad()# extract and target featuresfeat_tgt = tgt_encoder(images_tgt)# predict on discriminatorpred_tgt = critic(feat_tgt.pooler_output)# prepare fake labelslabel_tgt = make_variable(torch.ones(feat_tgt.last_hidden_state.size(0)).long().cuda())# compute loss for target encoderloss_tgt = criterion(pred_tgt, label_tgt)loss_tgt.backward()# optimize target encoderoptimizer_tgt.step()######################## 2.3 print step info ########################if (step + 1) % params.log_step == 0:print("Epoch [{}/{}] Step [{}/{}]:""d_loss={:.5f} g_loss={:.5f} acc={:.5f}".format(epoch + 1,params.num_epochs,step + 1,len_data_loader,loss_critic.item(),loss_tgt.item(),acc.item()))############################## 2.4 save model parameters ##############################if ((epoch + 1) % params.save_step == 0):# 保存模型时加上特征提取器的标识符if model_type == "vit":model_name = "ADDA-target-encoder-ViT-{}.pt".format(epoch + 1)elif model_type == "mobilevit":model_name = "ADDA-target-encoder-MobileViT-{}.pt".format(epoch + 1)elif model_type == "swin":model_name = "ADDA-target-encoder-Swin-{}.pt".format(epoch + 1)torch.save(tgt_encoder.state_dict(), os.path.join(params.model_root,model_name))# 保存最终模型时也加上特征提取器的标识符if model_type == "vit":final_model_name = "ADDA-target-encoder-ViT-final.pt"elif model_type == "mobilevit":final_model_name = "ADDA-target-encoder-MobileViT-final.pt"elif model_type == "swin":final_model_name = "ADDA-target-encoder-Swin-final.pt"torch.save(tgt_encoder.state_dict(), os.path.join(params.model_root,final_model_name))return tgt_encoder

MCD:Maximum classifier discrepancy for unsupervised domain adaptation(CVPR 2018)

最大分类器差异的领域自适应
引入两个独立的分类器F1、F2,用二者的分歧表示样本的置信度不高,需要重新训练。在这里插入图片描述
判别损失有两部分组成

MDD:Bridging theory and algorithm for domain adaptation

CDAN:Conditional Adversarial Domain Adaptation(Neural 2018)

条件生成对抗网络,在GAN基础上做的一种改进,通过给原始的GAN的生成器和判别器添加额外的条件信息,实现条件生成模型

复现代码:https://www.cnblogs.com/BlairGrowing/p/17099742.html

提出一个条件对抗性域适应方法(CDAN),对分类器预测中所传递的判别信息建立了对抗性适应模型。条件域对抗性网络(CDAN)采用了两种新的条件调节策略:

  • 多线性条件调节,通过捕获特征表示与分类器预测之间的交叉方差来提高分类器的识别率
  • 熵条件调节,通过控制分类器预测的不确定性来保证分类器的可移植性

MCC:Moment Matching for Multi-Source Domain Adaptation(ICCV 2019)

DAPL:Domain Adaptation via Prompt Learning(DA+prompt)(arXiv 2022)

代码:https://github.com/LeapLabTHU/DAPrompt
使用预训练的视觉语言模型,优化较少的参数,将信息嵌入到提示中,每个域中共享。
只有当图像和文本的领域和类别分别匹配的时候,他们才形成一对正例。

特征提取器优化

  • ViT
    已部署,测试中

  • Swin Transformer:基于 Transformer 结构的新型模型,计算复杂度可能更高一些(对性能要求较高)

  • MobileViT:CNN的轻量高效,transformer的自注意力机制和全局视野,在速度和内存消耗方面优秀(2021)
    文章:MobileViT: Light-Weight, General-Purpose, and Mobile-Friendly Vision Transformer

  • ConvNeXt:结合了CNN和 Transformer 的模型(2022)
    文章:A ConvNet for the 2020s
    ConvNeXt用100多行代码就能搭建完成,相比Swin Transformer拥有更简单,更快的推理速度以及更高的准确率

  • EfficientNetV2:Google 提出的一系列高效的卷积神经网络,通过使用复合缩放方法和网络深度调整策略,实现了在不同任务上的良好性能和高效计算(对移动设备友好)(2021)

  • MobileNetV3:针对移动设备的轻量级卷积神经网络,有更快的推理速度和更低的内存消耗(对移动设备友好)(2019)

PyTorch Hub 下载模型
https://huggingface.co/models

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2808630.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

调度服务看门狗配置

查看当前服务器相关的sqlserver服务 在任务栏右键&#xff0c;选择点击启动任务管理器 依次点击&#xff0c;打开服务 找到sqlserver 相关的服务&#xff0c; 确认这些服务是启动状态 将相关服务在看门狗中进行配置 选择调度服务&#xff0c;双击打开 根据上面找的服务进行勾…

打开 Camera app 出图,前几帧图像偏暗、偏色该怎样去避免?

1、问题背景 使用的安卓平台&#xff0c;客户的应用是要尽可能快的获取到1帧图像效果正常的图片。 但当打开 camera 启动出流后&#xff0c;前3-5帧图像是偏暗、偏色的&#xff0c;如下图所示&#xff0c;是抓取出流的前25帧图像&#xff0c; 前3帧颜色是偏蓝的&#xff0c;…

vue2和vue3 setup beforecreate create生命周期时间比较

创建一个vue程序&#xff0c;vue3可以兼容Vue2的写法&#xff0c;很流畅完全没问题 写了一个vue3组件 <template><div></div> </template><script lang"ts"> import {onMounted} from vue export default{data(){return {}},beforeCr…

操作符详解3

✨✨ 欢迎大家来到莉莉的博文✨✨ &#x1f388;&#x1f388;养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; 前面我们已经讲过算术操作符、赋值操作符、逻辑操作符、条件操作符和部分的单目操作 符&#xff0c;今天继续介绍一部分。 目录 1.操作符的分类 2…

【软件测试面试】要你介绍项目-如何说?完美面试攻略...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、测试面试时&am…

QtRVSim F扩展实现(一):解码阶段

跟踪解码流程可以发现&#xff0c;解码主要是读取 instruction.cpp 里存储的指令集数组进行解码的。 那么对于实现 F 扩展指令集&#xff0c;第一步能成功读取识别新指令&#xff0c;就需要从这里入手。 解码部分代码&#xff1a; static inline const struct InstructionMa…

12. Springboot集成Dubbo3(三)Dubbo-Admin

目录 1、前言 2、安装 2.1、下载Dubbo-admin 2.2、修改配置 2.3、编译前端 2.4、访问 2.5、加载自己的服务 2.6、服务测试 2.7、其他 3、小结 1、前言 Dubbo Admin是用于管理Dubbo服务的基于Web的管理工具。Dubbo Admin提供了一个用户友好的界面&#xff0c;用于在分…

C/C++暴力/枚举/穷举题目持续更新(刷蓝桥杯基础题的进!)

目录 前言 一、百钱买百鸡 二、百元兑钞 三、门牌号码&#xff08;蓝桥杯真题&#xff09; 四、相乘&#xff08;蓝桥杯真题&#xff09; 五、卡片拼数字&#xff08;蓝桥杯真题&#xff09; 六、货物摆放&#xff08;蓝桥杯真题&#xff09; 七、最短路径&#xff08;蓝…

二蛋赠书十六期:《高效使用Redis:一书学透数据存储与高可用集群》

很多人都遇到过这么一道面试题&#xff1a;Redis是单线程还是多线程&#xff1f;这个问题既简单又复杂。说他简单是因为大多数人都知道Redis是单线程&#xff0c;说复杂是因为这个答案其实并不准确。 难道Redis不是单线程&#xff1f;我们启动一个Redis实例&#xff0c;验证一…

【Java程序设计】【C00262】基于Springboot的会员制医疗预约服务管理系统(有论文)

基于Springboot的会员制医疗预约服务管理系统&#xff08;有论文&#xff09; 项目简介项目获取开发环境项目技术运行截图 项目简介 这是一个基于Springboot的会员制医疗预约服务管理信息系统&#xff0c;本系统分为三种角色&#xff1a;管理员、医生和会员&#xff1b; 在系统…

Web3 基金会推出去中心化之声计划:投入高额 DOT 和 KSM ,助力去中心化治理

作者&#xff1a;Web3 Foundation Team 编译&#xff1a;OneBlock 原文&#xff1a;https://medium.com/web3foundation/decentralized-voices-program-93623c27ae43 Web3 基金会为 Polkadot 和 Kusama 创建了去中心化之声计划&#xff08;Decentralized Voices Program&…

【生活】浅浅记录

各位小伙伴们好鸭&#xff0c;今天不是技术文章&#xff0c;浅浅记录一下最近几个月的收获&#x1f60a; 新的一年&#xff0c;一起努力&#xff0c;加油加油&#xff01;

vue3(vite)+electron打包踩坑记录(1)

vue3(vite)electron打包踩坑记录 - 打包vue 第一步 编译vue 使用vite构建vue&#xff0c;package.json如下 {"name": "central-manager","private": true,"version": "0.0.0","type": "commonjs",&q…

2023年总结与2024展望

今天是春节后上班第一天&#xff0c;你懂的&#xff0c;今天基本上是摸鱼状态&#xff0c;早上把我们负责的项目的ppt介绍完善了一下&#xff0c;然后写了一篇技术文章&#xff0c;《分布式系统一致性与共识算法》。接着就看了我近几年写的的年度总结&#xff0c;我一般不会在元…

代码随想录算法训练营day27|39. 组合总和、40.组合总和II

39. 组合总和 如下树形结构如下&#xff1a; 选取第二个数字5之后&#xff0c;剩下的数字要从5、3中取数了&#xff0c;不能再取2了&#xff0c;负责组合就重复了&#xff0c;注意这一点&#xff0c;自己做的时候没想明白这一点 如果是一个集合来求组合的话&#xff0c;就需…

【C++精简版回顾】12.友元函数

1.友元函数 1.class class MM { public:MM(int age,string name):age(age),name(name){}friend void print(MM mm); private:int age;string name;void print() {cout << age << "岁的" << name << "喜欢你" << endl;} }; f…

Redis如何修改key名称

点击上方蓝字关注我 近期出现过多次修改Redis中key名字的场景&#xff0c;本次简介一下如何修改Redis中key名称的方法。 1. 命令行方式修改在Redis中&#xff0c;可以使用rename命令来修改Key的名称。这个命令的基本语法如下&#xff1a; RENAME old_key new_key 在这里&#…

学习或从事鸿蒙开发工作,有学历要求吗?

目前安卓有2,000万的开发者。本科及以上学历占比为35%&#xff1b;iOS有2,400万开发者&#xff0c;本科及以上学历占比为40% 绝大多数的前端开发者都是大专及以下学历&#xff0c;在2023年华为开发者大会上余承东透露华为的开发者目前有200万&#xff0c;但鸿蒙开发者统计的数据…

【GAD】基于邻域重建的图异常检测

GAD-NR: Graph Anomaly Detection via Neighborhood Reconstruction 摘要contributionsMethodologyGAE via Neighborhood Reconstruction邻域重建整体重建损失 实验 WSDM2024Link Code | 摘要 图异常检测&#xff08;GAD&#xff09;是一种用于识别图中异常节点的技术&#x…

istio系列教程

istio学习记录——安装https://suxueit.com/article_detail/otVbfI0BWZdDRfKqvP3Gistio学习记录——体验bookinfo及可视化观测https://suxueit.com/article_detail/o9VdfI0BWZdDRfKqlv0r istio学习记录——kiali介绍https://suxueit.com/article_detail/pNVbfY0BWZdDRfKqX_0K …