【pytorch官方文档学习之七】PyTorch: Custom nn Modules

  • 时间:
  • 浏览:
  • 来源:互联网

本系列旨在通过阅读官方pytorch代码熟悉CNN各个框架的实现方式和流程。

【pytorch官方文档学习之七】PyTorch: Custom nn Modules

  • 本文是对官方文档PyTorch: Custom nn Modules的详细注释和个人理解,欢迎交流。
  • nn.Module比预想中的还要强大,你可以通过其子类(subclass)自定义自己的功能更复杂的模型。
  • 实例
    以下实例是自定义实现的两层神经网络(custom Module subclass)。
# -*- coding: utf-8 -*-
import torch


class TwoLayerNet(torch.nn.Module):  # 定义一个网络需要包含layer定义与forward定义两个部分
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10   # 定义数据维度

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)       # 按照既定维度,定义输入输出数据
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out) # 定义forward的网络

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum') # 构建loss
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) # 构建优化器
for t in range(500): # 训练过程
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x) # forward,通过x计算y的预测值

    # Compute and print loss
    loss = criterion(y_pred, y) # loss,计算真实值y与预测值y的损耗
    if t % 100 == 99:
        print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad() # backward,更新权重
    loss.backward()
    optimizer.step()

本文链接http://xiahunao.cn/article/show-994405.html