1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
| ''' @File: minist.py.py @Contact: 2257925767@qq.com @Author:wangyu @Version:
@Desciption: 手写数字识别的核心代码
还存在一些问题没有解决---自己得到的数据比老师的代码迭代的次数要少很多 env: pytorch 1.3.1 @DateTime: 2020/8/22下午5:04 '''
import torch from torch import nn from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from util import plot_image, plot_curve, one_hot
batch_size = 512
train_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST("mnist_data",train=True,download=True, transform= torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,),(0.3081,)), ])),batch_size = batch_size,shuffle = True)
test_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST("mnist_data/",train=False,download=True, transform= torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,),(0.3081,)), ])),batch_size = batch_size,shuffle = False)
x,y = next(iter(train_loader)) print(x.shape,y.shape,x.min(),x.max()) plot_image(x,y,'image_test')
class Net(nn.Module):
def __init__(self): super(Net,self).__init__()
self.fc1 = nn.Linear(28*28,256) self.fc2 = nn.Linear(256,64) self.fc3 = nn.Linear(64,10)
def forward(self,x):
x = F.relu(self.fc1(x)) x=F.relu(self.fc2(x)) x=self.fc3(x)
return x
net = Net()
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
train_loss=[]
for epoch in range(3): for batch_idx , (x,y) in enumerate(train_loader):
x =x.view(x.size(0),28*28) out = net(x) y_onehot = one_hot(y)
loss = F.mse_loss(out,y_onehot)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss.append(loss.item())
if batch_idx % 10 == 0: print(epoch,batch_idx,loss.item())
plot_curve(train_loss)
total_correct = 0 for x,y in test_loader: x = x.view(x.size(0),28*28) out = net(x) pred = out.argmax(dim=1)
correct = pred.eq(y).sum().float().item() total_correct+= correct
total_num = len(test_loader.dataset) acc=total_correct/total_num print('test acc:',acc)
x,y = next(iter(test_loader)) out = net(x.view(x.size(0),28*28)) pred = out.argmax(dim=1) plot_image(x,pred,'test')
|