Implement teacher_forward, student_forward
and (if required) move_batch_to_device methods.
from squeezer import Distiller
class CustomDistiller(Distiller):
def teacher_forward(self, batch):
return self.teacher(batch['data'])
def student_forward(self, batch):
return self.student(batch['data'])from torch.nn.functional import mse_loss
from squeezer import AbstractDistillationPolicy
class DistillationPolicy(AbstractDistillationPolicy):
def forward(self, teacher_output, student_output, batch, epoch):
loss_mse = mse_loss(student_output, teacher_output)
loss_dict = {'mse': loss_mse.item()}
return loss_mse, loss_dictfrom torch import optim
from squeezer.logging import TensorboardLogger
train_loader = ...
teacher = Teacher()
student = Student()
logger = TensorboardLogger('runs', 'experiment')
optimizer = optim.AdamW(student.parameters(), lr=3e-4)
policy = DistillationPolicy()
distiller = CustomDistiller(teacher, student, policy, optimizer=optimizer, logger=logger)
distiller(train_loader, n_epochs=10)
distiller.save('path_to_some_directory')