Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
system:
cpus: 1
gpus: 1
seed: 1

dataset:
training:
s3vol01700:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_01700/img.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_01700/label_v3.h5"
s3vol02299:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02299/img.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_02299/label_v3.h5"
s3vol02400:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02400/img_zyx_2400-2656_5700-5956_2770-3026.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_02400/label_v1_diane.h5"
# s3vol02684:
# images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02284/img_zyx_2400-2656_5700-5956_2770-3026.h5",]
# label: "~/dropbox/40_gt/13_wasp_sample3/vol_0684/label_v1_diane.h5"
s3vol02794:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02794/img_zyx_2794-3050_5811-6067_8757-9013.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_02794/seg_v1_cropped.h5"
s3vol03290:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03290/img_zyx_3290-3546_2375-2631_8450-8706.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_03290/label_v1.h5"
s3vol03700:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03700/img_zyx_3700-3956_5000-5256_4250-4506.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_03700/label_v3.h5"
s3vol03998:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03998/img.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_03998/label_v1.h5"
s3vol04900:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_04900/img.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_04900/label_v1.h5"
s3vol05250:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_05250/img_zyx_5250-5506_4600-4856_5500-5756.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_05250/label_v3_remove_contact.h5"
s3vol05450:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_05450/img_zyx_5450-5706_5350-5606_7000-7256.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_05450/label_v4_chiyip.h5"
validation:
s3vol04000:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_04000/img_zyx_4000-4256_3400-3656_8150-8406.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_04000/label_v3.h5"
model:
in_channels: 3
out_channels: 3

train:
iter_start: 0
iter_stop: 1000000
class_rebalance: false
# batch size per GPU
# The dataprovider should provide nGPU*batch_size batches!
batch_size: 1
output_dir: "./"
patch_size: [128, 128, 128]
learning_rate: 0.001
#training_interval: 200
#validation_interval: 2000
training_interval: 2
validation_interval: 4
40 changes: 29 additions & 11 deletions neutorch/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import math
import random
from functools import cached_property
import math

import numpy as np
import torch
from yacs.config import CfgNode

from chunkflow.lib.cartesian_coordinate import Cartesian
from yacs.config import CfgNode

from neutorch.data.transform import *
from neutorch.data.sample import SemanticSample
from neutorch.data.transform import *

DEFAULT_PATCH_SIZE = Cartesian(128, 128, 128)

Expand Down Expand Up @@ -246,7 +245,6 @@ def __next__(self):

class AffinityMapDataset(SemanticDataset):
def __init__(self, samples: list):
#patch_size: Cartesian = DEFAULT_PATCH_SIZE):
super().__init__(samples)

@cached_property
Expand All @@ -267,12 +265,33 @@ def transform(self):
Flip(),
Transpose(),
MissAlignment(),
# Tranform to affinity map
# there is a shrinking, so we put this transformation here
# rather than the label2target function.
Label2AffinityMap(probability=1.),
])

class BoundaryAugmentationDataset(SemanticDataset):
def __initi__(self, samples: list):
super.__init__(samples)

@cached_property
def transform(self):
return Compose([
NormalizeTo01(probability=1.),
AdjustBrightness(),
AdjustContrast(),
Gamma(),
OneOf([
Noise(),
GaussianBlur2D(),
]),
MaskBox(),
Perspective2D(),
# RotateScale(probability=1.),
# DropSection(),
Flip(),
Transpose(),
MissAlignment(),
])

if __name__ == '__main__':

from yacs.config import CfgNode
Expand All @@ -282,10 +301,9 @@ def transform(self):
cfg = CfgNode.load_cfg(file)
cfg.freeze()

sd = AffinityMapDataset(
sd = BoundaryAugmentationDataset(
path_list=['/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_01700/rna_v1.h5'],
sample_name_to_image_versions=cfg.dataset.sample_name_to_image_versions,
patch_size=Cartesian(128, 128, 128),
)

# print(sd.samples)

3 changes: 2 additions & 1 deletion neutorch/data/patch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from functools import cached_property
import numpy as np

import numpy as np
# from torch import tensor, device
import torch

# torch.multiprocessing.set_start_method('spawn')

# from chunkflow.lib.cartesian_coordinate import Cartesian
Expand Down
22 changes: 9 additions & 13 deletions neutorch/data/transform.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
from abc import ABC, abstractmethod
import random
from abc import ABC, abstractmethod
from functools import cached_property


from chunkflow.lib.cartesian_coordinate import Cartesian
# from copy import deepcopy

import cv2
import numpy as np

from chunkflow.lib.cartesian_coordinate import Cartesian
# from reneu.lib.segmentation import seg_to_affs
from scipy.ndimage.filters import gaussian_filter
# from scipy.ndimage import affine_transform

import cv2

from skimage.util import random_noise
from skimage.transform import swirl

from reneu.lib.segmentation import seg_to_affs
from skimage.util import random_noise

from .patch import Patch

# from copy import deepcopy


# from scipy.ndimage import affine_transform

DEFAULT_PROBABILITY = .5
DEFAULT_SHRINK_SIZE = (0, 0, 0, 0, 0, 0)
Expand Down
3 changes: 2 additions & 1 deletion neutorch/train/affinity_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import click
from yacs.config import CfgNode

from .base import TrainerBase
from neutorch.data.dataset import AffinityMapDataset

from .base import TrainerBase


class AffinityMapTrainer(TrainerBase):
def __init__(self, cfg: CfgNode) -> None:
Expand Down
21 changes: 9 additions & 12 deletions neutorch/train/base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
import os
import random
from abc import ABC, abstractproperty
from functools import cached_property
from glob import glob

import random
import os
from time import time

from yacs.config import CfgNode
import numpy as np

from chunkflow.lib.cartesian_coordinate import Cartesian

import torch
from torch.utils.tensorboard import SummaryWriter
from chunkflow.lib.cartesian_coordinate import Cartesian
from torch.utils.data import DataLoader
from neutorch.data.patch import collate_batch
from torch.utils.tensorboard import SummaryWriter
from yacs.config import CfgNode

from neutorch.model.IsoRSUNet import Model
from neutorch.model.io import save_chkpt, load_chkpt, log_tensor
from neutorch.loss import BinomialCrossEntropyWithLogits
from neutorch.data.dataset import worker_init_fn
from neutorch.data.patch import collate_batch
from neutorch.loss import BinomialCrossEntropyWithLogits
from neutorch.model.io import load_chkpt, log_tensor, save_chkpt
from neutorch.model.IsoRSUNet import Model


class TrainerBase(ABC):
Expand Down
100 changes: 100 additions & 0 deletions neutorch/train/boundary_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
from functools import cached_property
from time import time

import click
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from yacs.config import CfgNode

from neutorch.data.dataset import BoundaryAugmentationDataset
from neutorch.model.io import log_tensor, save_chkpt

from .base import SemanticTrainer, TrainerBase

class BoundaryAugTrainer(SemanticTrainer):
def __init__(self, cfg: CfgNode) -> None:
assert isinstance(cfg, CfgNode)
super().__init__(cfg)
self.cfg = cfg
breakpoint()

@cached_property
def training_dataset(self):
return BoundaryAugmentationDataset.from_config(self.cfg, is_train=True)

@cached_property
def validation_dataset(self):
return BoundaryAugmentationDataset.from_config(self.cfg, is_train=False)

"""
def call(self):
writer = SummaryWriter(log_dir=self.cfg.train.output_dir)
accumulated_loss = 0. #floating point

for image, label in self.training_data_loader:
iter_idx += 1
if iter_idx > self.cfg.train.iter_stop:
print('exeeds maximum iteration:', self.cfg.train.iter_stop)
return

pint = time()
predict = self.model(image)
loss = self.loss_module(predict, label)
assert not torch.isnan(loss), 'loss is NaN.'

self.optimizer #
loss.backward()
self.optimizer.step()
accumulated_loss += loss.tolist()

if iter_idx % self.cfg.train.training_interval == 0 and iter_idx > 0:
per_voxel_loss = accumulated_loss / \
self.cfg.train.training_interval / \
self.voxel_num

print(f'iteration {iter_idx} takes {round(time()-ping, 3)} seconds with loss: {per_voxel_loss}')
accumulated_loss = 0.
predict = self.post_processing(predict)
writer.add_scalar('loss/train', per_voxel_loss, iter_idx)
log_tensor(writer, 'train/image', image, 'image', iter_idx)
log_tensor(writer, 'train/prediction', predict.detach(), 'image', iter_idx)
log_tensor(writer, 'train/label', label, 'segmentation', iter_idx)

if iter_idx % self.cfg.train.validation_interval == 0 and iter_idx > 0:
fname = os.path.join(self.cfg.train.output_dir, f'model_{iter_idx}.chkpt')
print(f'save model to {fname}')
save_chkpt(self.model, self.cfg.train.output_dir, iter_idx, self.optimizer)

print('evaluate prediction: ')
validation_image, validation_label = next(self.validation_data_iter)

with torch.no_grad():
validation_predict = self.model(validation_image)
validation_loss = self.loss_module(validation_predict, validation_label)
validation_predict = self.post_processing(validation_predict)
per_voxel_loss = validation_loss.tolist() / self.voxel_num
print(f'iteration {iter_idx} takes {round(time()-ping, 3)} seconds with loss: {per_voxel_loss}')
writer.add_scalar('loss/validation', per_voxel_loss, iter_idx)
log_tensor(writer, 'validation/image', validation_image, 'image', iter_idx)
log_tensor(writer, 'validation/prediction', validation_predict, 'image', iter_idx)
log_tensor(writer, 'validation/label', validation_label, 'segmentation', iter_idx)

writer.close()
"""

@click.command()
@click.option('--config-file', '-c',
type=click.Path(exists=True, dir_okay=False, file_okay=True, readable=True, resolve_path=True),
default='./config.yaml',
help = 'configuration file containing all the parameters.'
)

def main(config_file: str):
from neutorch.data.dataset import load_cfg
cfg = load_cfg(config_file)
trainer = BoundaryAugTrainer(cfg)
trainer()


1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
neutrain-denoise=neutorch.train.denoise:main
neutrain-post=neutorch.train.post_synapses:main
neutrain-affs=neutorch.train.affinity_map:main
neutrain-ba=neutorch.train.boundary_aug:main
''',
classifiers=[
'Development Status :: 4 - Beta',
Expand Down