|
| 1 | +"""Utilities for converting between sidpy datasets and NeXus HDF5.""" |
| 2 | + |
| 3 | +from __future__ import absolute_import, division, print_function, unicode_literals |
| 4 | + |
| 5 | +import datetime |
| 6 | +import json |
| 7 | + |
| 8 | +import h5py |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +from sidpy.sid import Dataset, Dimension |
| 12 | + |
| 13 | +__all__ = ["sidpy_to_nexus_hdf5", "nexus_to_sidpy"] |
| 14 | + |
| 15 | + |
| 16 | +def _clean_name(name, fallback): |
| 17 | + if name is None: |
| 18 | + name = "" |
| 19 | + name = str(name).strip().replace("/", "_") |
| 20 | + return name if name else fallback |
| 21 | + |
| 22 | + |
| 23 | +def _ensure_unique_name(name, used_names): |
| 24 | + if name not in used_names: |
| 25 | + used_names.add(name) |
| 26 | + return name |
| 27 | + |
| 28 | + index = 1 |
| 29 | + while True: |
| 30 | + candidate = "{}_{}".format(name, index) |
| 31 | + if candidate not in used_names: |
| 32 | + used_names.add(candidate) |
| 33 | + return candidate |
| 34 | + index += 1 |
| 35 | + |
| 36 | + |
| 37 | +def _decode_if_bytes(value): |
| 38 | + if isinstance(value, bytes): |
| 39 | + return value.decode("utf-8") |
| 40 | + if isinstance(value, np.bytes_): |
| 41 | + return value.decode("utf-8") |
| 42 | + if isinstance(value, np.ndarray): |
| 43 | + return [_decode_if_bytes(item) for item in value.tolist()] |
| 44 | + return value |
| 45 | + |
| 46 | + |
| 47 | +def _normalize_axes_attr(value): |
| 48 | + value = _decode_if_bytes(value) |
| 49 | + if isinstance(value, str): |
| 50 | + return [value] |
| 51 | + if isinstance(value, (list, tuple)): |
| 52 | + return [str(item) for item in value] |
| 53 | + return None |
| 54 | + |
| 55 | + |
| 56 | +def _json_ready(value): |
| 57 | + if isinstance(value, bytes): |
| 58 | + return value.decode("utf-8", errors="replace") |
| 59 | + if isinstance(value, np.bytes_): |
| 60 | + return value.decode("utf-8", errors="replace") |
| 61 | + if isinstance(value, dict): |
| 62 | + return {str(_json_ready(key)): _json_ready(val) for key, val in value.items()} |
| 63 | + if isinstance(value, (list, tuple)): |
| 64 | + return [_json_ready(item) for item in value] |
| 65 | + if isinstance(value, np.ndarray): |
| 66 | + return value.tolist() |
| 67 | + if isinstance(value, np.generic): |
| 68 | + return value.item() |
| 69 | + return value |
| 70 | + |
| 71 | + |
| 72 | +def _write_json_dataset(parent, name, payload): |
| 73 | + string_dtype = h5py.string_dtype(encoding="utf-8") |
| 74 | + parent.create_dataset(name, data=json.dumps(_json_ready(payload)), dtype=string_dtype) |
| 75 | + |
| 76 | + |
| 77 | +def _read_json_dataset(parent, name): |
| 78 | + if name not in parent: |
| 79 | + return {} |
| 80 | + raw = parent[name][()] |
| 81 | + raw = _decode_if_bytes(raw) |
| 82 | + if not raw: |
| 83 | + return {} |
| 84 | + return json.loads(raw) |
| 85 | + |
| 86 | + |
| 87 | +def _set_root_attrs(h5_file, default_entry): |
| 88 | + time_stamp = datetime.datetime.now().isoformat() |
| 89 | + h5_file.attrs["default"] = default_entry |
| 90 | + h5_file.attrs["file_name"] = h5_file.filename |
| 91 | + h5_file.attrs["file_time"] = time_stamp |
| 92 | + h5_file.attrs["creator"] = "sidpy" |
| 93 | + h5_file.attrs["HDF5_Version"] = h5py.version.hdf5_version |
| 94 | + h5_file.attrs["h5py_version"] = h5py.version.version |
| 95 | + |
| 96 | + |
| 97 | +def sidpy_to_nexus_hdf5(dataset, h5_path, entry_name="entry", nxdata_name="data", |
| 98 | + signal_name="data", mode="w", compression=None): |
| 99 | + """ |
| 100 | + Write a sidpy.Dataset to a NeXus-compatible HDF5 file. |
| 101 | +
|
| 102 | + Parameters |
| 103 | + ---------- |
| 104 | + dataset : sidpy.Dataset |
| 105 | + Dataset to serialize. |
| 106 | + h5_path : str or h5py.File |
| 107 | + Destination HDF5 path or open file handle. |
| 108 | + entry_name : str, optional |
| 109 | + Name of the NXentry group. |
| 110 | + nxdata_name : str, optional |
| 111 | + Name of the NXdata group. |
| 112 | + signal_name : str, optional |
| 113 | + Name of the primary signal dataset within NXdata. |
| 114 | + mode : str, optional |
| 115 | + File mode used when `h5_path` is a path. |
| 116 | + compression : str, optional |
| 117 | + Compression passed to h5py when creating datasets. |
| 118 | +
|
| 119 | + Returns |
| 120 | + ------- |
| 121 | + str or h5py.Dataset |
| 122 | + Signal dataset path for path-based writes, or the written h5py.Dataset |
| 123 | + when an open file handle is provided. |
| 124 | + """ |
| 125 | + if not isinstance(dataset, Dataset): |
| 126 | + raise TypeError("dataset must be a sidpy.Dataset") |
| 127 | + |
| 128 | + entry_name = _clean_name(entry_name, "entry") |
| 129 | + nxdata_name = _clean_name(nxdata_name, "data") |
| 130 | + signal_name = _clean_name(signal_name, "data") |
| 131 | + |
| 132 | + close_file = False |
| 133 | + if isinstance(h5_path, h5py.File): |
| 134 | + h5_file = h5_path |
| 135 | + else: |
| 136 | + h5_file = h5py.File(h5_path, mode) |
| 137 | + close_file = True |
| 138 | + |
| 139 | + try: |
| 140 | + if entry_name in h5_file: |
| 141 | + del h5_file[entry_name] |
| 142 | + |
| 143 | + _set_root_attrs(h5_file, entry_name) |
| 144 | + |
| 145 | + entry = h5_file.create_group(entry_name) |
| 146 | + entry.attrs["NX_class"] = "NXentry" |
| 147 | + entry.attrs["default"] = nxdata_name |
| 148 | + entry.create_dataset("title", data=dataset.title or signal_name) |
| 149 | + |
| 150 | + nxdata = entry.create_group(nxdata_name) |
| 151 | + nxdata.attrs["NX_class"] = "NXdata" |
| 152 | + nxdata.attrs["signal"] = signal_name |
| 153 | + |
| 154 | + signal_kwargs = {} |
| 155 | + if compression is not None: |
| 156 | + signal_kwargs["compression"] = compression |
| 157 | + |
| 158 | + signal = nxdata.create_dataset(signal_name, data=np.array(dataset), **signal_kwargs) |
| 159 | + signal.attrs["units"] = dataset.units |
| 160 | + signal.attrs["quantity"] = dataset.quantity |
| 161 | + signal.attrs["data_type"] = dataset.data_type.name |
| 162 | + signal.attrs["modality"] = dataset.modality |
| 163 | + signal.attrs["source"] = dataset.source |
| 164 | + signal.attrs["title"] = dataset.title |
| 165 | + signal.attrs["long_name"] = dataset.data_descriptor |
| 166 | + |
| 167 | + used_names = {signal_name} |
| 168 | + axes_names = [] |
| 169 | + for dim_index in range(dataset.ndim): |
| 170 | + axis = dataset._axes.get(dim_index) |
| 171 | + if axis is None: |
| 172 | + axis_name = "." |
| 173 | + axes_names.append(axis_name) |
| 174 | + continue |
| 175 | + |
| 176 | + axis_name = _ensure_unique_name(_clean_name(axis.name, "dim_{}".format(dim_index)), used_names) |
| 177 | + axis_dset = nxdata.create_dataset(axis_name, data=np.asarray(axis.values)) |
| 178 | + axis_dset.attrs["units"] = axis.units |
| 179 | + axis_dset.attrs["quantity"] = axis.quantity |
| 180 | + axis_dset.attrs["dimension_type"] = axis.dimension_type.name |
| 181 | + axis_dset.attrs["long_name"] = "{} ({})".format(axis.quantity, axis.units) |
| 182 | + nxdata.attrs["{}_indices".format(axis_name)] = dim_index |
| 183 | + axes_names.append(axis_name) |
| 184 | + |
| 185 | + nxdata.attrs["axes"] = np.asarray(axes_names, dtype=h5py.string_dtype(encoding="utf-8")) |
| 186 | + |
| 187 | + if dataset.variance is not None: |
| 188 | + nxdata.create_dataset("{}_errors".format(signal_name), data=np.array(dataset.variance)) |
| 189 | + |
| 190 | + sidpy_collection = entry.create_group("sidpy_metadata") |
| 191 | + sidpy_collection.attrs["NX_class"] = "NXcollection" |
| 192 | + _write_json_dataset(sidpy_collection, "metadata", dataset.metadata) |
| 193 | + _write_json_dataset(sidpy_collection, "original_metadata", dataset.original_metadata) |
| 194 | + _write_json_dataset(sidpy_collection, "provenance", dataset.provenance) |
| 195 | + |
| 196 | + h5_file.flush() |
| 197 | + |
| 198 | + if close_file: |
| 199 | + return signal.name |
| 200 | + return signal |
| 201 | + finally: |
| 202 | + if close_file: |
| 203 | + h5_file.close() |
| 204 | + |
| 205 | + |
| 206 | +def _resolve_default_child(parent, default_name): |
| 207 | + if default_name is None: |
| 208 | + return None |
| 209 | + |
| 210 | + default_name = _decode_if_bytes(default_name) |
| 211 | + if default_name in parent: |
| 212 | + return parent[default_name] |
| 213 | + |
| 214 | + if isinstance(default_name, str): |
| 215 | + if default_name.startswith("/"): |
| 216 | + return parent.file[default_name] |
| 217 | + candidate = "{}/{}".format(parent.name.rstrip("/"), default_name).replace("//", "/") |
| 218 | + if candidate in parent.file: |
| 219 | + return parent.file[candidate] |
| 220 | + return None |
| 221 | + |
| 222 | + |
| 223 | +def _find_nxentry(h5_file): |
| 224 | + default_entry = _resolve_default_child(h5_file, h5_file.attrs.get("default")) |
| 225 | + if isinstance(default_entry, h5py.Group) and _decode_if_bytes(default_entry.attrs.get("NX_class")) == "NXentry": |
| 226 | + return default_entry |
| 227 | + |
| 228 | + for key in h5_file: |
| 229 | + obj = h5_file[key] |
| 230 | + if isinstance(obj, h5py.Group) and _decode_if_bytes(obj.attrs.get("NX_class")) == "NXentry": |
| 231 | + return obj |
| 232 | + raise ValueError("Could not find an NXentry group in the provided file") |
| 233 | + |
| 234 | + |
| 235 | +def _find_nxdata(entry): |
| 236 | + default_nxdata = _resolve_default_child(entry, entry.attrs.get("default")) |
| 237 | + if isinstance(default_nxdata, h5py.Group) and _decode_if_bytes(default_nxdata.attrs.get("NX_class")) == "NXdata": |
| 238 | + return default_nxdata |
| 239 | + |
| 240 | + for key in entry: |
| 241 | + obj = entry[key] |
| 242 | + if isinstance(obj, h5py.Group) and _decode_if_bytes(obj.attrs.get("NX_class")) == "NXdata": |
| 243 | + return obj |
| 244 | + raise ValueError("Could not find an NXdata group in the provided entry") |
| 245 | + |
| 246 | + |
| 247 | +def nexus_to_sidpy(h5_path, entry_path=None, nxdata_path=None, signal_name=None): |
| 248 | + """ |
| 249 | + Read a NeXus HDF5 NXdata signal into a sidpy.Dataset. |
| 250 | +
|
| 251 | + Parameters |
| 252 | + ---------- |
| 253 | + h5_path : str or h5py.File |
| 254 | + Source HDF5 file path or open file handle. |
| 255 | + entry_path : str, optional |
| 256 | + Explicit path to the NXentry group. |
| 257 | + nxdata_path : str, optional |
| 258 | + Explicit path to the NXdata group. |
| 259 | + signal_name : str, optional |
| 260 | + Explicit name of the signal dataset inside NXdata. |
| 261 | +
|
| 262 | + Returns |
| 263 | + ------- |
| 264 | + sidpy.Dataset |
| 265 | + Restored dataset. |
| 266 | + """ |
| 267 | + if isinstance(h5_path, h5py.File): |
| 268 | + h5_file = h5_path |
| 269 | + else: |
| 270 | + h5_file = h5py.File(h5_path, "r") |
| 271 | + |
| 272 | + if nxdata_path is not None: |
| 273 | + nxdata = h5_file[nxdata_path] |
| 274 | + if not isinstance(nxdata, h5py.Group): |
| 275 | + raise TypeError("nxdata_path must point to a group") |
| 276 | + if entry_path is None: |
| 277 | + entry = nxdata.parent |
| 278 | + else: |
| 279 | + entry = h5_file[entry_path] |
| 280 | + else: |
| 281 | + if entry_path is not None: |
| 282 | + entry = h5_file[entry_path] |
| 283 | + else: |
| 284 | + entry = _find_nxentry(h5_file) |
| 285 | + nxdata = _find_nxdata(entry) |
| 286 | + |
| 287 | + if _decode_if_bytes(nxdata.attrs.get("NX_class")) != "NXdata": |
| 288 | + raise ValueError("The selected group is not an NXdata group") |
| 289 | + |
| 290 | + if signal_name is None: |
| 291 | + signal_name = _decode_if_bytes(nxdata.attrs.get("signal")) |
| 292 | + signal_name = _clean_name(signal_name, "data") |
| 293 | + |
| 294 | + if signal_name not in nxdata: |
| 295 | + raise ValueError("Could not find signal dataset '{}' in NXdata".format(signal_name)) |
| 296 | + |
| 297 | + signal = nxdata[signal_name] |
| 298 | + entry_title = _decode_if_bytes(entry["title"][()]) if "title" in entry else signal_name |
| 299 | + signal_title = _decode_if_bytes(signal.attrs.get("title", "")) |
| 300 | + dataset = Dataset.from_array(np.array(signal), title=signal_title or entry_title) |
| 301 | + |
| 302 | + dataset.units = _decode_if_bytes(signal.attrs.get("units", "generic")) |
| 303 | + dataset.quantity = _decode_if_bytes(signal.attrs.get("quantity", "generic")) |
| 304 | + |
| 305 | + data_type = _decode_if_bytes(signal.attrs.get("data_type", "UNKNOWN")) |
| 306 | + try: |
| 307 | + dataset.data_type = data_type |
| 308 | + except Warning: |
| 309 | + dataset.data_type = "UNKNOWN" |
| 310 | + |
| 311 | + dataset.modality = _decode_if_bytes(signal.attrs.get("modality", "generic")) |
| 312 | + dataset.source = _decode_if_bytes(signal.attrs.get("source", "generic")) |
| 313 | + dataset.title = signal_title or entry_title |
| 314 | + |
| 315 | + axes_names = _normalize_axes_attr(nxdata.attrs.get("axes")) |
| 316 | + if axes_names is None: |
| 317 | + axes_names = ["dim_{}".format(index) for index in range(dataset.ndim)] |
| 318 | + |
| 319 | + if len(axes_names) != dataset.ndim: |
| 320 | + raise ValueError("NXdata axes metadata does not match signal rank") |
| 321 | + |
| 322 | + for dim_index, axis_name in enumerate(axes_names): |
| 323 | + if axis_name == ".": |
| 324 | + continue |
| 325 | + if axis_name not in nxdata: |
| 326 | + continue |
| 327 | + |
| 328 | + axis_dset = nxdata[axis_name] |
| 329 | + axis_values = np.asarray(axis_dset[()]) |
| 330 | + if axis_values.ndim != 1: |
| 331 | + raise NotImplementedError("Only 1D NXdata axes are currently supported") |
| 332 | + if axis_values.shape[0] != dataset.shape[dim_index]: |
| 333 | + raise ValueError("Axis '{}' length does not match data dimension {}".format(axis_name, dim_index)) |
| 334 | + |
| 335 | + dimension = Dimension(axis_values, |
| 336 | + name=axis_name, |
| 337 | + quantity=_decode_if_bytes(axis_dset.attrs.get("quantity", axis_name)), |
| 338 | + units=_decode_if_bytes(axis_dset.attrs.get("units", "generic")), |
| 339 | + dimension_type=_decode_if_bytes(axis_dset.attrs.get("dimension_type", "UNKNOWN"))) |
| 340 | + dataset.set_dimension(dim_index, dimension) |
| 341 | + |
| 342 | + if "sidpy_metadata" in entry: |
| 343 | + sidpy_collection = entry["sidpy_metadata"] |
| 344 | + if isinstance(sidpy_collection, h5py.Group): |
| 345 | + dataset.metadata = _read_json_dataset(sidpy_collection, "metadata") |
| 346 | + dataset.original_metadata = _read_json_dataset(sidpy_collection, "original_metadata") |
| 347 | + provenance = _read_json_dataset(sidpy_collection, "provenance") |
| 348 | + if provenance: |
| 349 | + dataset.provenance = provenance |
| 350 | + |
| 351 | + dataset.h5_dataset = signal |
| 352 | + return dataset |
0 commit comments