Skip to content

esceptico/squeezer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Squeezer (WIP)

Usage

Step 1: Define Distiller class

Implement teacher_forward, student_forward and (if required) move_batch_to_device methods.

from squeezer import Distiller


class CustomDistiller(Distiller):
    def teacher_forward(self, batch):
        return self.teacher(batch['data'])

    def student_forward(self, batch):
        return self.student(batch['data'])

Step 2: Define LossPolicy

from torch.nn.functional import mse_loss

from squeezer import AbstractDistillationPolicy


class DistillationPolicy(AbstractDistillationPolicy):
    def forward(self, teacher_output, student_output, batch, epoch):
        loss_mse = mse_loss(student_output, teacher_output)
        loss_dict = {'mse': loss_mse.item()}
        return loss_mse, loss_dict

Step 3: Fit

from torch import optim

from squeezer.logging import TensorboardLogger


train_loader = ...

teacher = Teacher()
student = Student()

logger = TensorboardLogger('runs', 'experiment')
optimizer = optim.AdamW(student.parameters(), lr=3e-4)
policy = DistillationPolicy()
distiller = CustomDistiller(teacher, student, policy, optimizer=optimizer, logger=logger)

distiller(train_loader, n_epochs=10)
distiller.save('path_to_some_directory')

About

Lightweight knowledge distillation pipeline

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published