pytorch用nn实现逻辑回归(logistic回归)

it2024-04-02  62

import torch import torchvision from torch import nn import torchvision.transforms as transforms mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor()) mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor()) batch_size = 256 train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, # num_workers=num_workers ) test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, # num_workers=num_workers ) class SoftmaxReg(nn.Module): def __init__(self,n_feature,n_outputs): super(SoftmaxReg, self).__init__() self.linear=nn.Linear(n_feature,n_outputs) def forward(self,input): sigma=self.linear(input) # 多项交叉熵损失函数做softmax所以这里不用做 return sigma num_input=28*28#输入的维度 num_output=10#预测的维度 net=SoftmaxReg(num_input,num_output) from torch.nn import init init.normal_(net.linear.weight,mean=0,std=0.01) init.constant_(net.linear.bias,val=0) loss=nn.CrossEntropyLoss()#pytorch中的交叉熵损失函数已经包含了softmax计算,所以直接输入原始的线性结果就行,出来的就是概率 import torch.optim as optim optimizer=optim.SGD(net.parameters(),lr=0.03) num_epochs=10 for epoch in range(1,num_epochs+1): for Xx,yy in train_iter: Xx = Xx.view(Xx.size()[0], Xx.size()[1], -1) output=net(Xx) l=loss(output.view(output.size()[0],-1),yy) optimizer.zero_grad() l.backward() optimizer.step() allTrain=0 rightTrain=0 allTest=0 rightTest=0 #训练集上的正确率 for train_x,train_y in train_iter: allTrain+=len(train_y) train_x=train_x.view(train_x.size()[0], train_x.size()[1], -1) trainOut=net(train_x) correct=torch.softmax(trainOut.view(trainOut.size()[0],-1),dim=1).argmax(dim=1)==train_y rightTrain+=sum(correct).item() #测试集上的正确率 for test_x,test_y in test_iter: allTest+=len(test_y) test_x=test_x.view(test_x.size()[0], test_x.size()[1], -1) testOut=net(test_x) correct=torch.softmax(testOut.view(testOut.size()[0],-1),dim=1).argmax(dim=1)==test_y rightTest+=sum(correct).item() print("epoch%d,损失:%f,训练集上正确率%f,测试集上的正确率%f" % (epoch, l.item(),rightTrain/allTrain,rightTest/allTest)) # print(net[0].weight[0],net[0].bias)
最新回复(0)