Created
May 5, 2020 02:39
-
-
Save xuhdev/58c494ccfb6ed3f8236b85fc1e4964b3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for epoch in range(1): # loop over the dataset multiple times | |
running_loss = 0.0 | |
for i, data in enumerate(train_loader, 0): | |
# get the inputs; data is a list of [inputs, labels] | |
inputs, labels = data | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# forward | |
outputs = net(inputs) | |
loss = criterion(outputs, labels) | |
# backward (differentiate) | |
loss.backward() | |
# optimize (update) | |
optimizer.step() | |
# print statistics | |
running_loss += loss.item() | |
if i % 3000 == 2999: | |
# print every 3000 mini-batches | |
print(f'Epoch: {epoch + 1}, Iteration: {i + 1}, loss: {running_loss / 3000}') | |
running_loss = 0.0 | |
# Test accuracy | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for data in test_loader: | |
images, labels = data | |
outputs = net(images) | |
_, predicted = torch.max(outputs.data, 1) # The label with the maximum probability is predicted | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print(f'Accuracy of the network on the test images: {(100 * correct / total)} %') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment