博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch(5)
阅读量:4099 次
发布时间:2019-05-25

本文共 1894 字,大约阅读时间需要 6 分钟。

莫烦pytorch,集中优化方式的比较

#optimizer.pyimport torchimport torch.nn.functional as Fimport torch.utils.data as Dataimport matplotlib.pyplot as pltfrom torch.autograd import Variable#hyper parametersLR=0.01BATCH_SIZE=32EPOCH=40x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)y=x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))# plt.scatter(x.numpy(), y.numpy())# plt.show()torch_dataset = Data.TensorDataset(x,y)loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)class Net(torch.nn.Module):    def __init__(self):        super(Net,self).__init__()        self.hidden=torch.nn.Linear(1,20)        self.predict=torch.nn.Linear(20,1)    def forward(self, x):        x = F.relu(self.hidden(x))        x=self.predict(x)        return xnet_SGD= Net()net_Momentum = Net()net_RMSprop = Net()net_Adam = Net()nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR,momentum=0.8)opt_RMSprop=torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9,0.99))optimizers= [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]loss_func = torch.nn.MSELoss()losses_his=[[], [], [], []]for epoch in range(EPOCH):    for step, (batch_x, batch_y) in enumerate(loader):        b_x=batch_x        b_y=batch_y        for net , opt, l_his in zip(nets, optimizers, losses_his):            output = net(b_x)            loss = loss_func(output, b_y)            opt.zero_grad()            loss.backward()            opt.step()            l_his.append(loss.item())labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']for i, l_his in enumerate(losses_his):    plt.plot(l_his, label=labels[i])plt.legend(loc='best')plt.xlabel('Step')plt.ylabel('Loss')plt.ylim((0,0.5))plt.show()# optimizer = torch.optim.SGD()

转载地址:http://lwksi.baihongyu.com/

你可能感兴趣的文章
在JS中 onclick="save();return false;"return false是
查看>>
JSTL 常用标签总结
查看>>
内容里面带标签,在HTML显示问题,JSTL
查看>>
VS编译器运行后闪退,处理方法
查看>>
用div+css做下拉菜单,当鼠标移向2级菜单时,为什么1级菜单的a:hover背景色就不管用了?
查看>>
idea 有时提示找不到类或者符号
查看>>
JS遍历的多种方式
查看>>
ng-class的几种用法
查看>>
node入门demo-Ajax让前端angularjs/jquery与后台node.js交互,技术支持:mysql+html+angularjs/jquery
查看>>
神经网络--单层感知器
查看>>
注册表修改DOS的编码页为utf-8
查看>>
matplotlib.pyplot.plot()参数详解
查看>>
拉格朗日对偶问题详解
查看>>
MFC矩阵运算
查看>>
最小二乘法拟合:原理,python源码,C++源码
查看>>
ubuntu 安装mysql
查看>>
c# 计算器
查看>>
C# 简单的矩阵运算
查看>>
gcc 常用选项详解
查看>>
c++输入文件流ifstream用法详解
查看>>