前言
在使用Pytorch训练模型时,用到python中的item()函数,如:
train_loss += loss.item()
现对item()函数用法做出总结。item()函数的作用是从包含单个元素的张量中取出该元素值,并保持该元素的类型不变。,即:该元素为整形,则返回整形,该元素为浮点型,则返回浮点型。官网解释如下:
Pytorch官网:https://pytorch.org/docs/stable/tensors.html?highlight=item#torch.Tensor.item
实验
做个测试:
import torchx = torch.randn(2, 2)print(x)
print(x[0,0])
print(x[0,0].item())
Output:
tensor([[-0.1405, 2.4767],[-0.6847, 0.0057]])
tensor(-0.1405)
-0.14052967727184296
总结
- 计算loss或者accuracy时,经常使用item()函数,而不是直接取对应的元素x[i,j]。
- item()函数取值时,保持该元素的类型不变。