python3.6
pip3 install torchvision==0.2.0
MNIST手写数字0-9训练集下载 链接:https://pan.baidu.com/s/1kUiARaVTdFNAfJoS5xx6pw 提取码:xttr
错误
raise NotSupportedError(base.range(), "slicing multiple dimensions at the sa
借用一下其他大佬的测试代码
import torch
import torch
.nn
as nn
import torch
.nn
.functional
as F
from torchvision
import datasets
, transforms
class Net(nn
.Module
):
def __init__(self
, in_features
, out_features
):
super(Net
, self
).__init__
()
self
.dnn1
= nn
.Linear
(in_features
, 512)
self
.dnn2
= nn
.Linear
(512, out_features
)
def forward(self
, x
):
x
= F
.relu
(self
.dnn1
(x
))
x
= self
.dnn2
(x
)
return x
net
= Net
(28 * 28, 10)
train_dataset
= datasets
.MNIST
('../data/MNIST', train
=True, download
=False, transform
=transforms
.ToTensor
())
test_dataset
= datasets
.MNIST
('../data/MNIST', train
=False, download
=False, transform
=transforms
.ToTensor
())
train_loader
= torch
.utils
.data
.DataLoader
(train_dataset
, batch_size
=128, shuffle
=True)
test_loader
= torch
.utils
.data
.DataLoader
(test_dataset
, batch_size
=128, shuffle
=False)
criterion
= nn
.CrossEntropyLoss
()
optimizer
= torch
.optim
.Adam
(net
.parameters
(), lr
=0.02)
for epoch
in range(5):
running_loss
, running_acc
= 0.0, 0.0
for i
, data
in enumerate(train_loader
):
img
, label
= data
img
= img
.reshape
(-1, 28 * 28)
out
= net
(img
)
loss
= criterion
(out
, label
)
optimizer
.zero_grad
()
loss
.backward
()
optimizer
.step
()
running_loss
+= loss
.item
() * label
.size
(0)
_
, predicted
= torch
.max(out
, 1)
running_acc
+= (predicted
== label
).sum().item
()
print('Epoch [{}/5], Step [{}/{}], Loss: {:.6f}, Acc: {:.6f}'.format(
epoch
+ 1, i
+ 1, len(train_loader
), loss
.item
(), (predicted
== label
).sum().item
() / 128))
test_loss
, test_acc
= 0.0, 0.0
for i
, data
in enumerate(test_loader
):
img
, label
= data
img
= img
.reshape
(-1, 28 * 28)
out
= net
(img
)
loss
= criterion
(out
, label
)
test_loss
+= loss
.item
() * label
.size
(0)
_
, predicted
= torch
.max(out
, 1)
test_acc
+= (predicted
== label
).sum().item
()
print("Train {} epoch, Loss: {:.6f}, Acc: {:.6f}, Test_Loss: {:.6f}, Test_Acc: {:.6f}".format(
epoch
+ 1, running_loss
/ (len(train_dataset
)), running_acc
/ (len(train_dataset
)),
test_loss
/ (len(test_dataset
)), test_acc
/ (len(test_dataset
))))
效果