内容目录
合集请看:pyTorch入门合集
参考视频:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click
以ReLU和sigmoid为例
核心代码
class module(nn.Module):
def __init__(self):
super().__init__()
self.relu1 = ReLU()
self.sigmoid1 = Sigmoid()
def forward(self, input):
output = self.sigmoid1(input)
return output
完整代码
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import MaxPool2d, ReLU, Sigmoid
# 输入图像的二维矩阵
input = torch.tensor([[1, -0.5],
[-1,3]])
input = torch.reshape(input,(-1, 1, 2, 2))
class module(nn.Module):
def __init__(self):
super().__init__()
self.relu1 = ReLU()
self.sigmoid1 = Sigmoid()
def forward(self, input):
output = self.sigmoid1(input)
return output
test = module()
output = test(input)
print(output)
其实可以发现,写法都是类似的,注意参数即可