NeRF学习——NeRF-Pytorch的源码解读

学习 github 上 NeRF 的 pytorch 实现项目(https://github.com/yenchenlin/nerf-pytorch)的一些笔记

1 参数

部分参数配置:

  1. 训练参数:

    • netdepth:神经网络的层数。默认值为8

    • netwidth:每层的通道数。默认值为256

    • netdepth_fine:精细网络的层数。默认值为8

    • netwidth_fine:精细网络每层的通道数。默认值为256

    • N_rand:批量大小(每个梯度步骤的随机光线数)。默认值为 32 × 32 × 4 32 \times 32 \times 4 32×32×4

    • lrate:学习率。默认值为5e-4

    • lrate_decay:指数学习率衰减(在1000步中)。默认值为250

    • chunk:并行处理的光线数,如果内存不足,可以减少这个值。默认值为1024*32

    • netchunk:并行通过网络发送的点数,如果内存不足,可以减少这个值。默认值为1024*64

    • no_batching:是否只从一张图像中取随机光线

    • no_reload:是否不从保存的检查点重新加载权重

    • ft_path:用于重新加载粗网络的特定权重npy文件。默认值为None

    • precrop_iters:在中心裁剪上训练的步数。默认值为0。如果这个值大于0,那么在训练的开始阶段,模型将只在图像的中心部分进行训练,这可以帮助模型更快地收敛

    • precrop_frac:用于中心裁剪的图像的比例。默认值为0.5。这个值决定了在进行中心裁剪时,应该保留图像的多少部分。例如,如果这个值为0.5,那么将保留图像中心的50%

  2. 渲染参数:

    • N_samples:每条光线的粗采样数。默认64

    • N_importance:每条光线的额外精细采样数(分层采样)。默认0

    • perturb:设置为0表示没有抖动,设置为1表示有抖动。抖动可以增加采样点的随机性。默认1

    • use_viewdirs:是否使用完整的5D输入,而不是3D。5D输入包括3D位置和2D视角

    • i_embed:设置为0表示使用默认的位置编码,设置为-1表示不使用位置编码。默认0

    • multires:位置编码的最大频率的对数(用于3D位置)。默认10

    • multires_views:位置编码的最大频率的对数(用于2D方向)。默认4

      我们设置 d = 10 d=10 d=10 用于位置坐标 ϕ ( x ) ϕ(\bf x) ϕ(x) ,所以输入是60维的向量; d = 4 d=4 d=4 用于相机位姿 ϕ ( d ) ϕ(\bf d) ϕ(d) 对应的则是24维

    • raw_noise_std:添加到 sigma_a 输出的噪声的标准偏差,用于正则化 sigma_a 输出。默认0

    • render_only:如果设置,那么不进行优化,只加载权重并渲染出 render_poses 路径

    • render_test:如果设置,那么渲染测试集,而不是 render_poses 路径

    • render_factor:降采样因子,用于加速渲染。设置为4或8可以快速预览。默认0

  3. LLFF(Light Field Photography)数据集:

    • factor:LLFF图像的降采样因子。默认值为8。这个值决定了在处理LLFF图像时,应该降低多少分辨率

    • no_ndc:是否不使用归一化设备坐标(NDC)。如果在命令行中指定了这个参数,那么其值为True。这个选项应该在处理非前向场景时设置

    • lindisp:是否在视差中线性采样,而不是在深度中采样。如果在命令行中指定了这个参数,那么其值为True

    • spherify:是否处理球形360度场景。如果在命令行中指定了这个参数,那么其值为True

    • llffhold:每N张图像中取一张作为LLFF测试集。默认值为8。这个值决定了在处理LLFF数据集时,应该把多少图像作为测试集

      # 加载数据时,每隔args.llffhold个图像取一张图形
      i_test = np.arange(images.shape[0])[::args.llffhold]
      

2 大致过程

2.1 加载LLFF数据
  1. load_llff_data 函数返回五个值:images(图像),poses(姿态),bds(深度范围),render_poses(渲染姿态)和i_test(测试图像索引)

    • hwf是从poses中提取的图像的高度宽度焦距
    images, poses, bds, render_poses, i_test = load_llff_data(.....)
    hwf = poses[0,:3,-1]
    poses = poses[:,:3,:4]
    
  2. 将图像数据集划分为三个部分:训练集(i_train)、验证集(i_val)和测试集(i_test

    # 每隔args.llffhold个图像取一张做测试集
    i_test = np.arange(images.shape[0])[::args.llffhold]
    # 验证集 = 测试集
    i_val = i_test
    # 所有不在测试集和验证集中的图像
    i_train = np.array([i for i in np.arange(int(images.shape[0])) if(i not in i_test and i not in i_val)])
    
2.2 创建神经网络模型
  1. 将采样点坐标和观察坐标通过位置编码 get_embedder 成63维和27维
  2. 实例化NeRF模型和NeRF精细模型
  3. 创建网络查询函数 network_query_fn() ,用于运行网络
  4. 创建 Adam 优化器
  5. 加载检查点(如果有),即从检查点中重新加载模型和优化器状态
  6. 创建用于训练和测试的渲染参数 render_kwargs_trainrender_kwargs_test
  7. 根据数据集类型(只有LLFF才行)和参数确定是否使用NDC
2.3 准备光线

使用批处理:

  1. 对于每一个姿态,使用get_rays_np函数获取光线原点和方向( ro+rd ),然后将所有的光线堆叠起来,得到rays
  2. 将射线的原点和方向与图像的颜色通道连接起来( ro+rd+rgb
  3. 对张量进行重新排列和整形,只保留训练集中的图像
  4. 对训练数据进行随机重排
2.4 训练迭代
  1. 设置训练迭代次数 N_iters = 200000 + 1

  2. 开始进行训练迭代

    • 准备光线数据:在每次迭代中,从rays_rgb中取出一批(批处理)光线数据,数量为参数值N_rand,并准备好目标值 target_s

      如果完成一个了周期(i_batch >= rays_rgb.shape[0] ),则对数据进行打乱

    • 渲染:使用渲染函数 render()

    • 计算损失:计算渲染结果的损失。这里使用了均方误差损失函数 img2mse() 来计算图像损失
      L = ∑ r ∈ R ∥ C ^ c ( r ) − C ( r ) ∥ 2 2 + ∥ C ^ f ( r ) − C ( r ) ∥ 2 2 \mathcal{L} = \sum_{\mathbf{r} \in \mathcal{R}} \left\| \hat{C}^c(\mathbf{r}) - C(\mathbf{r}) \right\|_2^2 + \left\| \hat{C}^f(\mathbf{r}) - C(\mathbf{r}) \right\|_2^2 L=rR C^c(r)C(r) 22+ C^f(r)C(r) 22

      img2mse = lambda x, y : torch.mean((x - y) ** 2)
      
    • 反向传播:进行反向传播,并执行优化

    • 更新学习率:这里采用指数衰减的学习率调度策略,学习率在每个一定的步骤(decay_steps)内以一定的速率(decay_rate)衰减

  3. 根据参数设置的频率输出相关状态、视频和测试集

3 神经网络模型

模型结构如下:

image-20240316162459526

  • 应用 ReLU 激活函数

  • 采样点坐标和观察坐标通过位置编码成63维和27维

  • 中间有一个跳跃连接在第四次 256->256 的线性层

    跳跃连接可以将某一层的输入直接传递到后面的层,从而避免梯度消失和表示瓶颈,提高网络的性能

4 体积渲染

4.1 render()

渲染主函数是调用 render() 函数:

def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,near=0., far=1.,use_viewdirs=False, c2w_staticcam=None,**kwargs):

其有两种用法:

  1. 测试用:

    rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
    

    c2w=c2w[:3,:4] 意味着光线的起点和方向是由函数内部通过相机参数计算得出的

    这个只在 render_path() 函数中用到,其在给定相机路径下渲染图像

    • 不训练只渲染时直接渲染时
    • 定期输出结果时
  2. 训练用:

    rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,verbose=i < 10, retraw=True,**render_kwargs_train)
    

    rays=batch_rays 意味着光线的起点和方向是预先计算好的,而不是由函数内部通过相机参数计算得出

    这个只在训练迭代时用到:Core optimization loop 中,对从rays_rgb中取出一批(批处理)光线进行渲染,得到的 rgb 值与 target_s (也来自预先计算好的 rays_rgb )计算 loss,来进行神经网络的训练

4.2 batchify_rays()

在主函数 render() 中,渲染工作是调用的 batchify_rays()

主要目的是将大量的光线分批处理,以避免在渲染过程中出现内存溢出(OOM)的问题

4.3 render_rays()

分批处理函数 batchify_rays() 中的渲染操作是由 render_rays() 进行,其是真正的渲染操作的函数

def render_rays(ray_batch,network_fn,network_query_fn,N_samples,retraw=False,lindisp=False,perturb=0.,N_importance=0,network_fine=None,white_bkgd=False,raw_noise_std=0.,verbose=False,pytest=False):

其参数:光线批次(ray_batch)、网络函数(network_fn)、网络查询函数(network_query_fn)、样本数量(N_samples)等等

返回:一个字典 ,包含了 RGB 颜色映射、视差映射、累积不透明度等信息

其大致过程为:

  1. 从光线批次中提取出光线的起点、方向、视线方向以及近远边界

    • 根据是否进行线性分布采样,计算出每个光线上的采样点的深度值

    • 若设置扰动( perturb ),则在每个采样间隔内进行分层随机采样

  2. 函数计算出每个采样点在空间中的位置

    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
    
  3. 然后使用 network_query_fn() 对每个采样点进行预测,得到原始的预测结果 raw

  4. 使用 raw2outputs()(请看下一节4.4) 函数将原始预测结果转换为 RGB 颜色映射、视差映射、累积不透明度等输出

  5. 若分层采样 N_importance > 0,调用 sample_pdf() 分层采样,并将这些额外的采样点传递给精细网络 network_fine 进行预测

  6. 最后,函数返回一个字典,包含了所有的输出结果

4.4 raw2outputs()

其将模型的原始预测转换为语义上有意义的值,主要基于论文中离散形式的积分方程实现:

累积不透明度函数 C ^ ( r ) \hat{C}(r) C^(r) 的估计公式如下:

C ^ ( r ) = ∑ i = 1 N T i ( 1 − exp ⁡ ( − σ i δ i ) ) c i \hat{C}(r) = \sum_{i=1}^{N} T_i (1 - \exp(-\sigma_i \delta_i)) c_i C^(r)=i=1NTi(1exp(σiδi))ci

其中,

  • N N N 是样本点的数量,
  • T i = exp ⁡ ( − ∑ j = 1 i − 1 σ j δ j ) T_i = \exp \left( - \sum_{j=1}^{i-1} \sigma_j \delta_j \right) Ti=exp(j=1i1σjδj) 是权重系数
  • δ i = t i + 1 − t i \delta_i = t_{i+1} - t_i δi=ti+1ti 表示相邻样本之间的距离
  • c i c_i ci 是颜色值
  • σ i \sigma_i σi 是不透明度值(体积密度)

根据代码,我们可以得出以下关系:

  • c i c_i ci 对应着 rgb = torch.sigmoid(raw[...,:3]),表示颜色值
  • σ i \sigma_i σi 对应着 raw[...,3],表示不透明度值

然后,我们可以根据公式中的每个项逐一解释如何在代码中实现:

  1. δ i = t i + 1 − t i \delta_i = t_{i+1} - t_i δi=ti+1ti:计算相邻样本之间的距离。在代码中:

     dists = z_vals[...,1:] - z_vals[...,:-1]
    
  2. 1 − exp ⁡ ( − σ i δ i ) 1 - \exp(-\sigma_i \delta_i) 1exp(σiδi):计算每个样本的不透明度。在代码中:

    raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)alpha = raw2alpha(raw[...,3] + noise, dists)
    
  3. T i = exp ⁡ ( − ∑ j = 1 i − 1 σ j δ j ) T_i = \exp \left( - \sum_{j=1}^{i-1} \sigma_j \delta_j \right) Ti=exp(j=1i1σjδj)​:计算权重系数。在代码中:

    即对 1 − ( 1 − exp ⁡ ( − σ i δ i ) ) 1 - (1 - \exp(-\sigma_i \delta_i)) 1(1exp(σiδi)) 累乘

    torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    
  4. C ^ ( r ) = ∑ i = 1 N T i ( 1 − exp ⁡ ( − σ i δ i ) ) c i \hat{C}(r) = \sum_{i=1}^{N} T_i (1 - \exp(-\sigma_i \delta_i)) c_i C^(r)=i=1NTi(1exp(σiδi))ci​​:计算累积不透明度。在代码中:

    w i = T i ( 1 − exp ⁡ ( − σ i δ i ) ) w_i = T_i(1 - \exp(-\sigma_i\delta_i)) wi=Ti(1exp(σiδi))

    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]
    

最终,代码返回估计的 RGB 颜色、视差图、累积权重、权重以及估计的距离图

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2869919.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

代码随想录算法训练营三刷day25 | 回溯 之 216.组合总和III 17.电话号码的字母组合

三刷day25 216.组合总和III剪枝 17.电话号码的字母组合 216.组合总和III 题目链接 解题思路&#xff1a; 选取过程如图&#xff1a; 图中&#xff0c;可以看出&#xff0c;只有最后取到集合&#xff08;1&#xff0c;3&#xff09;和为4 符合条件。 递归三部曲 确定递归函数参…

如何通过人才测评系统来寻找个人的潜能

潜力这个词&#xff0c;有的时候真是虚无缥缈&#xff0c;人们总说人的潜力是无限&#xff0c;又总说人的潜力是有限的&#xff0c;想一想两句话也都有道理&#xff0c;人的潜能怎么可能无限大&#xff1f;但在某些时候&#xff0c;你也许可以做的更好&#xff0c;但是对于这个…

【Java基础】IO流(三):字符流的FileReader(文件字符输入流)和 FileWriter(文件字节输出流)

目录 字符流 1、FileReader&#xff08;字符输入流&#xff09; 1.1、无参的read( )方法示例 ​编辑 1.2、有参的read(char[ ] buffer)方法示例 2、FileWriter&#xff08;字符输出流&#xff09; 字符流 字符流的底层其实就是字节流&#xff0c;即字符流 字节流 字符集…

sqllab第二十六关通关笔记

知识点&#xff1a; 空格替换 %09 %0a %0b %0c %0d %a0 (%2b)or替换&#xff1a;|| ||是不需要空格区分的and替换&#xff1a;&& &&同样不需要空格区分的双写绕过&#xff0c;但是绕过后需要和内容进行空格区分的&#xff0c;要不然不发挥作用&#xff1b;这关…

unity学习(61)——hierarchy和scene的全新认识+模型+皮肤+动画controller

刚刚开始&#xff0c;但又结束的感觉&#xff1f; 1.对hierarchy和scene中的内容有了全新的认识 一定要清楚自己写过几个scene&#xff1b;每个scene之间如何跳转&#xff1b;build setting是add当前的scene。 2.此时的相机需要与模型同级&#xff0c;不能在把模型放在相机下…

最细节操作 Linux LVM 逻辑卷管理

Linux LVM&#xff08;逻辑卷管理&#xff09; 周末愉快&#xff0c;今天带大家实战一下LVM! 一、LVM理论 LVM&#xff0c;即Logical Volume Manager&#xff0c;逻辑卷管理器&#xff0c;是一种硬盘的虚拟化技术&#xff0c;可以允许用户的硬盘资源进行灵活的调整和动态管理…

数学建模博弈理论与实践国防科大版

目录 4.博弈模型 4.1.Nash平衡点和帕雷托最优 4.2.囚徒困境 4.3.智猪博弈 4.4.脏脸之谜 5.军事问题数学建模 5.1.兰彻斯特作战模型 5.1.1.一般战斗模型 5.1.2游击战模型 5.1.3.混合战模型 5.2.硫磺岛战役 4.博弈模型 本讲介绍博弈模型&#xff0c;包括博弈论&#x…

C/C++整数和浮点数在内存中存储

1. 整数在内存中的存储&#xff1a; 整数的2进制表⽰⽅法有三种&#xff0c;即 原码、反码和补码 三种表⽰⽅法均有符号位和数值位两部分&#xff0c;符号位都是⽤0表⽰“正”&#xff0c;⽤1表⽰“负”&#xff0c;⽽数值位最 ⾼位的⼀位是被当做符号位&#xff0c;剩余的都是…

Redis远程连接本机——Docker

1. Docker拉取redis镜像并创建容器 1.1 拉取redis镜像 如果要指定redis版本&#xff0c;需要使用redis:&#xff08;版本&#xff09;&#xff0c;不写默认最新版本 docker pull redis1.2 创建容器并挂载配置文件 创建一个redis目录&#xff0c;并在其创建一个conf目录和一个d…

Rocky Linux 基本工具的安装

1.系统安装后先查看ip地址 ip addr 2.安装net工具 &#xff1a;ifconfig yum install net-tools 3.安装gcc &#xff1b;选择都选 y yum install gcc yum install gcc-c 4.安装tcl yum install -y tcl 5.安装lsof &#xff08;端口查看工具&#xff09; yum install l…

深度强化学习(七)策略梯度

深度强化学习(七)策略梯度 策略学习的目的是通过求解一个优化问题&#xff0c;学出最优策略函数或它的近似函数&#xff08;比如策略网络&#xff09; 一.策略网络 假设动作空间是离散的,&#xff0c;比如 A { 左 , 右 , 上 } \cal A\{左,右,上\} A{左,右,上}&#xff0c;策…

用SeaTunnel从SQL Server向Elasticsearch同步数据

文章目录 引言I 步骤1.1 环境准备1.2 配置JDBC插件1.3 编写SeaTunnel任务配置II Enable Sql Server CDC引言 SeaTunnel 的官网 https://seatunnel.apache.org/ Support SQL Server Version: server:2008 (Or later version for information only)Supported DataSource Info: …

布隆过滤器原理及应用场景

目录 一、布隆过滤器概述1.1 什么是布隆过滤器1.2 优缺点 二、布隆过滤器原理2.1 布隆过滤器的组成2.2 元素添加和查询 三、 应用场景参考资料 一、布隆过滤器概述 1.1 什么是布隆过滤器 布隆过滤器&#xff08;Bloom Filter&#xff09;是一种数据结构&#xff0c;用于快速检…

[蓝桥杯练习题]Fizz Buzz经典问题

return的艺术 #include<bits/stdc.h> using namespace std; int main(){ios::sync_with_stdio(0);cin.tie(nullptr);cout.tie(nullptr);int n;cin>>n;if(n%50&&n%30)return !(cout<<"FizzBuzz");if(n%30)return !(cout<<"Fizz&…

Microsoft Word 符号 / 特殊符号

Microsoft Word 符号 / 特殊符号 1. 插入 -> 符号 -> 其他符号 -> Wingdings 2References 1. 插入 -> 符号 -> 其他符号 -> Wingdings 2 ​ References [1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/

PHP+golang开源办公系统CRM管理系统

基于ThinkPHP6 Layui MySQL的企业办公系统。集成系统设置、人事管理、消息管理、审批管理、日常办公、客户管理、合同管理、项目管理、财务管理、电销接口集成、在线签章等模块。系统简约&#xff0c;易于功能扩展&#xff0c;方便二次开发。 服务器运行环境要求 PHP > 7.…

AI - 决策树模型

&#x1f914;决策树算法 决策树的思想来源可以追溯到古希腊时期&#xff0c;当时的哲学家们就已经开始使用类似于决策树的图形来表示逻辑推理过程。然而&#xff0c;决策树作为一种科学的决策分析工具&#xff0c;其发展主要发生在20世纪。 在20世纪50年代&#xff0c;美国兰…

mac激活pycharm,python环境安装和包安装问题

1.PyCharm到官网下载就行 地址&#xff1a;Other Versions - PyCharm (jetbrains.com) 2.MacOS 下载python环境&#xff0c;地址&#xff1a; Python Releases for macOS | Python.org 3.PyCharm环境配置&#xff1a; 4. 如果包下载不下来可以换个源试试 pip install py…

【网络原理】TCP 协议中比较重要的一些特性(三)

目录 1、拥塞控制 2、延时应答 3、捎带应答 4、面向字节流 5、异常情况处理 5.1、其中一方出现了进程崩溃 5.2、其中一方出现关机&#xff08;正常流程的关机&#xff09; 5.3、其中一方出现断电&#xff08;直接拔电源&#xff0c;也是关机&#xff0c;更突然的关机&am…

Unity的AssetBundle资源运行内存管理的再次深入思考

大家好&#xff0c;我是阿赵。   这篇文章我想写了很久&#xff0c;是关于Unity项目使用AssetBundle加载资源时的内存管理的。这篇文章不会分享代码&#xff0c;只是分享思路&#xff0c;思路不一定正确&#xff0c;欢迎讨论。   对于Unity引擎的资源内存管理&#xff0c;我…