diff --git a/examples/config.yaml b/examples/config.yaml new file mode 100644 index 0000000..e4259d0 --- /dev/null +++ b/examples/config.yaml @@ -0,0 +1,62 @@ +system: + cpus: 1 + gpus: 1 + seed: 1 + +dataset: + training: + s3vol01700: + images: ["~/dropbox/40_gt/13_wasp_sample3/vol_01700/img.h5",] + label: "~/dropbox/40_gt/13_wasp_sample3/vol_01700/label_v3.h5" + s3vol02299: + images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02299/img.h5",] + label: "~/dropbox/40_gt/13_wasp_sample3/vol_02299/label_v3.h5" + s3vol02400: + images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02400/img_zyx_2400-2656_5700-5956_2770-3026.h5",] + label: "~/dropbox/40_gt/13_wasp_sample3/vol_02400/label_v1_diane.h5" + # s3vol02684: + # images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02284/img_zyx_2400-2656_5700-5956_2770-3026.h5",] + # label: "~/dropbox/40_gt/13_wasp_sample3/vol_0684/label_v1_diane.h5" + s3vol02794: + images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02794/img_zyx_2794-3050_5811-6067_8757-9013.h5",] + label: "~/dropbox/40_gt/13_wasp_sample3/vol_02794/seg_v1_cropped.h5" + s3vol03290: + images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03290/img_zyx_3290-3546_2375-2631_8450-8706.h5",] + label: "~/dropbox/40_gt/13_wasp_sample3/vol_03290/label_v1.h5" + s3vol03700: + images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03700/img_zyx_3700-3956_5000-5256_4250-4506.h5",] + label: "~/dropbox/40_gt/13_wasp_sample3/vol_03700/label_v3.h5" + s3vol03998: + images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03998/img.h5",] + label: "~/dropbox/40_gt/13_wasp_sample3/vol_03998/label_v1.h5" + s3vol04900: + images: ["~/dropbox/40_gt/13_wasp_sample3/vol_04900/img.h5",] + label: "~/dropbox/40_gt/13_wasp_sample3/vol_04900/label_v1.h5" + s3vol05250: + images: ["~/dropbox/40_gt/13_wasp_sample3/vol_05250/img_zyx_5250-5506_4600-4856_5500-5756.h5",] + label: "~/dropbox/40_gt/13_wasp_sample3/vol_05250/label_v3_remove_contact.h5" + s3vol05450: + images: ["~/dropbox/40_gt/13_wasp_sample3/vol_05450/img_zyx_5450-5706_5350-5606_7000-7256.h5",] + label: "~/dropbox/40_gt/13_wasp_sample3/vol_05450/label_v4_chiyip.h5" + validation: + s3vol04000: + images: ["~/dropbox/40_gt/13_wasp_sample3/vol_04000/img_zyx_4000-4256_3400-3656_8150-8406.h5",] + label: "~/dropbox/40_gt/13_wasp_sample3/vol_04000/label_v3.h5" +model: + in_channels: 3 + out_channels: 3 + +train: + iter_start: 0 + iter_stop: 1000000 + class_rebalance: false + # batch size per GPU + # The dataprovider should provide nGPU*batch_size batches! + batch_size: 1 + output_dir: "./" + patch_size: [128, 128, 128] + learning_rate: 0.001 + #training_interval: 200 + #validation_interval: 2000 + training_interval: 2 + validation_interval: 4 \ No newline at end of file diff --git a/neutorch/data/dataset.py b/neutorch/data/dataset.py index a2883a9..0c3f795 100644 --- a/neutorch/data/dataset.py +++ b/neutorch/data/dataset.py @@ -1,15 +1,14 @@ +import math import random from functools import cached_property -import math import numpy as np import torch -from yacs.config import CfgNode - from chunkflow.lib.cartesian_coordinate import Cartesian +from yacs.config import CfgNode -from neutorch.data.transform import * from neutorch.data.sample import SemanticSample +from neutorch.data.transform import * DEFAULT_PATCH_SIZE = Cartesian(128, 128, 128) @@ -246,7 +245,6 @@ def __next__(self): class AffinityMapDataset(SemanticDataset): def __init__(self, samples: list): - #patch_size: Cartesian = DEFAULT_PATCH_SIZE): super().__init__(samples) @cached_property @@ -267,12 +265,33 @@ def transform(self): Flip(), Transpose(), MissAlignment(), - # Tranform to affinity map - # there is a shrinking, so we put this transformation here - # rather than the label2target function. Label2AffinityMap(probability=1.), ]) +class BoundaryAugmentationDataset(SemanticDataset): + def __initi__(self, samples: list): + super.__init__(samples) + + @cached_property + def transform(self): + return Compose([ + NormalizeTo01(probability=1.), + AdjustBrightness(), + AdjustContrast(), + Gamma(), + OneOf([ + Noise(), + GaussianBlur2D(), + ]), + MaskBox(), + Perspective2D(), + # RotateScale(probability=1.), + # DropSection(), + Flip(), + Transpose(), + MissAlignment(), + ]) + if __name__ == '__main__': from yacs.config import CfgNode @@ -282,10 +301,9 @@ def transform(self): cfg = CfgNode.load_cfg(file) cfg.freeze() - sd = AffinityMapDataset( + sd = BoundaryAugmentationDataset( path_list=['/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_01700/rna_v1.h5'], sample_name_to_image_versions=cfg.dataset.sample_name_to_image_versions, patch_size=Cartesian(128, 128, 128), ) - - # print(sd.samples) \ No newline at end of file + \ No newline at end of file diff --git a/neutorch/data/patch.py b/neutorch/data/patch.py index 1d3554f..8d6c829 100644 --- a/neutorch/data/patch.py +++ b/neutorch/data/patch.py @@ -1,8 +1,9 @@ from functools import cached_property -import numpy as np +import numpy as np # from torch import tensor, device import torch + # torch.multiprocessing.set_start_method('spawn') # from chunkflow.lib.cartesian_coordinate import Cartesian diff --git a/neutorch/data/transform.py b/neutorch/data/transform.py index 945c275..3caa248 100644 --- a/neutorch/data/transform.py +++ b/neutorch/data/transform.py @@ -1,25 +1,21 @@ -from abc import ABC, abstractmethod import random +from abc import ABC, abstractmethod from functools import cached_property - -from chunkflow.lib.cartesian_coordinate import Cartesian -# from copy import deepcopy - +import cv2 import numpy as np - +from chunkflow.lib.cartesian_coordinate import Cartesian +# from reneu.lib.segmentation import seg_to_affs from scipy.ndimage.filters import gaussian_filter -# from scipy.ndimage import affine_transform - -import cv2 - -from skimage.util import random_noise from skimage.transform import swirl - -from reneu.lib.segmentation import seg_to_affs +from skimage.util import random_noise from .patch import Patch +# from copy import deepcopy + + +# from scipy.ndimage import affine_transform DEFAULT_PROBABILITY = .5 DEFAULT_SHRINK_SIZE = (0, 0, 0, 0, 0, 0) diff --git a/neutorch/train/affinity_map.py b/neutorch/train/affinity_map.py index b2a9daf..54b22e9 100644 --- a/neutorch/train/affinity_map.py +++ b/neutorch/train/affinity_map.py @@ -3,9 +3,10 @@ import click from yacs.config import CfgNode -from .base import TrainerBase from neutorch.data.dataset import AffinityMapDataset +from .base import TrainerBase + class AffinityMapTrainer(TrainerBase): def __init__(self, cfg: CfgNode) -> None: diff --git a/neutorch/train/base.py b/neutorch/train/base.py index cbae7a8..bad344c 100644 --- a/neutorch/train/base.py +++ b/neutorch/train/base.py @@ -1,25 +1,22 @@ +import os +import random from abc import ABC, abstractproperty from functools import cached_property from glob import glob - -import random -import os from time import time -from yacs.config import CfgNode import numpy as np - -from chunkflow.lib.cartesian_coordinate import Cartesian - import torch -from torch.utils.tensorboard import SummaryWriter +from chunkflow.lib.cartesian_coordinate import Cartesian from torch.utils.data import DataLoader -from neutorch.data.patch import collate_batch +from torch.utils.tensorboard import SummaryWriter +from yacs.config import CfgNode -from neutorch.model.IsoRSUNet import Model -from neutorch.model.io import save_chkpt, load_chkpt, log_tensor -from neutorch.loss import BinomialCrossEntropyWithLogits from neutorch.data.dataset import worker_init_fn +from neutorch.data.patch import collate_batch +from neutorch.loss import BinomialCrossEntropyWithLogits +from neutorch.model.io import load_chkpt, log_tensor, save_chkpt +from neutorch.model.IsoRSUNet import Model class TrainerBase(ABC): diff --git a/neutorch/train/boundary_aug.py b/neutorch/train/boundary_aug.py new file mode 100644 index 0000000..76c8450 --- /dev/null +++ b/neutorch/train/boundary_aug.py @@ -0,0 +1,100 @@ +import os +from functools import cached_property +from time import time + +import click +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter +from yacs.config import CfgNode + +from neutorch.data.dataset import BoundaryAugmentationDataset +from neutorch.model.io import log_tensor, save_chkpt + +from .base import SemanticTrainer, TrainerBase + +class BoundaryAugTrainer(SemanticTrainer): + def __init__(self, cfg: CfgNode) -> None: + assert isinstance(cfg, CfgNode) + super().__init__(cfg) + self.cfg = cfg + breakpoint() + + @cached_property + def training_dataset(self): + return BoundaryAugmentationDataset.from_config(self.cfg, is_train=True) + + @cached_property + def validation_dataset(self): + return BoundaryAugmentationDataset.from_config(self.cfg, is_train=False) + + """ + def call(self): + writer = SummaryWriter(log_dir=self.cfg.train.output_dir) + accumulated_loss = 0. #floating point + + for image, label in self.training_data_loader: + iter_idx += 1 + if iter_idx > self.cfg.train.iter_stop: + print('exeeds maximum iteration:', self.cfg.train.iter_stop) + return + + pint = time() + predict = self.model(image) + loss = self.loss_module(predict, label) + assert not torch.isnan(loss), 'loss is NaN.' + + self.optimizer # + loss.backward() + self.optimizer.step() + accumulated_loss += loss.tolist() + + if iter_idx % self.cfg.train.training_interval == 0 and iter_idx > 0: + per_voxel_loss = accumulated_loss / \ + self.cfg.train.training_interval / \ + self.voxel_num + + print(f'iteration {iter_idx} takes {round(time()-ping, 3)} seconds with loss: {per_voxel_loss}') + accumulated_loss = 0. + predict = self.post_processing(predict) + writer.add_scalar('loss/train', per_voxel_loss, iter_idx) + log_tensor(writer, 'train/image', image, 'image', iter_idx) + log_tensor(writer, 'train/prediction', predict.detach(), 'image', iter_idx) + log_tensor(writer, 'train/label', label, 'segmentation', iter_idx) + + if iter_idx % self.cfg.train.validation_interval == 0 and iter_idx > 0: + fname = os.path.join(self.cfg.train.output_dir, f'model_{iter_idx}.chkpt') + print(f'save model to {fname}') + save_chkpt(self.model, self.cfg.train.output_dir, iter_idx, self.optimizer) + + print('evaluate prediction: ') + validation_image, validation_label = next(self.validation_data_iter) + + with torch.no_grad(): + validation_predict = self.model(validation_image) + validation_loss = self.loss_module(validation_predict, validation_label) + validation_predict = self.post_processing(validation_predict) + per_voxel_loss = validation_loss.tolist() / self.voxel_num + print(f'iteration {iter_idx} takes {round(time()-ping, 3)} seconds with loss: {per_voxel_loss}') + writer.add_scalar('loss/validation', per_voxel_loss, iter_idx) + log_tensor(writer, 'validation/image', validation_image, 'image', iter_idx) + log_tensor(writer, 'validation/prediction', validation_predict, 'image', iter_idx) + log_tensor(writer, 'validation/label', validation_label, 'segmentation', iter_idx) + + writer.close() + """ + +@click.command() +@click.option('--config-file', '-c', + type=click.Path(exists=True, dir_okay=False, file_okay=True, readable=True, resolve_path=True), + default='./config.yaml', + help = 'configuration file containing all the parameters.' +) + +def main(config_file: str): + from neutorch.data.dataset import load_cfg + cfg = load_cfg(config_file) + trainer = BoundaryAugTrainer(cfg) + trainer() + + diff --git a/setup.py b/setup.py index 4e9aff2..070a46a 100755 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ neutrain-denoise=neutorch.train.denoise:main neutrain-post=neutorch.train.post_synapses:main neutrain-affs=neutorch.train.affinity_map:main + neutrain-ba=neutorch.train.boundary_aug:main ''', classifiers=[ 'Development Status :: 4 - Beta',