Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| def train_one_epoch(model, dataloader, optimizer, criterion, device): | |
| model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for images, labels in dataloader: | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| _, preds = outputs.max(1) | |
| correct += preds.eq(labels).sum().item() | |
| total += labels.size(0) | |
| epoch_loss = running_loss / len(dataloader) | |
| epoch_acc = correct / total | |
| return epoch_loss, epoch_acc | |
| def validate_one_epoch(model, dataloader, criterion, device): | |
| model.eval() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for images, labels in dataloader: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| running_loss += loss.item() | |
| _, preds = outputs.max(1) | |
| correct += preds.eq(labels).sum().item() | |
| total += labels.size(0) | |
| epoch_loss = running_loss / len(dataloader) | |
| epoch_acc = correct / total | |
| return epoch_loss, epoch_acc |