1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
| from os import path, mkdir from random import randint
import torch import numpy as np import torchvision from matplotlib import pyplot as plt from torchvision.datasets import MNIST from torchvision.transforms import ToTensor from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data.dataloader import DataLoader import torch.nn.functional as F import torch.nn as nn
dataset = MNIST(root="./data", download=True, transform=ToTensor()) test_dataset = MNIST(root='./data', train=False, transform=ToTensor())
def split_indices(n, rate): n_val = int(n * rate) idxs = np.random.permutation(n) return idxs[n_val:], idxs[:n_val]
train_indices, val_indices = split_indices(len(dataset), 0.2)
batch_size = 100 train_sampler = SubsetRandomSampler(train_indices) train_loder = DataLoader(dataset, batch_size, sampler=train_sampler)
val_sampler = SubsetRandomSampler(val_indices) val_loder = DataLoader(dataset, batch_size, sampler=val_sampler)
input_size = 28 * 28 num_classes = 10
class MnistModel(nn.Module):
def __init__(self, in_size, hidden_size, out_size): super().__init__()
self.linear1 = nn.Linear(in_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, out_size)
def forward(self, xb): xb = xb.view(xb.size(0), -1) return self.linear2(F.relu(self.linear1(xb)))
def get_device(): if torch.cuda.is_available(): return torch.device('cuda') else: return torch.device('cpu')
def to_device(data, device): if isinstance(data, (list, tuple)): return [to_device(x, device) for x in data] return data.to(device, non_blocking=True)
class DeviceDataLoder(): def __init__(self, dl, device): self.dl = dl self.device = device
def __iter__(self): for b in self.dl: yield to_device(b, self.device)
def __len__(self): return len(self.dl)
train_dl = DeviceDataLoder(train_loder, get_device()) valid_dl = DeviceDataLoder(val_loder, get_device())
def loss_batch(model, loss_func, xb, yb, opt=None, metric=None): preds = model(xb)
loss = loss_func(preds, yb)
if opt is not None: loss.backward() opt.step() opt.zero_grad()
metric_result = None if metric is not None: metric_result = metric(preds, yb)
return loss.item(), len(xb), metric_result
def evaluate(model, loss_func, valid_dl, metric=None): with torch.no_grad(): results = [loss_batch(model, loss_func, xb, yb, metric=metric) for xb, yb in valid_dl]
loss, nums, metric = zip(*results) total = np.sum(nums) avg_loss = np.sum(np.multiply(loss, nums)) / total avg_metric = None if metric is not None: avg_metric = np.sum(np.multiply(metric, nums)) / total return avg_loss, total, avg_metric
def fit(epochs, lr, model, loss_func, train_dl, valid_dl, opt_fn=None, metric=None): if opt_fn is None: opt_fn = torch.optim.SGD opt = opt_fn(model.parameters(), lr=lr) loss_history = [] metric_history = []
for epoch in range(epochs): for xb, yb in train_dl: loss_batch(model, loss_func, xb, yb, opt) result = evaluate(model, loss_func, valid_dl, metric) val_loss, total, val_metric = result
loss_history.append(val_loss) metric_history.append(val_metric)
if metric is not None: print(f'Epoch [{epoch + 1}/{epochs}], Loss: {val_loss:.4f}, Metric: {val_metric:.4f}') else: print(f'Epoch [{epoch + 1}/{epochs}], Loss: {val_loss:.4f}')
return loss_history, metric_history
def accuracy(output, label): _, preds = torch.max(output, dim=1) return torch.sum(label == preds).item() / len(preds)
model = MnistModel(input_size, 32, num_classes) to_device(model, get_device())
if path.exists('./tutorial5/mnist-logistic.pth'): model.load_state_dict(torch.load('./tutorial5/mnist-logistic.pth'))
else: loss_history, metric_history = fit(5, 0.5, model, F.cross_entropy, train_dl, valid_dl, opt_fn=torch.optim.SGD, metric=accuracy) mkdir('./tutorial5') torch.save(model.state_dict(), './tutorial5/mnist-logistic.pth')
def prediction_img(img, model): xb = img.unsqueeze(0) yb = model(xb) _, preds = torch.max(yb, dim=1) return preds[0].item()
for i in range(10): img, label = test_dataset[randint(0, len(test_dataset) - 1)] img_np = np.array(img) plt.imshow(img_np.squeeze(), cmap='gray') plt.show() print(prediction_img(img, model))
|