ResNet:原理与实现

原理

背景与动机

一个基本事实:

  • 随着网络加深,原始输入会被扭曲,信息会丢失,模型的训练误差会增加。(“退化问题”)

抽象描述:

  • 损失了信息的图像 + 损失的信息 = 信息完整的图像

想法:

  • 损失的信息 = 信息完整的图像 - 损失了信息的图像

  • 如果我们能够将损失的信息(称为残差)以某种方式保存起来,然后在后续的网络中加回,则可以一定程度遏制这种扭曲,从而遏制误差变大

  • 于是就有了 ResNet 的概念

残差块(Residual Block)

在传统的深层神经网络中,假设某一层的输入为 $\mathbf{x}$,希望学习一个映射函数 $H(\mathbf{x})$。但是,直接优化这一映射会变得困难,尤其是在层数很深的情况下。

ResNet提出,引入一个恒等映射,令每个残差块学习一个“残差”函数 $F(\mathbf{x}) = H(\mathbf{x}) - \mathbf{x}$,即:


$$
 H(\mathbf{x}) = F(\mathbf{x}) + \mathbf{x} 
$$

这种结构称为“跳跃连接”(Skip Connection),它将输入直接加入到输出,使得梯度可以更顺畅地在网络中传递,缓解了梯度消失和梯度爆炸的问题。

系统分析

从数据流的角度来看,输入数据在残差块中经历以下几步:

  1. 卷积层1:输入 $\mathbf{x}$ 通过第一个卷积层进行特征提取,得到特征图。

  2. 批归一化(Batch Normalization):对卷积输出进行归一化,加快训练速度并稳定训练过程。

  3. 激活函数(ReLU):引入非线性,使得模型能够拟合更复杂的函数。

  4. 卷积层2:再次进行卷积操作,提取更加复杂的特征。

  5. 批归一化:再次对卷积输出进行归一化。

  6. 跳跃连接:将输入 $\mathbf{x}$ 直接加到经过两次卷积和归一化后的输出上,形成最终的输出 $\mathbf{y} = F(\mathbf{x}) + \mathbf{x}$

  7. 激活函数(ReLU):对跳跃连接后的输出进行非线性处理。

整个数据流如图所示:

x -----> [Conv1] -> [BN1] -> [ReLU] -> [Conv2] -> [BN2] ----+
       |                                                    |
       +----------------------------------------------------+
                          |                
                       [ReLU]

4. 网络架构

ResNet的典型架构由多个残差块堆叠而成。以ResNet-18为例:

  • 初始层:一个7x7的卷积层,步幅为2,后接一个3x3的最大池化层,步幅为2。

  • 残差块层:4个阶段,每个阶段包含若干个基本残差块。每个阶段的特征图数量通常为64、128、256、512。

  • 全局平均池化:将特征图压缩为全局特征。

  • 全连接层:输出最终的分类结果。

实现 ResNet

残差块

 1import torch
 2import torch.nn as nn
 3import torch.nn.functional as F
 4
 5class BasicBlock(nn.Module):
 6    expansion = 1  # 对于BasicBlock,通道数扩展倍率为1
 7
 8    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
 9        super(BasicBlock, self).__init__()
10        # 第一个卷积层
11        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
12                               stride=stride, padding=1, bias=False)
13        self.bn1 = nn.BatchNorm2d(out_channels)
14        # 第二个卷积层
15        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
16                               stride=1, padding=1, bias=False)
17        self.bn2 = nn.BatchNorm2d(out_channels)
18        # 下采样
19        self.downsample = downsample
20        self.relu = nn.ReLU(inplace=True)
21
22    def forward(self, x):
23        identity = x
24
25        out = self.conv1(x)       # 卷积层1
26        out = self.bn1(out)       # 批归一化1
27        out = self.relu(out)      # ReLU
28
29        out = self.conv2(out)     # 卷积层2
30        out = self.bn2(out)       # 批归一化2
31
32        if self.downsample is not None:
33            identity = self.downsample(x)  # 下采样调整维度
34
35        out += identity            # 跳跃连接
36        out = self.relu(out)       # ReLU
37
38        return out

核心就是在激活前进行 out += identity

ResNet

 1class ResNet(nn.Module):
 2    def __init__(self, block, layers, num_classes=10):  # 以CIFAR-10为例
 3        super(ResNet, self).__init__()
 4        self.in_channels = 64
 5        # 初始卷积层
 6        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3,
 7                               stride=1, padding=1, bias=False)  # CIFAR-10不需要7x7卷积和池化
 8        self.bn1 = nn.BatchNorm2d(self.in_channels)
 9        self.relu = nn.ReLU(inplace=True)
10        # 残差块层
11        self.layer1 = self._make_layer(block, 64, layers[0])    # 64通道
12        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)  # 128通道,步幅2
13        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)  # 256通道,步幅2
14        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)  # 512通道,步幅2
15        # 全局平均池化和全连接层
16        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
17        self.fc = nn.Linear(512 * block.expansion, num_classes)
18
19        # 权重初始化
20        for m in self.modules():
21            if isinstance(m, nn.Conv2d):
22                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
23            elif isinstance(m, nn.BatchNorm2d):
24                nn.init.constant_(m.weight, 1)
25                nn.init.constant_(m.bias, 0)
26
27    def _make_layer(self, block, out_channels, blocks, stride=1):
28        downsample = None
29        # 如果输入和输出维度不一致,或步幅不为1,需要下采样
30        if stride != 1 or self.in_channels != out_channels * block.expansion:
31            downsample = nn.Sequential(
32                nn.Conv2d(self.in_channels, out_channels * block.expansion,
33                          kernel_size=1, stride=stride, bias=False),
34                nn.BatchNorm2d(out_channels * block.expansion),
35            )
36        layers = []
37        layers.append(block(self.in_channels, out_channels, stride, downsample))
38        self.in_channels = out_channels * block.expansion
39        for _ in range(1, blocks):
40            layers.append(block(self.in_channels, out_channels))  # stride=1
41        return nn.Sequential(*layers)
42
43    def forward(self, x):
44        out = self.conv1(x)    # 初始卷积
45        out = self.bn1(out)
46        out = self.relu(out)
47
48        out = self.layer1(out) # 残差块1
49        out = self.layer2(out) # 残差块2
50        out = self.layer3(out) # 残差块3
51        out = self.layer4(out) # 残差块4
52
53        out = self.avg_pool(out)  # 全局平均池化
54        out = torch.flatten(out, 1)
55        out = self.fc(out)         # 全连接层
56
57        return out

构建 ResNet-18

ResNet-18由4个阶段,每个阶段包含2个BasicBlock。

1def ResNet18(num_classes=10):
2    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

数据集准备

使用CIFAR-10数据集进行训练和测试。

 1import torch
 2import torch.optim as optim
 3import torchvision
 4import torchvision.transforms as transforms
 5
 6# 数据预处理
 7transform_train = transforms.Compose([
 8    transforms.RandomHorizontalFlip(),  # 随机水平翻转
 9    transforms.RandomCrop(32, padding=4),  # 随机裁剪
10    transforms.ToTensor(),
11    transforms.Normalize((0.4914, 0.4822, 0.4465),
12                         (0.2023, 0.1994, 0.2010)),
13])
14
15transform_test = transforms.Compose([
16    transforms.ToTensor(),
17    transforms.Normalize((0.4914, 0.4822, 0.4465),
18                         (0.2023, 0.1994, 0.2010)),
19])
20
21# 加载数据
22trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
23                                        download=True, transform=transform_train)
24trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
25                                          shuffle=True, num_workers=2)
26
27testset = torchvision.datasets.CIFAR10(root='./data', train=False,
28                                       download=True, transform=transform_test)
29testloader = torch.utils.data.DataLoader(testset, batch_size=100,
30                                         shuffle=False, num_workers=2)
31
32classes = ('plane', 'car', 'bird', 'cat', 'deer',
33           'dog', 'frog', 'horse', 'ship', 'truck')

训练与测试

 1def train(model, device, trainloader, optimizer, criterion, epoch):
 2    model.train()
 3    running_loss = 0.0
 4    total = 0
 5    correct = 0
 6    for batch_idx, (inputs, targets) in enumerate(trainloader):
 7        inputs, targets = inputs.to(device), targets.to(device)
 8
 9        optimizer.zero_grad()          # 梯度清零
10        outputs = model(inputs)        # 前向传播
11        loss = criterion(outputs, targets)  # 计算损失
12        loss.backward()                # 反向传播
13        optimizer.step()               # 更新参数
14
15        running_loss += loss.item()
16        _, predicted = outputs.max(1)
17        total += targets.size(0)
18        correct += predicted.eq(targets).sum().item()
19
20        if batch_idx % 100 == 99:  # 每100个batch打印一次
21            print(f'Epoch [{epoch}], Batch [{batch_idx+1}], Loss: {running_loss / 100:.3f}, '
22                  f'Accuracy: {100. * correct / total:.2f}%')
23            running_loss = 0.0
24
25def test(model, device, testloader, criterion):
26    model.eval()
27    test_loss = 0
28    correct = 0
29    total = 0
30    with torch.no_grad():
31        for inputs, targets in testloader:
32            inputs, targets = inputs.to(device), targets.to(device)
33            outputs = model(inputs)
34            loss = criterion(outputs, targets)
35
36            test_loss += loss.item()
37            _, predicted = outputs.max(1)
38            total += targets.size(0)
39            correct += predicted.eq(targets).sum().item()
40    print(f'Test Loss: {test_loss / len(testloader):.3f}, '
41          f'Test Accuracy: {100. * correct / total:.2f}%')
42    return 100. * correct / total

主程序

 1import time
 2
 3device = 'cuda' if torch.cuda.is_available() else 'cpu'
 4print(f'device: {device}')
 5
 6model = ResNet18(num_classes=10).to(device)
 7
 8criterion = nn.CrossEntropyLoss()
 9optimizer = optim.SGD(model.parameters(), lr=0.1,
10                      momentum=0.9, weight_decay=5e-4)
11scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
12
13num_epochs = 100
14best_acc = 0
15
16for epoch in range(1, num_epochs + 1):
17    start_time = time.time()
18    train(model, device, trainloader, optimizer, criterion, epoch)
19    acc = test(model, device, testloader, criterion)
20    if acc > best_acc:
21        best_acc = acc
22        # 保存最佳模型
23        torch.save(model.state_dict(), 'best_resnet18.pth')
24    scheduler.step()
25    end_time = time.time()
26    print(f'Epoch [{epoch}] Done in {end_time - start_time:.2f}s\n')
27
28print(f'Best accuracy {best_acc:.2f}%')

检验性能

1# 加载最佳模型
2model.load_state_dict(torch.load('best_resnet18.pth'))
3test_acc = test(model, device, testloader, criterion)
4print(f'Test accuracy: {test_acc:.2f}%')

Test Acc 能到 93 左右,还行。

$ python impl.py 
Files already downloaded and verified
Files already downloaded and verified
使用设备: cuda
Epoch [1], Batch [100], Loss: 2.422, Accuracy: 20.01%
Epoch [1], Batch [200], Loss: 1.868, Accuracy: 24.83%
Epoch [1], Batch [300], Loss: 1.707, Accuracy: 28.58%
Test Loss: 1.544, Test Accuracy: 43.36%
Epoch [1] Done in 20.87s

Epoch [2], Batch [100], Loss: 1.518, Accuracy: 44.02%
Epoch [2], Batch [200], Loss: 1.406, Accuracy: 46.39%
Epoch [2], Batch [300], Loss: 1.322, Accuracy: 48.25%
Test Loss: 1.230, Test Accuracy: 57.20%
Epoch [2] Done in 20.73s

Epoch [3], Batch [100], Loss: 1.142, Accuracy: 59.45%
Epoch [3], Batch [200], Loss: 1.070, Accuracy: 60.82%
Epoch [3], Batch [300], Loss: 1.013, Accuracy: 61.90%
Test Loss: 1.114, Test Accuracy: 62.81%
Epoch [3] Done in 20.72s

Epoch [4], Batch [100], Loss: 0.906, Accuracy: 67.95%
Epoch [4], Batch [200], Loss: 0.838, Accuracy: 69.22%
Epoch [4], Batch [300], Loss: 0.808, Accuracy: 70.00%
Test Loss: 0.874, Test Accuracy: 69.42%
Epoch [4] Done in 20.72s

Epoch [5], Batch [100], Loss: 0.713, Accuracy: 75.28%
Epoch [5], Batch [200], Loss: 0.706, Accuracy: 75.52%
Epoch [5], Batch [300], Loss: 0.681, Accuracy: 75.67%
Test Loss: 0.752, Test Accuracy: 74.31%
Epoch [5] Done in 20.71s

Epoch [6], Batch [100], Loss: 0.617, Accuracy: 78.60%
Epoch [6], Batch [200], Loss: 0.620, Accuracy: 78.71%
Epoch [6], Batch [300], Loss: 0.597, Accuracy: 78.91%
Test Loss: 0.779, Test Accuracy: 73.72%
Epoch [6] Done in 20.66s

Epoch [7], Batch [100], Loss: 0.556, Accuracy: 80.91%
Epoch [7], Batch [200], Loss: 0.547, Accuracy: 81.07%
Epoch [7], Batch [300], Loss: 0.566, Accuracy: 80.84%
Test Loss: 0.697, Test Accuracy: 75.82%
Epoch [7] Done in 20.75s

Epoch [8], Batch [100], Loss: 0.520, Accuracy: 82.01%
Epoch [8], Batch [200], Loss: 0.528, Accuracy: 82.11%
Epoch [8], Batch [300], Loss: 0.530, Accuracy: 81.93%
Test Loss: 0.625, Test Accuracy: 78.78%
Epoch [8] Done in 20.72s

Epoch [9], Batch [100], Loss: 0.514, Accuracy: 82.43%
Epoch [9], Batch [200], Loss: 0.498, Accuracy: 82.74%
Epoch [9], Batch [300], Loss: 0.492, Accuracy: 82.85%
Test Loss: 0.528, Test Accuracy: 82.16%
Epoch [9] Done in 20.72s

Epoch [10], Batch [100], Loss: 0.469, Accuracy: 83.93%
Epoch [10], Batch [200], Loss: 0.474, Accuracy: 83.75%
Epoch [10], Batch [300], Loss: 0.497, Accuracy: 83.43%
Test Loss: 0.637, Test Accuracy: 78.31%
Epoch [10] Done in 20.69s

Epoch [11], Batch [100], Loss: 0.446, Accuracy: 84.35%
Epoch [11], Batch [200], Loss: 0.469, Accuracy: 84.33%
Epoch [11], Batch [300], Loss: 0.477, Accuracy: 84.14%
Test Loss: 0.527, Test Accuracy: 82.31%
Epoch [11] Done in 20.77s

Epoch [12], Batch [100], Loss: 0.427, Accuracy: 85.59%
Epoch [12], Batch [200], Loss: 0.472, Accuracy: 84.71%
Epoch [12], Batch [300], Loss: 0.472, Accuracy: 84.40%
Test Loss: 0.813, Test Accuracy: 73.67%
Epoch [12] Done in 20.71s

Epoch [13], Batch [100], Loss: 0.430, Accuracy: 85.02%
Epoch [13], Batch [200], Loss: 0.431, Accuracy: 85.09%
Epoch [13], Batch [300], Loss: 0.444, Accuracy: 84.98%
Test Loss: 0.821, Test Accuracy: 75.01%
Epoch [13] Done in 20.74s

Epoch [14], Batch [100], Loss: 0.410, Accuracy: 85.82%
Epoch [14], Batch [200], Loss: 0.431, Accuracy: 85.51%
Epoch [14], Batch [300], Loss: 0.425, Accuracy: 85.48%
Test Loss: 0.635, Test Accuracy: 79.43%
Epoch [14] Done in 20.73s

Epoch [15], Batch [100], Loss: 0.409, Accuracy: 86.09%
Epoch [15], Batch [200], Loss: 0.421, Accuracy: 85.82%
Epoch [15], Batch [300], Loss: 0.417, Accuracy: 85.93%
Test Loss: 0.450, Test Accuracy: 84.87%
Epoch [15] Done in 20.79s

Epoch [16], Batch [100], Loss: 0.385, Accuracy: 86.80%
Epoch [16], Batch [200], Loss: 0.422, Accuracy: 86.24%
Epoch [16], Batch [300], Loss: 0.417, Accuracy: 86.03%
Test Loss: 0.562, Test Accuracy: 81.20%
Epoch [16] Done in 20.70s

Epoch [17], Batch [100], Loss: 0.388, Accuracy: 86.75%
Epoch [17], Batch [200], Loss: 0.420, Accuracy: 86.22%
Epoch [17], Batch [300], Loss: 0.401, Accuracy: 86.27%
Test Loss: 0.840, Test Accuracy: 73.49%
Epoch [17] Done in 20.73s

Epoch [18], Batch [100], Loss: 0.400, Accuracy: 86.34%
Epoch [18], Batch [200], Loss: 0.387, Accuracy: 86.52%
Epoch [18], Batch [300], Loss: 0.399, Accuracy: 86.43%
Test Loss: 0.584, Test Accuracy: 81.25%
Epoch [18] Done in 20.70s

Epoch [19], Batch [100], Loss: 0.386, Accuracy: 87.20%
Epoch [19], Batch [200], Loss: 0.387, Accuracy: 86.97%
Epoch [19], Batch [300], Loss: 0.393, Accuracy: 86.82%
Test Loss: 0.570, Test Accuracy: 81.27%
Epoch [19] Done in 20.70s

Epoch [20], Batch [100], Loss: 0.386, Accuracy: 87.02%
Epoch [20], Batch [200], Loss: 0.391, Accuracy: 86.88%
Epoch [20], Batch [300], Loss: 0.398, Accuracy: 86.80%
Test Loss: 0.781, Test Accuracy: 76.02%
Epoch [20] Done in 20.72s

Epoch [21], Batch [100], Loss: 0.384, Accuracy: 86.77%
Epoch [21], Batch [200], Loss: 0.388, Accuracy: 86.75%
Epoch [21], Batch [300], Loss: 0.401, Accuracy: 86.61%
Test Loss: 0.619, Test Accuracy: 79.30%
Epoch [21] Done in 20.70s

Epoch [22], Batch [100], Loss: 0.372, Accuracy: 87.52%
Epoch [22], Batch [200], Loss: 0.362, Accuracy: 87.67%
Epoch [22], Batch [300], Loss: 0.378, Accuracy: 87.44%
Test Loss: 0.472, Test Accuracy: 84.23%
Epoch [22] Done in 20.71s

Epoch [23], Batch [100], Loss: 0.351, Accuracy: 88.22%
Epoch [23], Batch [200], Loss: 0.368, Accuracy: 87.98%
Epoch [23], Batch [300], Loss: 0.385, Accuracy: 87.66%
Test Loss: 0.591, Test Accuracy: 81.47%
Epoch [23] Done in 20.71s

Epoch [24], Batch [100], Loss: 0.343, Accuracy: 88.38%
Epoch [24], Batch [200], Loss: 0.362, Accuracy: 87.91%
Epoch [24], Batch [300], Loss: 0.380, Accuracy: 87.58%
Test Loss: 0.649, Test Accuracy: 79.41%
Epoch [24] Done in 20.71s

Epoch [25], Batch [100], Loss: 0.358, Accuracy: 87.67%
Epoch [25], Batch [200], Loss: 0.358, Accuracy: 87.72%
Epoch [25], Batch [300], Loss: 0.377, Accuracy: 87.50%
Test Loss: 0.614, Test Accuracy: 79.93%
Epoch [25] Done in 20.70s

Epoch [26], Batch [100], Loss: 0.355, Accuracy: 87.79%
Epoch [26], Batch [200], Loss: 0.351, Accuracy: 87.75%
Epoch [26], Batch [300], Loss: 0.375, Accuracy: 87.57%
Test Loss: 0.623, Test Accuracy: 79.35%
Epoch [26] Done in 20.71s

Epoch [27], Batch [100], Loss: 0.351, Accuracy: 88.19%
Epoch [27], Batch [200], Loss: 0.348, Accuracy: 88.20%
Epoch [27], Batch [300], Loss: 0.366, Accuracy: 87.99%
Test Loss: 0.673, Test Accuracy: 80.00%
Epoch [27] Done in 20.70s

Epoch [28], Batch [100], Loss: 0.349, Accuracy: 88.23%
Epoch [28], Batch [200], Loss: 0.361, Accuracy: 88.01%
Epoch [28], Batch [300], Loss: 0.369, Accuracy: 87.87%
Test Loss: 0.569, Test Accuracy: 82.28%
Epoch [28] Done in 20.71s

Epoch [29], Batch [100], Loss: 0.348, Accuracy: 88.22%
Epoch [29], Batch [200], Loss: 0.347, Accuracy: 88.32%
Epoch [29], Batch [300], Loss: 0.359, Accuracy: 88.14%
Test Loss: 0.478, Test Accuracy: 84.06%
Epoch [29] Done in 20.72s

Epoch [30], Batch [100], Loss: 0.359, Accuracy: 87.91%
Epoch [30], Batch [200], Loss: 0.346, Accuracy: 88.15%
Epoch [30], Batch [300], Loss: 0.360, Accuracy: 87.99%
Test Loss: 0.822, Test Accuracy: 74.72%
Epoch [30] Done in 20.73s

Epoch [31], Batch [100], Loss: 0.322, Accuracy: 88.74%
Epoch [31], Batch [200], Loss: 0.365, Accuracy: 88.21%
Epoch [31], Batch [300], Loss: 0.365, Accuracy: 88.01%
Test Loss: 0.517, Test Accuracy: 82.56%
Epoch [31] Done in 20.70s

Epoch [32], Batch [100], Loss: 0.334, Accuracy: 88.77%
Epoch [32], Batch [200], Loss: 0.325, Accuracy: 88.77%
Epoch [32], Batch [300], Loss: 0.353, Accuracy: 88.43%
Test Loss: 0.591, Test Accuracy: 80.54%
Epoch [32] Done in 20.70s

Epoch [33], Batch [100], Loss: 0.345, Accuracy: 88.23%
Epoch [33], Batch [200], Loss: 0.352, Accuracy: 88.08%
Epoch [33], Batch [300], Loss: 0.367, Accuracy: 87.92%
Test Loss: 0.494, Test Accuracy: 83.82%
Epoch [33] Done in 20.71s

Epoch [34], Batch [100], Loss: 0.339, Accuracy: 88.21%
Epoch [34], Batch [200], Loss: 0.345, Accuracy: 88.20%
Epoch [34], Batch [300], Loss: 0.351, Accuracy: 88.12%
Test Loss: 0.424, Test Accuracy: 85.62%
Epoch [34] Done in 20.76s

Epoch [35], Batch [100], Loss: 0.320, Accuracy: 89.38%
Epoch [35], Batch [200], Loss: 0.349, Accuracy: 88.67%
Epoch [35], Batch [300], Loss: 0.343, Accuracy: 88.48%
Test Loss: 0.448, Test Accuracy: 85.36%
Epoch [35] Done in 20.71s

Epoch [36], Batch [100], Loss: 0.328, Accuracy: 88.77%
Epoch [36], Batch [200], Loss: 0.339, Accuracy: 88.77%
Epoch [36], Batch [300], Loss: 0.374, Accuracy: 88.25%
Test Loss: 0.487, Test Accuracy: 84.15%
Epoch [36] Done in 20.71s

Epoch [37], Batch [100], Loss: 0.329, Accuracy: 88.97%
Epoch [37], Batch [200], Loss: 0.336, Accuracy: 88.75%
Epoch [37], Batch [300], Loss: 0.334, Accuracy: 88.63%
Test Loss: 0.484, Test Accuracy: 84.09%
Epoch [37] Done in 20.70s

Epoch [38], Batch [100], Loss: 0.340, Accuracy: 88.69%
Epoch [38], Batch [200], Loss: 0.329, Accuracy: 88.73%
Epoch [38], Batch [300], Loss: 0.344, Accuracy: 88.58%
Test Loss: 0.497, Test Accuracy: 83.75%
Epoch [38] Done in 20.72s

Epoch [39], Batch [100], Loss: 0.317, Accuracy: 89.23%
Epoch [39], Batch [200], Loss: 0.350, Accuracy: 88.67%
Epoch [39], Batch [300], Loss: 0.345, Accuracy: 88.57%
Test Loss: 0.410, Test Accuracy: 86.56%
Epoch [39] Done in 20.77s

Epoch [40], Batch [100], Loss: 0.331, Accuracy: 88.46%
Epoch [40], Batch [200], Loss: 0.335, Accuracy: 88.48%
Epoch [40], Batch [300], Loss: 0.326, Accuracy: 88.62%
Test Loss: 0.459, Test Accuracy: 85.06%
Epoch [40] Done in 20.71s

Epoch [41], Batch [100], Loss: 0.329, Accuracy: 88.83%
Epoch [41], Batch [200], Loss: 0.335, Accuracy: 88.59%
Epoch [41], Batch [300], Loss: 0.349, Accuracy: 88.40%
Test Loss: 0.532, Test Accuracy: 83.28%
Epoch [41] Done in 20.70s

Epoch [42], Batch [100], Loss: 0.301, Accuracy: 89.88%
Epoch [42], Batch [200], Loss: 0.340, Accuracy: 89.09%
Epoch [42], Batch [300], Loss: 0.349, Accuracy: 88.74%
Test Loss: 0.680, Test Accuracy: 78.78%
Epoch [42] Done in 20.72s

Epoch [43], Batch [100], Loss: 0.320, Accuracy: 89.06%
Epoch [43], Batch [200], Loss: 0.333, Accuracy: 88.81%
Epoch [43], Batch [300], Loss: 0.344, Accuracy: 88.68%
Test Loss: 0.401, Test Accuracy: 86.55%
Epoch [43] Done in 20.73s

Epoch [44], Batch [100], Loss: 0.343, Accuracy: 88.29%
Epoch [44], Batch [200], Loss: 0.335, Accuracy: 88.35%
Epoch [44], Batch [300], Loss: 0.333, Accuracy: 88.47%
Test Loss: 0.475, Test Accuracy: 84.56%
Epoch [44] Done in 20.70s

Epoch [45], Batch [100], Loss: 0.325, Accuracy: 89.16%
Epoch [45], Batch [200], Loss: 0.332, Accuracy: 88.90%
Epoch [45], Batch [300], Loss: 0.335, Accuracy: 88.85%
Test Loss: 0.491, Test Accuracy: 83.45%
Epoch [45] Done in 20.72s

Epoch [46], Batch [100], Loss: 0.328, Accuracy: 88.64%
Epoch [46], Batch [200], Loss: 0.335, Accuracy: 88.62%
Epoch [46], Batch [300], Loss: 0.332, Accuracy: 88.69%
Test Loss: 0.563, Test Accuracy: 81.67%
Epoch [46] Done in 20.70s

Epoch [47], Batch [100], Loss: 0.310, Accuracy: 89.43%
Epoch [47], Batch [200], Loss: 0.333, Accuracy: 88.90%
Epoch [47], Batch [300], Loss: 0.342, Accuracy: 88.62%
Test Loss: 0.556, Test Accuracy: 82.48%
Epoch [47] Done in 20.72s

Epoch [48], Batch [100], Loss: 0.316, Accuracy: 89.09%
Epoch [48], Batch [200], Loss: 0.321, Accuracy: 89.13%
Epoch [48], Batch [300], Loss: 0.333, Accuracy: 89.02%
Test Loss: 0.419, Test Accuracy: 86.28%
Epoch [48] Done in 20.72s

Epoch [49], Batch [100], Loss: 0.320, Accuracy: 88.99%
Epoch [49], Batch [200], Loss: 0.328, Accuracy: 88.95%
Epoch [49], Batch [300], Loss: 0.322, Accuracy: 88.91%
Test Loss: 0.474, Test Accuracy: 84.37%
Epoch [49] Done in 20.71s

Epoch [50], Batch [100], Loss: 0.308, Accuracy: 89.34%
Epoch [50], Batch [200], Loss: 0.333, Accuracy: 89.02%
Epoch [50], Batch [300], Loss: 0.327, Accuracy: 88.96%
Test Loss: 0.503, Test Accuracy: 84.12%
Epoch [50] Done in 20.70s

Epoch [51], Batch [100], Loss: 0.233, Accuracy: 92.05%
Epoch [51], Batch [200], Loss: 0.171, Accuracy: 93.20%
Epoch [51], Batch [300], Loss: 0.167, Accuracy: 93.58%
Test Loss: 0.219, Test Accuracy: 92.78%
Epoch [51] Done in 20.77s

Epoch [52], Batch [100], Loss: 0.139, Accuracy: 95.30%
Epoch [52], Batch [200], Loss: 0.125, Accuracy: 95.55%
Epoch [52], Batch [300], Loss: 0.133, Accuracy: 95.57%
Test Loss: 0.212, Test Accuracy: 92.96%
Epoch [52] Done in 20.89s

Epoch [53], Batch [100], Loss: 0.107, Accuracy: 96.63%
Epoch [53], Batch [200], Loss: 0.109, Accuracy: 96.41%
Epoch [53], Batch [300], Loss: 0.108, Accuracy: 96.40%
Test Loss: 0.199, Test Accuracy: 93.29%
Epoch [53] Done in 20.79s

Epoch [54], Batch [100], Loss: 0.097, Accuracy: 96.86%
Epoch [54], Batch [200], Loss: 0.098, Accuracy: 96.75%
Epoch [54], Batch [300], Loss: 0.092, Accuracy: 96.85%
Test Loss: 0.199, Test Accuracy: 93.65%
Epoch [54] Done in 20.78s

Epoch [55], Batch [100], Loss: 0.084, Accuracy: 97.21%
Epoch [55], Batch [200], Loss: 0.080, Accuracy: 97.30%
Epoch [55], Batch [300], Loss: 0.089, Accuracy: 97.20%
Test Loss: 0.200, Test Accuracy: 93.69%
Epoch [55] Done in 20.77s

Epoch [56], Batch [100], Loss: 0.073, Accuracy: 97.63%
Epoch [56], Batch [200], Loss: 0.082, Accuracy: 97.50%
Epoch [56], Batch [300], Loss: 0.071, Accuracy: 97.54%
Test Loss: 0.212, Test Accuracy: 93.24%
Epoch [56] Done in 20.71s

Epoch [57], Batch [100], Loss: 0.062, Accuracy: 98.00%
Epoch [57], Batch [200], Loss: 0.065, Accuracy: 98.00%
Epoch [57], Batch [300], Loss: 0.070, Accuracy: 97.87%
Test Loss: 0.212, Test Accuracy: 93.45%
Epoch [57] Done in 20.70s

Epoch [58], Batch [100], Loss: 0.059, Accuracy: 98.16%
Epoch [58], Batch [200], Loss: 0.057, Accuracy: 98.19%
Epoch [58], Batch [300], Loss: 0.069, Accuracy: 97.97%
Test Loss: 0.219, Test Accuracy: 93.21%
Epoch [58] Done in 20.71s

Epoch [59], Batch [100], Loss: 0.061, Accuracy: 97.95%
Epoch [59], Batch [200], Loss: 0.052, Accuracy: 98.14%
Epoch [59], Batch [300], Loss: 0.059, Accuracy: 98.12%
Test Loss: 0.211, Test Accuracy: 93.61%
Epoch [59] Done in 20.73s

Epoch [60], Batch [100], Loss: 0.046, Accuracy: 98.57%
Epoch [60], Batch [200], Loss: 0.049, Accuracy: 98.49%
Epoch [60], Batch [300], Loss: 0.057, Accuracy: 98.35%
Test Loss: 0.226, Test Accuracy: 93.26%
Epoch [60] Done in 20.69s

Epoch [61], Batch [100], Loss: 0.046, Accuracy: 98.53%
Epoch [61], Batch [200], Loss: 0.045, Accuracy: 98.51%
Epoch [61], Batch [300], Loss: 0.051, Accuracy: 98.45%
Test Loss: 0.214, Test Accuracy: 93.73%
Epoch [61] Done in 20.77s

Epoch [62], Batch [100], Loss: 0.040, Accuracy: 98.73%
Epoch [62], Batch [200], Loss: 0.046, Accuracy: 98.64%
Epoch [62], Batch [300], Loss: 0.043, Accuracy: 98.62%
Test Loss: 0.227, Test Accuracy: 93.34%
Epoch [62] Done in 20.72s

Epoch [63], Batch [100], Loss: 0.039, Accuracy: 98.73%
Epoch [63], Batch [200], Loss: 0.043, Accuracy: 98.67%
Epoch [63], Batch [300], Loss: 0.043, Accuracy: 98.63%
Test Loss: 0.221, Test Accuracy: 93.51%
Epoch [63] Done in 20.71s

Epoch [64], Batch [100], Loss: 0.039, Accuracy: 98.82%
Epoch [64], Batch [200], Loss: 0.039, Accuracy: 98.77%
Epoch [64], Batch [300], Loss: 0.036, Accuracy: 98.82%
Test Loss: 0.228, Test Accuracy: 93.55%
Epoch [64] Done in 20.71s

Epoch [65], Batch [100], Loss: 0.033, Accuracy: 98.98%
Epoch [65], Batch [200], Loss: 0.038, Accuracy: 98.85%
Epoch [65], Batch [300], Loss: 0.037, Accuracy: 98.81%
Test Loss: 0.230, Test Accuracy: 93.47%
Epoch [65] Done in 20.72s

Epoch [66], Batch [100], Loss: 0.035, Accuracy: 98.84%
Epoch [66], Batch [200], Loss: 0.036, Accuracy: 98.82%
Epoch [66], Batch [300], Loss: 0.037, Accuracy: 98.81%
Test Loss: 0.242, Test Accuracy: 93.31%
Epoch [66] Done in 20.73s

Epoch [67], Batch [100], Loss: 0.036, Accuracy: 98.88%
Epoch [67], Batch [200], Loss: 0.035, Accuracy: 98.88%
Epoch [67], Batch [300], Loss: 0.035, Accuracy: 98.88%
Test Loss: 0.232, Test Accuracy: 93.47%
Epoch [67] Done in 20.73s

Epoch [68], Batch [100], Loss: 0.036, Accuracy: 98.82%
Epoch [68], Batch [200], Loss: 0.034, Accuracy: 98.88%
Epoch [68], Batch [300], Loss: 0.035, Accuracy: 98.90%
Test Loss: 0.237, Test Accuracy: 93.47%
Epoch [68] Done in 20.70s

Epoch [69], Batch [100], Loss: 0.032, Accuracy: 99.02%
Epoch [69], Batch [200], Loss: 0.033, Accuracy: 99.00%
Epoch [69], Batch [300], Loss: 0.037, Accuracy: 98.94%
Test Loss: 0.240, Test Accuracy: 93.52%
Epoch [69] Done in 20.71s

Epoch [70], Batch [100], Loss: 0.032, Accuracy: 99.05%
Epoch [70], Batch [200], Loss: 0.038, Accuracy: 98.90%
Epoch [70], Batch [300], Loss: 0.039, Accuracy: 98.85%
Test Loss: 0.250, Test Accuracy: 93.27%
Epoch [70] Done in 20.72s

Epoch [71], Batch [100], Loss: 0.028, Accuracy: 99.14%
Epoch [71], Batch [200], Loss: 0.034, Accuracy: 99.06%
Epoch [71], Batch [300], Loss: 0.035, Accuracy: 99.03%
Test Loss: 0.262, Test Accuracy: 92.96%
Epoch [71] Done in 20.70s

Epoch [72], Batch [100], Loss: 0.034, Accuracy: 98.91%
Epoch [72], Batch [200], Loss: 0.031, Accuracy: 98.93%
Epoch [72], Batch [300], Loss: 0.039, Accuracy: 98.84%
Test Loss: 0.251, Test Accuracy: 93.20%
Epoch [72] Done in 20.71s

Epoch [73], Batch [100], Loss: 0.032, Accuracy: 98.99%
Epoch [73], Batch [200], Loss: 0.033, Accuracy: 98.98%
Epoch [73], Batch [300], Loss: 0.040, Accuracy: 98.90%
Test Loss: 0.247, Test Accuracy: 93.18%
Epoch [73] Done in 20.72s

Epoch [74], Batch [100], Loss: 0.033, Accuracy: 98.93%
Epoch [74], Batch [200], Loss: 0.033, Accuracy: 98.97%
Epoch [74], Batch [300], Loss: 0.039, Accuracy: 98.86%
Test Loss: 0.234, Test Accuracy: 93.52%
Epoch [74] Done in 20.73s

Epoch [75], Batch [100], Loss: 0.027, Accuracy: 99.13%
Epoch [75], Batch [200], Loss: 0.035, Accuracy: 99.00%
Epoch [75], Batch [300], Loss: 0.034, Accuracy: 98.98%
Test Loss: 0.256, Test Accuracy: 93.11%
Epoch [75] Done in 20.72s

Epoch [76], Batch [100], Loss: 0.041, Accuracy: 98.64%
Epoch [76], Batch [200], Loss: 0.036, Accuracy: 98.77%
Epoch [76], Batch [300], Loss: 0.031, Accuracy: 98.85%
Test Loss: 0.254, Test Accuracy: 93.17%
Epoch [76] Done in 20.72s

Epoch [77], Batch [100], Loss: 0.033, Accuracy: 98.95%
Epoch [77], Batch [200], Loss: 0.035, Accuracy: 98.89%
Epoch [77], Batch [300], Loss: 0.041, Accuracy: 98.79%
Test Loss: 0.287, Test Accuracy: 92.33%
Epoch [77] Done in 20.72s

Epoch [78], Batch [100], Loss: 0.040, Accuracy: 98.78%
Epoch [78], Batch [200], Loss: 0.038, Accuracy: 98.75%
Epoch [78], Batch [300], Loss: 0.038, Accuracy: 98.74%
Test Loss: 0.270, Test Accuracy: 92.80%
Epoch [78] Done in 20.71s

Epoch [79], Batch [100], Loss: 0.043, Accuracy: 98.56%
Epoch [79], Batch [200], Loss: 0.043, Accuracy: 98.52%
Epoch [79], Batch [300], Loss: 0.041, Accuracy: 98.53%
Test Loss: 0.243, Test Accuracy: 93.37%
Epoch [79] Done in 20.72s

Epoch [80], Batch [100], Loss: 0.036, Accuracy: 98.91%
Epoch [80], Batch [200], Loss: 0.043, Accuracy: 98.77%
Epoch [80], Batch [300], Loss: 0.041, Accuracy: 98.70%
Test Loss: 0.255, Test Accuracy: 93.22%