1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
def loss_batch(model, loss_func, xb, yb, opt=None, metric=None): preds = model(xb) loss = loss_func(preds, yb)
if opt is not None: loss.backward() opt.step() opt.zero_grad()
metric_result = None if metric is not None: metric_result = metric(preds, yb)
return loss.item(), len(xb), metric_result
def evaluate(model, loss_func, valid_dl, metric=None): with torch.no_grad(): results = [loss_batch(model, loss_func, xb, yb, metric=metric) for xb, yb in valid_dl]
loss, nums, metric = zip(*results) total = np.sum(nums) avg_loss = np.sum(np.multiply(loss, nums)) / total avg_metric = None if metric is not None: avg_metric = np.sum(np.multiply(metric, nums)) / total return avg_loss, total, avg_metric
def accuracy(output, label): _, preds = torch.max(output, dim=1) return torch.sum(label == preds).item() / len(preds)
def fit(epochs, model, loss_fn, opt, train_dl, valid_dl, metric=None): for epoch in range(epochs): for xb, yb in train_dl: loss, _, _ = loss_batch(model, loss_fn, xb, yb, opt, metric=metric)
result = evaluate(model, loss_fn, valid_dl, metric=metric) val_loss, total, val_metric = result
if metric is None: print("Epoch [{}/{}], total:{:.4f}, Loss: {:.4f}" .format(epoch + 1, epochs, total, val_loss, val_metric)) else: print("Epoch [{}/{}], total:{:.4f}, Loss: {:.4f}, {}: {:.4f}" .format(epoch + 1, epochs, total, val_loss, metric.__name__, val_metric))
model = MnistModel()
if path.exists('mnist-logistic.pth'): model.load_state_dict(torch.load('mnist-logistic.pth'))
else: fit(5, model, F.cross_entropy, torch.optim.SGD(model.parameters(), lr=0.001), train_loder, val_loder, metric=accuracy) torch.save(model.state_dict(), 'mnist-logistic.pth')
def prediction_img(img, model): xb = img.unsqueeze(0) yb = model(xb) _, preds = torch.max(yb, dim=1) return preds[0].item()
for i in range(10): img, label = test_dataset[randint(0, len(test_dataset) - 1)] img_np = np.array(img) plt.imshow(img_np.squeeze(), cmap='gray') plt.show() print(prediction_img(img, model))
|