Skip to content

Commit 4685462

Browse files
author
Rama Vasudevan
committed
nexus compatibility
1 parent d8f8f42 commit 4685462

3 files changed

Lines changed: 474 additions & 1 deletion

File tree

sidpy/io/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,9 @@
22
User interface utilities
33
"""
44
from . import interface_utils
5+
from . import nexus
56
from .interface_utils import FileWidget, ChooseDataset
6-
__all__ = ['interface_utils', 'FileWidget', 'ChooseDataset']
7+
from .nexus import sidpy_to_nexus_hdf5, nexus_to_sidpy
8+
9+
__all__ = ['interface_utils', 'nexus', 'FileWidget', 'ChooseDataset',
10+
'sidpy_to_nexus_hdf5', 'nexus_to_sidpy']

sidpy/io/nexus.py

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
"""Utilities for converting between sidpy datasets and NeXus HDF5."""
2+
3+
from __future__ import absolute_import, division, print_function, unicode_literals
4+
5+
import datetime
6+
import json
7+
8+
import h5py
9+
import numpy as np
10+
11+
from sidpy.sid import Dataset, Dimension
12+
13+
__all__ = ["sidpy_to_nexus_hdf5", "nexus_to_sidpy"]
14+
15+
16+
def _clean_name(name, fallback):
17+
if name is None:
18+
name = ""
19+
name = str(name).strip().replace("/", "_")
20+
return name if name else fallback
21+
22+
23+
def _ensure_unique_name(name, used_names):
24+
if name not in used_names:
25+
used_names.add(name)
26+
return name
27+
28+
index = 1
29+
while True:
30+
candidate = "{}_{}".format(name, index)
31+
if candidate not in used_names:
32+
used_names.add(candidate)
33+
return candidate
34+
index += 1
35+
36+
37+
def _decode_if_bytes(value):
38+
if isinstance(value, bytes):
39+
return value.decode("utf-8")
40+
if isinstance(value, np.bytes_):
41+
return value.decode("utf-8")
42+
if isinstance(value, np.ndarray):
43+
return [_decode_if_bytes(item) for item in value.tolist()]
44+
return value
45+
46+
47+
def _normalize_axes_attr(value):
48+
value = _decode_if_bytes(value)
49+
if isinstance(value, str):
50+
return [value]
51+
if isinstance(value, (list, tuple)):
52+
return [str(item) for item in value]
53+
return None
54+
55+
56+
def _json_ready(value):
57+
if isinstance(value, bytes):
58+
return value.decode("utf-8", errors="replace")
59+
if isinstance(value, np.bytes_):
60+
return value.decode("utf-8", errors="replace")
61+
if isinstance(value, dict):
62+
return {str(_json_ready(key)): _json_ready(val) for key, val in value.items()}
63+
if isinstance(value, (list, tuple)):
64+
return [_json_ready(item) for item in value]
65+
if isinstance(value, np.ndarray):
66+
return value.tolist()
67+
if isinstance(value, np.generic):
68+
return value.item()
69+
return value
70+
71+
72+
def _write_json_dataset(parent, name, payload):
73+
string_dtype = h5py.string_dtype(encoding="utf-8")
74+
parent.create_dataset(name, data=json.dumps(_json_ready(payload)), dtype=string_dtype)
75+
76+
77+
def _read_json_dataset(parent, name):
78+
if name not in parent:
79+
return {}
80+
raw = parent[name][()]
81+
raw = _decode_if_bytes(raw)
82+
if not raw:
83+
return {}
84+
return json.loads(raw)
85+
86+
87+
def _set_root_attrs(h5_file, default_entry):
88+
time_stamp = datetime.datetime.now().isoformat()
89+
h5_file.attrs["default"] = default_entry
90+
h5_file.attrs["file_name"] = h5_file.filename
91+
h5_file.attrs["file_time"] = time_stamp
92+
h5_file.attrs["creator"] = "sidpy"
93+
h5_file.attrs["HDF5_Version"] = h5py.version.hdf5_version
94+
h5_file.attrs["h5py_version"] = h5py.version.version
95+
96+
97+
def sidpy_to_nexus_hdf5(dataset, h5_path, entry_name="entry", nxdata_name="data",
98+
signal_name="data", mode="w", compression=None):
99+
"""
100+
Write a sidpy.Dataset to a NeXus-compatible HDF5 file.
101+
102+
Parameters
103+
----------
104+
dataset : sidpy.Dataset
105+
Dataset to serialize.
106+
h5_path : str or h5py.File
107+
Destination HDF5 path or open file handle.
108+
entry_name : str, optional
109+
Name of the NXentry group.
110+
nxdata_name : str, optional
111+
Name of the NXdata group.
112+
signal_name : str, optional
113+
Name of the primary signal dataset within NXdata.
114+
mode : str, optional
115+
File mode used when `h5_path` is a path.
116+
compression : str, optional
117+
Compression passed to h5py when creating datasets.
118+
119+
Returns
120+
-------
121+
str or h5py.Dataset
122+
Signal dataset path for path-based writes, or the written h5py.Dataset
123+
when an open file handle is provided.
124+
"""
125+
if not isinstance(dataset, Dataset):
126+
raise TypeError("dataset must be a sidpy.Dataset")
127+
128+
entry_name = _clean_name(entry_name, "entry")
129+
nxdata_name = _clean_name(nxdata_name, "data")
130+
signal_name = _clean_name(signal_name, "data")
131+
132+
close_file = False
133+
if isinstance(h5_path, h5py.File):
134+
h5_file = h5_path
135+
else:
136+
h5_file = h5py.File(h5_path, mode)
137+
close_file = True
138+
139+
try:
140+
if entry_name in h5_file:
141+
del h5_file[entry_name]
142+
143+
_set_root_attrs(h5_file, entry_name)
144+
145+
entry = h5_file.create_group(entry_name)
146+
entry.attrs["NX_class"] = "NXentry"
147+
entry.attrs["default"] = nxdata_name
148+
entry.create_dataset("title", data=dataset.title or signal_name)
149+
150+
nxdata = entry.create_group(nxdata_name)
151+
nxdata.attrs["NX_class"] = "NXdata"
152+
nxdata.attrs["signal"] = signal_name
153+
154+
signal_kwargs = {}
155+
if compression is not None:
156+
signal_kwargs["compression"] = compression
157+
158+
signal = nxdata.create_dataset(signal_name, data=np.array(dataset), **signal_kwargs)
159+
signal.attrs["units"] = dataset.units
160+
signal.attrs["quantity"] = dataset.quantity
161+
signal.attrs["data_type"] = dataset.data_type.name
162+
signal.attrs["modality"] = dataset.modality
163+
signal.attrs["source"] = dataset.source
164+
signal.attrs["title"] = dataset.title
165+
signal.attrs["long_name"] = dataset.data_descriptor
166+
167+
used_names = {signal_name}
168+
axes_names = []
169+
for dim_index in range(dataset.ndim):
170+
axis = dataset._axes.get(dim_index)
171+
if axis is None:
172+
axis_name = "."
173+
axes_names.append(axis_name)
174+
continue
175+
176+
axis_name = _ensure_unique_name(_clean_name(axis.name, "dim_{}".format(dim_index)), used_names)
177+
axis_dset = nxdata.create_dataset(axis_name, data=np.asarray(axis.values))
178+
axis_dset.attrs["units"] = axis.units
179+
axis_dset.attrs["quantity"] = axis.quantity
180+
axis_dset.attrs["dimension_type"] = axis.dimension_type.name
181+
axis_dset.attrs["long_name"] = "{} ({})".format(axis.quantity, axis.units)
182+
nxdata.attrs["{}_indices".format(axis_name)] = dim_index
183+
axes_names.append(axis_name)
184+
185+
nxdata.attrs["axes"] = np.asarray(axes_names, dtype=h5py.string_dtype(encoding="utf-8"))
186+
187+
if dataset.variance is not None:
188+
nxdata.create_dataset("{}_errors".format(signal_name), data=np.array(dataset.variance))
189+
190+
sidpy_collection = entry.create_group("sidpy_metadata")
191+
sidpy_collection.attrs["NX_class"] = "NXcollection"
192+
_write_json_dataset(sidpy_collection, "metadata", dataset.metadata)
193+
_write_json_dataset(sidpy_collection, "original_metadata", dataset.original_metadata)
194+
_write_json_dataset(sidpy_collection, "provenance", dataset.provenance)
195+
196+
h5_file.flush()
197+
198+
if close_file:
199+
return signal.name
200+
return signal
201+
finally:
202+
if close_file:
203+
h5_file.close()
204+
205+
206+
def _resolve_default_child(parent, default_name):
207+
if default_name is None:
208+
return None
209+
210+
default_name = _decode_if_bytes(default_name)
211+
if default_name in parent:
212+
return parent[default_name]
213+
214+
if isinstance(default_name, str):
215+
if default_name.startswith("/"):
216+
return parent.file[default_name]
217+
candidate = "{}/{}".format(parent.name.rstrip("/"), default_name).replace("//", "/")
218+
if candidate in parent.file:
219+
return parent.file[candidate]
220+
return None
221+
222+
223+
def _find_nxentry(h5_file):
224+
default_entry = _resolve_default_child(h5_file, h5_file.attrs.get("default"))
225+
if isinstance(default_entry, h5py.Group) and _decode_if_bytes(default_entry.attrs.get("NX_class")) == "NXentry":
226+
return default_entry
227+
228+
for key in h5_file:
229+
obj = h5_file[key]
230+
if isinstance(obj, h5py.Group) and _decode_if_bytes(obj.attrs.get("NX_class")) == "NXentry":
231+
return obj
232+
raise ValueError("Could not find an NXentry group in the provided file")
233+
234+
235+
def _find_nxdata(entry):
236+
default_nxdata = _resolve_default_child(entry, entry.attrs.get("default"))
237+
if isinstance(default_nxdata, h5py.Group) and _decode_if_bytes(default_nxdata.attrs.get("NX_class")) == "NXdata":
238+
return default_nxdata
239+
240+
for key in entry:
241+
obj = entry[key]
242+
if isinstance(obj, h5py.Group) and _decode_if_bytes(obj.attrs.get("NX_class")) == "NXdata":
243+
return obj
244+
raise ValueError("Could not find an NXdata group in the provided entry")
245+
246+
247+
def nexus_to_sidpy(h5_path, entry_path=None, nxdata_path=None, signal_name=None):
248+
"""
249+
Read a NeXus HDF5 NXdata signal into a sidpy.Dataset.
250+
251+
Parameters
252+
----------
253+
h5_path : str or h5py.File
254+
Source HDF5 file path or open file handle.
255+
entry_path : str, optional
256+
Explicit path to the NXentry group.
257+
nxdata_path : str, optional
258+
Explicit path to the NXdata group.
259+
signal_name : str, optional
260+
Explicit name of the signal dataset inside NXdata.
261+
262+
Returns
263+
-------
264+
sidpy.Dataset
265+
Restored dataset.
266+
"""
267+
if isinstance(h5_path, h5py.File):
268+
h5_file = h5_path
269+
else:
270+
h5_file = h5py.File(h5_path, "r")
271+
272+
if nxdata_path is not None:
273+
nxdata = h5_file[nxdata_path]
274+
if not isinstance(nxdata, h5py.Group):
275+
raise TypeError("nxdata_path must point to a group")
276+
if entry_path is None:
277+
entry = nxdata.parent
278+
else:
279+
entry = h5_file[entry_path]
280+
else:
281+
if entry_path is not None:
282+
entry = h5_file[entry_path]
283+
else:
284+
entry = _find_nxentry(h5_file)
285+
nxdata = _find_nxdata(entry)
286+
287+
if _decode_if_bytes(nxdata.attrs.get("NX_class")) != "NXdata":
288+
raise ValueError("The selected group is not an NXdata group")
289+
290+
if signal_name is None:
291+
signal_name = _decode_if_bytes(nxdata.attrs.get("signal"))
292+
signal_name = _clean_name(signal_name, "data")
293+
294+
if signal_name not in nxdata:
295+
raise ValueError("Could not find signal dataset '{}' in NXdata".format(signal_name))
296+
297+
signal = nxdata[signal_name]
298+
entry_title = _decode_if_bytes(entry["title"][()]) if "title" in entry else signal_name
299+
signal_title = _decode_if_bytes(signal.attrs.get("title", ""))
300+
dataset = Dataset.from_array(np.array(signal), title=signal_title or entry_title)
301+
302+
dataset.units = _decode_if_bytes(signal.attrs.get("units", "generic"))
303+
dataset.quantity = _decode_if_bytes(signal.attrs.get("quantity", "generic"))
304+
305+
data_type = _decode_if_bytes(signal.attrs.get("data_type", "UNKNOWN"))
306+
try:
307+
dataset.data_type = data_type
308+
except Warning:
309+
dataset.data_type = "UNKNOWN"
310+
311+
dataset.modality = _decode_if_bytes(signal.attrs.get("modality", "generic"))
312+
dataset.source = _decode_if_bytes(signal.attrs.get("source", "generic"))
313+
dataset.title = signal_title or entry_title
314+
315+
axes_names = _normalize_axes_attr(nxdata.attrs.get("axes"))
316+
if axes_names is None:
317+
axes_names = ["dim_{}".format(index) for index in range(dataset.ndim)]
318+
319+
if len(axes_names) != dataset.ndim:
320+
raise ValueError("NXdata axes metadata does not match signal rank")
321+
322+
for dim_index, axis_name in enumerate(axes_names):
323+
if axis_name == ".":
324+
continue
325+
if axis_name not in nxdata:
326+
continue
327+
328+
axis_dset = nxdata[axis_name]
329+
axis_values = np.asarray(axis_dset[()])
330+
if axis_values.ndim != 1:
331+
raise NotImplementedError("Only 1D NXdata axes are currently supported")
332+
if axis_values.shape[0] != dataset.shape[dim_index]:
333+
raise ValueError("Axis '{}' length does not match data dimension {}".format(axis_name, dim_index))
334+
335+
dimension = Dimension(axis_values,
336+
name=axis_name,
337+
quantity=_decode_if_bytes(axis_dset.attrs.get("quantity", axis_name)),
338+
units=_decode_if_bytes(axis_dset.attrs.get("units", "generic")),
339+
dimension_type=_decode_if_bytes(axis_dset.attrs.get("dimension_type", "UNKNOWN")))
340+
dataset.set_dimension(dim_index, dimension)
341+
342+
if "sidpy_metadata" in entry:
343+
sidpy_collection = entry["sidpy_metadata"]
344+
if isinstance(sidpy_collection, h5py.Group):
345+
dataset.metadata = _read_json_dataset(sidpy_collection, "metadata")
346+
dataset.original_metadata = _read_json_dataset(sidpy_collection, "original_metadata")
347+
provenance = _read_json_dataset(sidpy_collection, "provenance")
348+
if provenance:
349+
dataset.provenance = provenance
350+
351+
dataset.h5_dataset = signal
352+
return dataset

0 commit comments

Comments
 (0)