Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions config/configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ checkpoints: "./checkpoints"

# Dataset Options
dataset: "TCGA"
subset: "dropGradNaN"

# Model options
pretrained: "gagan3012/swinv2_1024"
pretrained: ""
dis_gene: ['idh mutation', 'codeletion', 'PTEN', 'EGFR', 'CARD11', 'FGFR2']
float_gene: ['10q', '10p', '7p', '7q']
patch_size: 4
window_size: 8
patch_size: 16
window_size: 7

# training options
seed: 42
Expand Down
10 changes: 7 additions & 3 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def __init__(self, args, data, gene_list, split='train'):
A.VerticalFlip(p=.5),
A.RandomRotate90(p=.5),
A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=.5),
A.OneOf([
A.ElasticTransform(p=.5),
A.GridDistortion(p=.5),
A.OpticalDistortion(p=.5),
], p=.5),
A.OneOf([
A.RandomGridShuffle(grid=(3, 3), p=.5),
A.RandomGridShuffle(grid=(7, 7), p=.5),
Expand Down Expand Up @@ -68,9 +73,8 @@ def __getitem__(self, index):
common = self.spatial_transform(image=img)['image']
view1 = self.color_transform(image=common)['image']
view2 = self.color_transform(image=common)['image']
view3 = self.color_transform(image=common)['image']
return view1, view2, view3, dis_gene, float_gene, grade
return view1, view2, dis_gene, float_gene, grade
else:
img = self.test_transform(image=img)['image']
return img, dis_gene, float_gene, grade
return img, grade

61 changes: 27 additions & 34 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
from torch.utils.data import DataLoader
from utils import yaml_config_hook, convert_model, train, MultiHeadContrastiveLoss, RegionContrastiveLoss, GeneGuidance
from utils import yaml_config_hook, convert_model, train, generate_splits
import warnings


Expand Down Expand Up @@ -59,7 +59,7 @@ def main(gpu, args, wandb_logger):
if rank == 0:
test_dataset = TCGADataset(args, data_cv_split, gene_names, split='test')
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
num_workers=args.workers, pin_memory=True)
else:
test_loader = None

Expand All @@ -68,47 +68,38 @@ def main(gpu, args, wandb_logger):
num_classes = train_dataset.num_classes

# model init
global_model = SwinTransformer(image_size=args.image_size, num_classes=num_classes,
pretrained=args.pretrained, patch_size=args.patch_size,
window_size=args.window_size,)
# add three dummy classes for the normal and other types
local_model = SwinTransformer(image_size=args.image_size, num_classes=num_classes + 4,
pretrained=args.pretrained, patch_size=args.patch_size,
window_size=args.window_size,)
projectors = ContrastiveProjectors(global_model.config.hidden_size, args.dis_gene)


global_model = global_model.cuda()
local_model = local_model.cuda()
projectors = projectors.cuda()

optim_params = [{'params': global_model.classifier.parameters()}, {'params': local_model.parameters()},
{'params': projectors.parameters(), 'lr_mult': 10}]
model = SwinTransformer(image_size=args.image_size, num_classes=num_classes,
pretrained=args.pretrained, patch_size=args.patch_size,
window_size=args.window_size, )
global_projectors = ContrastiveProjectors(model.config.hidden_size, args.dis_gene, teacher=True)
local_projectors = ContrastiveProjectors(model.config.hidden_size, args.dis_gene, teacher=False)

model = model.cuda()
global_projectors = global_projectors.cuda()
local_projectors = local_projectors.cuda()

optim_params = [{'params': model.parameters()}, {'params': local_projectors.parameters(), 'lr_mult': 10}]
optimizer = torch.optim.AdamW(optim_params, lr=args.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)

if args.dataparallel:
global_model = convert_model(global_model)
local_model = convert_model(local_model)
projectors = convert_model(projectors)
global_model = DataParallel(global_model, device_ids=[int(x) for x in args.visible_gpus.split(",")])
local_model = DataParallel(local_model, device_ids=[int(x) for x in args.visible_gpus.split(",")])
projectors = DataParallel(projectors, device_ids=[int(x) for x in args.visible_gpus.split(",")])
model = convert_model(model)
local_projectors = convert_model(local_projectors)
model = DataParallel(model, device_ids=[int(x) for x in args.visible_gpus.split(",")])
local_projectors = DataParallel(local_projectors, device_ids=[int(x) for x in args.visible_gpus.split(",")])

else:
if args.world_size > 1:
global_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(global_model)
local_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(local_model)
projectors = torch.nn.SyncBatchNorm.convert_sync_batchnorm(projectors)
global_model = DDP(global_model, device_ids=[gpu], find_unused_parameters=True, broadcast_buffers=True)
local_model = DDP(local_model, device_ids=[gpu], find_unused_parameters=True, broadcast_buffers=False)
projectors = DDP(projectors, device_ids=[gpu], find_unused_parameters=True, broadcast_buffers=False)

models = (global_model, local_model, projectors)

model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
local_projectors = torch.nn.SyncBatchNorm.convert_sync_batchnorm(local_projectors)
model = DDP(model, device_ids=[gpu], find_unused_parameters=True, broadcast_buffers=False)
local_projectors = DDP(local_projectors, device_ids=[gpu], find_unused_parameters=True, broadcast_buffers=False)

models = (model, global_projectors, local_projectors)

train(loaders, models, optimizer, scheduler, args, wandb_logger)


if __name__ == '__main__':
# args
parser = argparse.ArgumentParser()
Expand All @@ -131,6 +122,8 @@ def main(gpu, args, wandb_logger):
if not os.path.exists(args.checkpoints):
os.makedirs(args.checkpoints)

generate_splits(args.subset)

# init wandb
if not args.debug:
wandb.login(key="cb1e7d54d21d9080b46d2b1ae2a13d895770aa29")
Expand All @@ -152,4 +145,4 @@ def main(gpu, args, wandb_logger):
)
mp.spawn(main, args=(args, wandb_logger,), nprocs=args.world_size, join=True)
else:
main(0, args, wandb_logger)
main(0, args, wandb_logger)
14 changes: 10 additions & 4 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def __init__(self, image_size, num_classes, pretrained="", patch_size=4, window_
self.swin = Swinv2Model(config, add_pooling_layer=True, use_mask_token=True)
if pretrained:
self.swin = Swinv2Model.from_pretrained(pretrained, config=config, add_pooling_layer=True, use_mask_token=True)
self.classifier = nn.Linear(self.swin.num_features, config.num_labels)
self.global_classifier = nn.Linear(self.swin.num_features, config.num_labels)
self.local_classifier = nn.Linear(self.swin.num_features, config.num_labels + 10)

# Initialize weights and apply final processing
# self.post_init()
Expand All @@ -28,12 +29,13 @@ def forward(self, x, token_mask=None):
return_dict = self.config.use_return_dict
outputs = self.swin(x, bool_masked_pos=token_mask, return_dict=return_dict)
features = outputs[1]
logits = self.classifier(features)
return features, logits
global_logits = self.global_classifier(features)
local_logits = self.local_classifier(features)
return features, global_logits, local_logits


class ContrastiveProjectors(nn.Module):
def __init__(self, hidden_dim, gene_list):
def __init__(self, hidden_dim, gene_list, teacher=False):
super(ContrastiveProjectors, self).__init__()
self.region_projector = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, bias=False),
Expand All @@ -46,6 +48,10 @@ def __init__(self, hidden_dim, gene_list):
nn.ReLU(),
) for _ in gene_list]
)

if teacher:
for param in self.parameters():
param.requires_grad = False

def forward(self, features):
region_features = self.region_projector(features)
Expand Down
Loading