007pyTorch入门——最大池化的使用
内容目录

合集请看:pyTorch入门合集
参考视频:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click

池化层相关知识点:[[006深度学习入门——卷积神经网络#池化层]]

核心代码

class module(nn.Module):  
    def __init__(self):  
        super().__init__()  
        self.maxpool1 = MaxPool2d(kernel_size=3, ceil_mode=True)  

    def forward(self, input):  
        output = self.maxpool1(input)  
        return output  

完整代码

import torch  
import torch.nn.functional as F  
from torch import nn  
from torch.nn import MaxPool2d  

# 输入图像的二维矩阵  
input = torch.tensor([[1,2,0,3,1],  
                      [0,1,2,3,1],  
                      [1,2,1,0,0],  
                      [5,2,3,1,1],  
                      [2,1,0,1,1]], dtype=torch.float32)  

input = torch.reshape(input,(-1, 1, 5, 5))  

class module(nn.Module):  
    def __init__(self):  
        super().__init__()  
        self.maxpool1 = MaxPool2d(kernel_size=3, ceil_mode=True)  

    def forward(self, input):  
        output = self.maxpool1(input)  
        return output  

test = module()  
output = test(input)  
print(output)

池化层作用

缩小网络维度,加快运算速度

上一篇
下一篇