当前位置: 首页 > news >正文

cs231n assignment2 PyTorch

文章目录

    • Barebones PyTorch
      • Three-Layer ConvNet
      • Training a ConvNet
    • PyTorch Module API
      • Module API: Train a Three-Layer ConvNet
    • Part IV. PyTorch Sequential API
    • Part V. CIFAR-10 open-ended challenge
    • 搭建卷积神经网络需要注意的事项

Barebones PyTorch

Three-Layer ConvNet

使用pytorch抽象等级1的方式实现卷积神经网络。

three_layer_convnet()

out1 = F.conv2d(x, conv_w1, bias=conv_b1, stride=1, padding=(2,2))
relu1 = F.relu(out1)
out2 = F.conv2d(relu1, conv_w2, bias=conv_b2, stride=1, padding=(1,1))
relu2 = F.relu(out2)
scores = torch.mm(flatten(relu2), fc_w) + fc_b

Training a ConvNet

使用上面完成的卷积神经网络,训练模型。下面需要完成初始化参数。

conv_w1 = random_weight((channel_1,3,5,5))
conv_b1 = zero_weight((channel_1,))
conv_w2 = random_weight((channel_2,channel_1,3,3))
conv_b2 = zero_weight((channel_2,))
fc_w = random_weight((32*32*channel_2,10))
fc_b = zero_weight((10,))

PyTorch Module API

接下来的代码使用nn.Module来完成,与上面的代码进行比较可以发现,使用nn.Module的代码更具有层次性,符合面向对象的编程思想。并且在上面的代码中,需要手动实现参数初始化,但是在下面的代码中可以直接通过nn的函数来实现。

class ThreeLayerConvNet

def __init__(self, in_channel, channel_1, channel_2, num_classes):super().__init__()self.conv1 = nn.Conv2d(in_channel, channel_1, kernel_size=5,stride=1, padding=2)nn.init.kaiming_normal_(self.conv1.weight)self.conv2 = nn.Conv2d(channel_1, channel_2, kernel_size=3, stride=1,padding=1)nn.init.kaiming_normal_(self.conv2.weight)self.fc = nn.Linear(channel_2 * 32 * 32, num_classes)nn.init.kaiming_normal_(self.fc.weight)
def forward(self, x):scores = Nonerelu1 = F.relu(self.conv1(x))relu2 = F.relu(self.conv2(relu1))scores = self.fc(flatten(relu2)return scores

Module API: Train a Three-Layer ConvNet

使用已经写好的网络,训练一个模型,使之在CIFAR10的准确率达到45%以上。
PyTorch.ipynb

in_channel = 3
num_classes = 10
model = ThreeLayerConvNet(in_channel=in_channel, channel_1=channel_1, channel_2=channel_2, num_classes=num_classes)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

Part IV. PyTorch Sequential API

可以从上面的代码中发现,使用模块API需要完成在__init__中的定义,以及在forward函数中实现每一层的连接。下面使用Sequential API,它把所有步骤合成为一个了,这也就决定了他的灵活性不如Module API,但是对于绝大多数场景来说是够用了。

model = nn.Sequential(nn.Conv2d(3,channel_1,kernel_size=5, padding=2, bias=True),nn.ReLU(),nn.Conv2d(channel_1, channel_2, kernel_size=3, padding=1, bias=True),nn.ReLU(),Flatten(),nn.Linear(channel_2 * 32 * 32, 10, bias=True)
)
optimizer = optim.SGD(model.parameters(), momentum=0.9, lr=learning_rate, nesterov=True)

Part V. CIFAR-10 open-ended challenge

使用自定义的网络结构,完成对CIFAR-10数据集的训练和分类,使之准确率达到70%以上,我才用的结构为
[conv-relu-pool]xN -> [affine]xM -> [softmax or SVM]

channel_1 = 20
channel_2 = 30
learning_rate=0.001model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(2, 2),Flatten(),nn.Linear(128 * 4 * 4, 20),nn.Linear(20, 10)
)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

最后的准确率有76%


搭建卷积神经网络需要注意的事项

在搭建网络的过程中,要注意每一层的结构大小,需要按照一定的公式进行计算。

卷积层:

F为卷积核大小,W为图片的宽,H为图片的高,S为步长,P为padding大小
D:图像深度(通道数),N:卷积核(过滤器)个数卷积层后输出的大小为:W' = (W - F + 2P)/S + 1H' = (H - F + 2P)/S + 1 
卷积后输出图像深度:N' = N

池化层:

    W:图像宽,H:图像高,D:图像深度(通道数)F:卷积核宽高,S:步长池化后的大小为:W=(W-F)/S+1H=(H-F)/S+1
池化后输出图像深度:N' = D(保持上一层不变)

http://www.taodudu.cc/news/show-4974463.html

相关文章:

  • Stanford cs231n'18 课程及作业详细解读
  • CS231n第一节
  • cs231n笔记总结
  • 【实验小结】cs231n assignment1 knn 部分
  • CS231n 两层神经网络反向传播实现
  • 【深度学习】cs231n计算机视觉 CNN(卷积神经网络)
  • FreeCAD错误:没有激活的实体 解决办法
  • springboot 整合mysql clickhouse 多数据源
  • 自定义数据源 整合 Mybatis-Plus-多租户
  • 2020FME博客大赛——FME在数据整合中的应用
  • 从零开始Tableau | 2.数据整合
  • 代码分析 | 单细胞转录组数据整合详解
  • 怎样的数据报表才能将公司全部业务数据整合在一起
  • 数据仓库、数据整合、ETL、ELT和EII之间的区别?
  • 生物信息学|MOLI:基于深度神经网络进行多组学数据整合并用于药物反应预测
  • 数据清洗 Chapter04 | 数据整合
  • 分享一篇 Science 里不同批次的单细胞数据整合及批次校正方法
  • 数据库数据整合
  • 数据挖掘二:数据整合
  • 数据整合基础知识介绍
  • 从零开始设计键值数据库(KEY-VALUE STORE)
  • MySQL键值
  • 常用键值表
  • JAVA怎么给手机发短信对接验证码短信接口DEMO示例
  • 手机短信验证码解决方案
  • python实现发送和获取手机短信验证码
  • 【Python web 开发】获取手机短信验证码接口(1)
  • 手机短信验证码接口在各领域的应用
  • dwg文件导入到supermap显示导入失败问题
  • 全新共享协作体验—CAD 2022新功能介绍