Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions tests/unit/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_unsqueeze_array(self):

class TestAxes:
def test_init(self):
assert Axes("XYZ").dims == "XYZ"
assert Axes("XYZ") == "XYZ"

with pytest.raises(ValueError) as excinfo:
Axes("XYZW")
Expand All @@ -114,7 +114,7 @@ def test_init(self):
def test_canonical_unsqueezed(self, canonical_dims):
shape = np.random.randint(2, 20, size=len(canonical_dims))
for axes in map(Axes, it.permutations(canonical_dims)):
assert axes.canonical(shape).dims == canonical_dims
assert axes.canonical(shape) == canonical_dims

def test_canonical_squeezed(self):
shape = (1, 60, 40)
Expand Down Expand Up @@ -287,7 +287,10 @@ def test_transform_expand(self):


def assert_transform(source, target, a, expected):
axes_mapper = Axes(source).mapper(Axes(target))
s = Axes(source)
t = Axes(target)
print(f"Original:{s}")
axes_mapper = s.mapper(t)
assert axes_mapper.map_shape(a.shape) == expected.shape
np.testing.assert_array_equal(axes_mapper.map_array(a), expected)

Expand All @@ -297,4 +300,7 @@ def assert_canonical_transform(source, a, expected):
target = source.canonical(a.shape)
axes_mapper = source.mapper(target)
assert axes_mapper.map_shape(a.shape) == expected.shape
np.testing.assert_array_equal(axes_mapper.map_array(a), expected)
output = axes_mapper.map_array(a)
print(f"Output:{output}")
print(f"Expected:{target}")
np.testing.assert_array_equal(output, expected)
61 changes: 61 additions & 0 deletions tiledb/bioimg/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Literal

SpaceUnit = Literal[
'angstrom', 'attometer', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']
TimeUnit = Literal[
'attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']

spaceUnitSymbolMap = {
"Å": 'angstrom',
"am": 'attometer',
"cm": 'centimeter',
"dm": 'decimeter',
"Em": 'exameter',
"fm": 'femtometer',
"ft": 'foot',
"Gm": 'gigameter',
"hm": 'hectometer',
"in": 'inch',
"km": 'kilometer',
"Mm": 'megameter',
"m": 'meter',
"µm": 'micrometer',
"mi.": 'mile',
"mm": 'millimeter',
"nm": 'nanometer',
"pc": 'parsec',
"Pm": 'petameter',
"pm": 'picometer',
"Tm": 'terameter',
"yd": 'yard',
"ym": 'yoctometer',
"Ym": 'yottameter',
"zm": 'zeptometer',
"Zm": 'zettameter'
}

timeUnitSymbolMap = {
"as": 'attosecond',
"cs": 'centisecond',
"d": 'day',
"ds": 'decisecond',
"Es": 'exasecond',
"fs": 'femtosecond',
"Gs": 'gigasecond',
"hs": 'hectosecond',
"h": 'hour',
"ks": 'kilosecond',
"Ms": 'megasecond',
"µs": 'microsecond',
"ms": 'millisecond',
"min": 'minute',
"ns": 'nanosecond',
"Ps": 'petasecond',
"ps": 'picosecond',
"s": 'second',
"Ts": 'terasecond',
"ys": 'yoctosecond',
"Ys": 'yottasecond',
"zs": 'zeptosecond',
"Zs": 'zettasecond'
}
58 changes: 49 additions & 9 deletions tiledb/bioimg/converters/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Iterable, Iterator, MutableSequence, Sequence, Tuple
from typing import Any, Iterable, Iterator, MutableSequence, Sequence, Tuple, Optional, Union, Literal
from ..constants import SpaceUnit, TimeUnit

import numpy as np
from pyeditdistance.distance import levenshtein
Expand Down Expand Up @@ -112,7 +113,7 @@ def transform_tile(self, tile: MutableSequence[slice]) -> None:
self.transform_sequence(tile, fill_value=slice(None))

def transform_sequence(
self, sequence: MutableSequence[Any], fill_value: Any = None
self, sequence: MutableSequence[Any], fill_value: Any = None
) -> None:
for i in sorted(self.idxs):
sequence.insert(i, fill_value)
Expand Down Expand Up @@ -194,16 +195,39 @@ def transform_sequence(self, s: MutableSequence[Any]) -> None:
mapper.transform_sequence(s)


@dataclass(frozen=True)
class NGFFAxes:
name: str
type: Optional[Union[Literal['space', 'time', 'channel'], str]]
unit: Optional[Union[SpaceUnit, TimeUnit]]

def __repr__(self):
# {"name": "t", "type": "time", "unit": "millisecond"},
name_rep = f'"name": "{self.name}"'
type_rep = f', "type": "{self.type}"' if self.type else ""
unit_rep = f', "unit": "{self.unit}"' if self.unit else ""
return f"\u007b{name_rep}{type_rep}{unit_rep}\u007d"


@dataclass(frozen=True)
class Axes:
dims: str
dims: Sequence[NGFFAxes]
__slots__ = ("dims",)
CANONICAL_DIMS = "TCZYX"

def __init__(self, dims: Iterable[str]):
if not isinstance(dims, str):
dims = "".join(dims)
axes = set(dims)
def __init__(self, dims: Sequence[Union[str, NGFFAxes]]):
if not len(dims):
raise ValueError("Axes list cannot be empty.")

if isinstance(dims, str):
# Handle list of literals instead of str
dims = [NGFFAxes(d, None, None) for d in [*dims]]
else:
if all(isinstance(d, str) for d in dims):
dims = [NGFFAxes(d, None, None) for d in dims]

dims_names = [d.name for d in dims]
axes = set(dims_names)
if len(dims) != len(axes):
raise ValueError(f"Duplicate axes: {dims}")
for required_axis in "X", "Y":
Expand All @@ -214,14 +238,28 @@ def __init__(self, dims: Iterable[str]):
raise ValueError(f"{axes.pop()!r} is not a valid Axis")
object.__setattr__(self, "dims", dims)

def to_str(self):
return "".join([d.name for d in self.dims])

def __repr__(self):
return f"{repr(self.dims)}"

def __eq__(self, other: Union[str, Axes]):
if isinstance(other, str):
return "".join([d.name for d in self.dims]) == other
elif isinstance(other, Axes):
return self.dims == other.dims
else:
raise ValueError(f"Invalid second equality operator")

def canonical(self, shape: Tuple[int, ...]) -> Axes:
"""
Return a new Axes instance with the dimensions of this axes whose size in `shape`
are greater than 1 and ordered in canonical order (TCZYX)
"""
assert len(self.dims) == len(shape)
dims = frozenset(dim for dim, size in zip(self.dims, shape) if size > 1)
return Axes(dim for dim in self.CANONICAL_DIMS if dim in dims)
return Axes([dim for dim in self.CANONICAL_DIMS if dim in [d.name for d in dims]])

def mapper(self, other: Axes) -> AxesMapper:
"""Return an AxesMapper from this axes to other"""
Expand All @@ -234,7 +272,9 @@ def webp_mapper(self, num_channels: int) -> AxesMapper:
return CompositeAxesMapper(mappers)


def _iter_axes_mappers(s: str, t: str) -> Iterator[AxesMapper]:
def _iter_axes_mappers(s: Sequence[NGFFAxes], t: Sequence[NGFFAxes]) -> Iterator[AxesMapper]:
s = "".join([d.name for d in s])
t = "".join([d.name for d in t])
s_set = frozenset(s)
assert len(s_set) == len(s), f"{s!r} contains duplicates"
t_set = frozenset(t)
Expand Down
16 changes: 8 additions & 8 deletions tiledb/bioimg/converters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def to_tiledb(
with rw_group:
rw_group.w_group.meta.update(
reader.group_metadata,
axes=reader.axes.dims,
axes=repr(reader.axes),
pixel_depth=jsonpickle.encode(
dict(iter_pixel_depths_meta({**compressors, **scaled_compressors})),
unpicklable=False,
Expand Down Expand Up @@ -531,11 +531,11 @@ def _convert_level_to_tiledb(
# We need to calculate the min-max values per channel
# First find the indices of all axes except 'C' needed for numpy amin and amax
min_max_indices = tuple(
idx for idx, char in enumerate(source_axes.dims) if char != "C"
idx for idx, char in enumerate(source_axes.to_str()) if char != "C"
)

# Find the number of channels
channel_index = source_axes.dims.find("C")
channel_index = source_axes.to_str().find("C")
channel_count = source_shape[channel_index] if channel_index > -1 else 1

# Initialize a numpy 2D array to hold the min-max values per channel
Expand All @@ -551,12 +551,12 @@ def _convert_level_to_tiledb(
channel_min_max[:, 1] = np.repeat(min_value, channel_count)

level_metadata["axes"] = {
"originalAxes": [*reader.axes.dims],
"originalAxes": [*reader.axes.to_str()],
"originalShape": reader.level_shape(level),
"storedAxes": dim_names,
"storedShape": dim_shape,
"axesMapping": get_axes_translation(
compressor.get(level, tiledb.ZstdFilter(level=0)), reader.axes.dims
compressor.get(level, tiledb.ZstdFilter(level=0)), reader.axes.to_str()
),
}

Expand Down Expand Up @@ -620,7 +620,7 @@ def _create_image_pyramid(
preserve_axes: bool,
pyramid_kwargs: Mapping[str, Any],
) -> Tuple[Mapping[int, tiledb.Filter], Mapping[str, Any]]:
scaler = Scaler(reader.level_shape(base_level), reader.axes.dims, **pyramid_kwargs)
scaler = Scaler(reader.level_shape(base_level), reader.axes.to_str(), **pyramid_kwargs)

levels_metadata: MutableMapping[str, Any] = {"axes": []}

Expand All @@ -646,12 +646,12 @@ def _create_image_pyramid(

levels_metadata["axes"].append(
{
"originalAxes": [*reader.axes.dims],
"originalAxes": [*reader.axes.to_str()],
"originalShape": dim_shape,
"storedAxes": dim_names,
"storedShape": axes_mapper.map_shape(dim_shape),
"axesMapping": get_axes_translation(
scaler.compressors[level], reader.axes.dims
scaler.compressors[level], reader.axes.to_str()
),
}
)
Expand Down
13 changes: 10 additions & 3 deletions tiledb/bioimg/converters/ome_tiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .axes import Axes
from .base import ImageConverter, ImageReader, ImageWriter
from .metadata import qpi_image_meta, qpi_original_meta
from ..metadata import NGFFMetadata


class OMETiffReader(ImageReader):
Expand All @@ -35,6 +36,7 @@ def __init__(
# XXX ignore all but the first series
self._series = self._tiff.series[0]
omexml = self._tiff.ome_metadata
self._ome_metadata = self.ngff_metadata
self._metadata = tifffile.xml2dict(omexml) if omexml else {}

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
Expand All @@ -50,9 +52,10 @@ def logger(self, default_logger: logging.Logger) -> None:

@property
def axes(self) -> Axes:
axes = Axes(self._series.axes.replace("S", "C"))
self._logger.debug(f"Reader axes: {axes}")
return axes
data_axes = self._series.axes.replace("S", "C")
if self._ome_metadata:
data_axes = self._ome_metadata.axes
return Axes(data_axes)

@property
def channels(self) -> Sequence[str]:
Expand Down Expand Up @@ -313,6 +316,10 @@ def original_metadata(self) -> Dict[str, Any]:

return metadata

@property
def ngff_metadata(self) -> NGFFMetadata:
return NGFFMetadata.from_ome_tiff(self._tiff)


class OMETiffWriter(ImageWriter):
def __init__(self, output_path: str, logger: logging.Logger, ome: bool = True):
Expand Down
4 changes: 4 additions & 0 deletions tiledb/bioimg/converters/ome_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,13 @@ def logger(self, default_logger: logging.Logger) -> None:

@property
def axes(self) -> Axes:
<<<<<<< Updated upstream
axes = Axes(a["name"].upper() for a in self._multiscales.node.metadata["axes"])
self._logger.debug(f"Reader axes: {axes}")
return axes
=======
return Axes([a["name"].upper() for a in self._multiscales.node.metadata["axes"]])
>>>>>>> Stashed changes

@property
def channels(self) -> Sequence[str]:
Expand Down
43 changes: 43 additions & 0 deletions tiledb/bioimg/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Union, Sequence, Any
from .constants import spaceUnitSymbolMap, timeUnitSymbolMap
from .converters.axes import NGFFAxes
import tifffile
from tifffile import TiffFile
from typing_extensions import Self


class NGFFMetadata:
axes: Sequence[NGFFAxes]

@classmethod
def from_ome_tiff(cls, tiff: TiffFile) -> Union[Self, None]:
ome_metadata = tifffile.xml2dict(tiff.ome_metadata) if tiff.ome_metadata else {}
metadata = cls()

# If invalid OME metadata return empty NGFF metadata
if 'OME' not in ome_metadata:
return metadata

ome_images = ome_metadata.get('OME', {}).get('Image', [])
if not ome_images:
return metadata

ome_pixels = ome_images[0].get('Pixels', {}) if isinstance(ome_images, list) else ome_images.get('Pixels', {})

# Create 'axes' metadata field
if 'DimensionOrder' in ome_pixels:
axes = []
for axis in ome_pixels.get('DimensionOrder', ''):
if axis in ['X', 'Y', 'Z']:
axes.append(NGFFAxes(name=axis, type='space',
unit=spaceUnitSymbolMap.get(ome_pixels.get(f'PhysicalSize{axis}Unit', "µm"))))
elif axis == 'C':
axes.append(NGFFAxes(name=axis, type='channel', unit=None))
elif axis == 'T':
axes.append(NGFFAxes(name=axis, type='time',
unit=timeUnitSymbolMap.get(ome_pixels.get(f'TimeIncrementUnit', "s"))))
else:
axes.append(NGFFAxes(name=axis, type=None, unit=None))
metadata.axes = axes

return metadata