Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active July 12, 2022 22:21
Show Gist options
  • Save thomwolf/18386fc2ed223c65e977b528485acdc1 to your computer and use it in GitHub Desktop.
Save thomwolf/18386fc2ed223c65e977b528485acdc1 to your computer and use it in GitHub Desktop.
Knowledge Distilation
import torch
import torch.nn as nn
from torch.optim import Optimizer
KD_loss = nn.KLDivLoss(reduction='batchmean')
def kd_step(teacher: nn.Module, student: nn.Module, temperature: float,
inputs: torch.tensor, optimizer: Optimizer):
teacher.eval()
student.train()
with torch.no_grad():
logits_t = teacher(inputs=inputs)
logits_s = student(inputs=inputs)
loss = KD_loss(input=F.log_softmax(logits_s/temperature, dim=-1),
target=F.softmax(logits_t/temperature, dim=-1))
loss.backward()
optimizer.step()
optimizer.zero_grad()
@aobaruwa
Copy link

aobaruwa commented Jul 12, 2022

Is there a reason why student logits passes through a log_softmax and teacher logits, regular softmax and not both passing through softmax ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment