训练深度学习模型时,过拟合是最常见的问题之一。模型在训练集上表现很好,但在验证集或真实数据上效果明显变差。它不是 PyTorch 独有的问题,而是机器学习模型容量、数据规模、训练策略和任务复杂度共同作用的结果。
本文以 PyTorch 训练流程为例,说明如何判断过拟合,以及常用处理手段:数据划分、监控曲线、正则化、Dropout、数据增强、早停和模型简化。
什么是过拟合
过拟合可以简单理解为模型“记住了训练样本”,但没有学到足够通用的规律。典型表现是:
- 训练 loss 持续下降。
- 训练 accuracy 持续上升。
- 验证 loss 下降一段时间后开始上升。
- 验证 accuracy 停滞甚至下降。
只看训练集指标很容易误判。一个模型训练准确率 99%,并不代表它在新数据上也能达到同样效果。
正确划分数据集
判断过拟合的前提是有独立验证集:
1 2 3 4 5 6 7 8 9 10 11
| from torch.utils.data import random_split
dataset = MyDataset()
train_size = int(len(dataset) * 0.8) val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split( dataset, [train_size, val_size], )
|
训练集用于更新参数,验证集只用于评估模型。不要在训练过程中对验证集反向传播,也不要根据测试集结果反复调参。
如果数据存在时间顺序,例如订单、传感器、日志,不建议随机切分。应该按时间切分,避免未来数据泄漏到训练集。
记录训练和验证指标
一个最小训练循环:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| def train_one_epoch(model, loader, criterion, optimizer, device): model.train() total_loss = 0.0
for x, y in loader: x = x.to(device) y = y.to(device)
optimizer.zero_grad() output = model(x) loss = criterion(output, y) loss.backward() optimizer.step()
total_loss += loss.item() * x.size(0)
return total_loss / len(loader.dataset)
|
验证循环要关闭梯度,并切换到 eval 模式:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| def evaluate(model, loader, criterion, device): model.eval() total_loss = 0.0 correct = 0
with torch.no_grad(): for x, y in loader: x = x.to(device) y = y.to(device)
output = model(x) loss = criterion(output, y) total_loss += loss.item() * x.size(0)
pred = output.argmax(dim=1) correct += (pred == y).sum().item()
avg_loss = total_loss / len(loader.dataset) acc = correct / len(loader.dataset) return avg_loss, acc
|
model.eval() 会影响 Dropout、BatchNorm 等层的行为;torch.no_grad() 可以减少显存占用并加速验证。
观察指标曲线
主训练循环可以记录每个 epoch 的指标:
1 2 3 4 5 6 7 8 9 10
| for epoch in range(epochs): train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) val_loss, val_acc = evaluate(model, val_loader, criterion, device)
print( f"epoch={epoch} " f"train_loss={train_loss:.4f} " f"val_loss={val_loss:.4f} " f"val_acc={val_acc:.4f}" )
|
如果出现训练 loss 继续降低,而验证 loss 连续多个 epoch 上升,就要怀疑过拟合。
建议保存日志到 TensorBoard、CSV 或实验管理工具中,曲线比单个数值更容易判断趋势。
使用 weight decay
权重衰减是常见正则化方法,可以限制模型参数过大:
1 2 3 4 5
| optimizer = torch.optim.Adam( model.parameters(), lr=1e-3, weight_decay=1e-4, )
|
weight_decay 不宜盲目设置过大。过大可能导致欠拟合,表现为训练集和验证集指标都不好。可以从 1e-5、1e-4、1e-3 等量级尝试。
增加 Dropout
对于全连接层较多的模型,可以加入 Dropout:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| class Classifier(nn.Module): def __init__(self, input_dim, num_classes): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, 256), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(256, 128), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(128, num_classes), )
def forward(self, x): return self.net(x)
|
Dropout 在训练模式下随机丢弃部分神经元,在验证模式下关闭。因此训练和验证时必须正确调用 model.train() 和 model.eval()。
数据增强
图像任务中,数据增强通常比单纯调整模型更有效:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| from torchvision import transforms
train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), ])
val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), ])
|
训练集可以使用随机增强,验证集应使用稳定变换。不要对验证集做随机裁剪、随机翻转,否则指标会不稳定。
文本、表格、时序任务也可以做增强,但要保证增强后的样本仍然符合业务语义。
早停机制
早停可以在验证集不再提升时停止训练:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| best_loss = float("inf") patience = 5 bad_epochs = 0
for epoch in range(epochs): train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) val_loss, val_acc = evaluate(model, val_loader, criterion, device)
if val_loss < best_loss: best_loss = val_loss bad_epochs = 0 torch.save(model.state_dict(), "best.pt") else: bad_epochs += 1
if bad_epochs >= patience: print("early stopping") break
|
训练结束后应加载验证集表现最好的模型,而不是最后一个 epoch 的模型。
简化模型
如果数据量很小,而模型参数很多,过拟合会更明显。可以尝试:
- 减少网络层数。
- 减少隐藏层维度。
- 使用预训练模型并冻结部分层。
- 增加训练数据。
- 合并稀疏类别或清理噪声标签。
不要只依赖调参解决数据质量问题。标签错误、样本重复、训练和验证分布不一致都会造成指标异常。
小结
判断过拟合要同时观察训练集和验证集指标。处理过拟合可以从数据、模型和训练策略三个方向入手:保证验证集独立,记录指标曲线,加入 weight decay 和 Dropout,使用合理数据增强,必要时早停或简化模型。最终目标不是让训练集指标最高,而是让模型在未见过的数据上表现稳定。