import torch
import torchvision
from torch
import nn
import torchvision
.transforms
as transforms
mnist_train
= torchvision
.datasets
.FashionMNIST
(root
='~/Datasets/FashionMNIST', train
=True, download
=True, transform
=transforms
.ToTensor
())
mnist_test
= torchvision
.datasets
.FashionMNIST
(root
='~/Datasets/FashionMNIST', train
=False, download
=True, transform
=transforms
.ToTensor
())
batch_size
= 256
train_iter
= torch
.utils
.data
.DataLoader
(mnist_train
,
batch_size
=batch_size
,
shuffle
=True,
)
test_iter
= torch
.utils
.data
.DataLoader
(mnist_test
,
batch_size
=batch_size
,
shuffle
=False,
)
class SoftmaxReg(nn
.Module
):
def __init__(self
,n_feature
,n_outputs
):
super(SoftmaxReg
, self
).__init__
()
self
.linear
=nn
.Linear
(n_feature
,n_outputs
)
def forward(self
,input):
sigma
=self
.linear
(input)
return sigma
num_input
=28*28
num_output
=10
net
=SoftmaxReg
(num_input
,num_output
)
from torch
.nn
import init
init
.normal_
(net
.linear
.weight
,mean
=0,std
=0.01)
init
.constant_
(net
.linear
.bias
,val
=0)
loss
=nn
.CrossEntropyLoss
()
import torch
.optim
as optim
optimizer
=optim
.SGD
(net
.parameters
(),lr
=0.03)
num_epochs
=10
for epoch
in range(1,num_epochs
+1):
for Xx
,yy
in train_iter
:
Xx
= Xx
.view
(Xx
.size
()[0], Xx
.size
()[1], -1)
output
=net
(Xx
)
l
=loss
(output
.view
(output
.size
()[0],-1),yy
)
optimizer
.zero_grad
()
l
.backward
()
optimizer
.step
()
allTrain
=0
rightTrain
=0
allTest
=0
rightTest
=0
for train_x
,train_y
in train_iter
:
allTrain
+=len(train_y
)
train_x
=train_x
.view
(train_x
.size
()[0], train_x
.size
()[1], -1)
trainOut
=net
(train_x
)
correct
=torch
.softmax
(trainOut
.view
(trainOut
.size
()[0],-1),dim
=1).argmax
(dim
=1)==train_y
rightTrain
+=sum(correct
).item
()
for test_x
,test_y
in test_iter
:
allTest
+=len(test_y
)
test_x
=test_x
.view
(test_x
.size
()[0], test_x
.size
()[1], -1)
testOut
=net
(test_x
)
correct
=torch
.softmax
(testOut
.view
(testOut
.size
()[0],-1),dim
=1).argmax
(dim
=1)==test_y
rightTest
+=sum(correct
).item
()
print("epoch%d,损失:%f,训练集上正确率%f,测试集上的正确率%f" % (epoch
, l
.item
(),rightTrain
/allTrain
,rightTest
/allTest
))
转载请注明原文地址: https://lol.8miu.com/read-15304.html