008pyTorch入门——非线性激活
内容目录

合集请看:pyTorch入门合集

参考视频:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click

以ReLU和sigmoid为例

核心代码

class module(nn.Module):  
    def __init__(self):  
        super().__init__()  
        self.relu1 = ReLU()  
        self.sigmoid1 = Sigmoid()  

    def forward(self, input):  
        output = self.sigmoid1(input)  
        return output

完整代码

import torch  
import torch.nn.functional as F  
from torch import nn  
from torch.nn import MaxPool2d, ReLU, Sigmoid  

# 输入图像的二维矩阵  
input = torch.tensor([[1, -0.5],  
                            [-1,3]])  

input = torch.reshape(input,(-1, 1, 2, 2))  

class module(nn.Module):  
    def __init__(self):  
        super().__init__()  
        self.relu1 = ReLU()  
        self.sigmoid1 = Sigmoid()  

    def forward(self, input):  
        output = self.sigmoid1(input)  
        return output  

test = module()  
output = test(input)  
print(output)

其实可以发现,写法都是类似的,注意参数即可

上一篇
下一篇