From 4cb7995e42597d82153331aade71772d99b308ca Mon Sep 17 00:00:00 2001 From: YooSunyoung <17974113+YooSunYoung@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:06:06 +0100 Subject: [PATCH 1/6] Use generic workflow. --- src/ess/nmx/_executable_helper.py | 146 +-------- src/ess/nmx/configurations.py | 149 +++++++++ src/ess/nmx/executables.py | 420 +++++++------------------ src/ess/nmx/nexus.py | 503 +++++++++++------------------- src/ess/nmx/types.py | 68 ++++ src/ess/nmx/workflows.py | 308 ++++++++++++++++++ tests/executable_test.py | 82 +++-- 7 files changed, 889 insertions(+), 787 deletions(-) create mode 100644 src/ess/nmx/configurations.py create mode 100644 src/ess/nmx/workflows.py diff --git a/src/ess/nmx/_executable_helper.py b/src/ess/nmx/_executable_helper.py index d8ca016..afa1843 100644 --- a/src/ess/nmx/_executable_helper.py +++ b/src/ess/nmx/_executable_helper.py @@ -10,11 +10,11 @@ from types import UnionType from typing import Literal, TypeGuard, TypeVar, Union, get_args, get_origin -from pydantic import BaseModel, Field +from pydantic import BaseModel from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined -from .types import Compression +from .configurations import InputConfig, OutputConfig, ReductionConfig, WorkflowConfig def _validate_annotation(annotation) -> TypeGuard[type]: @@ -140,148 +140,6 @@ def add_args_from_pydantic_model( return parser -class InputConfig(BaseModel): - # Add title of the basemodel - model_config = {"title": "Input Configuration"} - # File IO - input_file: list[str] = Field( - title="Input File", - description="Path to the input file. If multiple file paths are given," - " the output(histogram) will be merged(summed) " - "and will not save individual outputs per input file. ", - ) - swmr: bool = Field( - title="SWMR Mode", - description="Open the input file in SWMR mode", - default=False, - ) - # Detector selection - detector_ids: list[int] = Field( - title="Detector IDs", - description="Detector indices to process", - default=[0, 1, 2], - ) - # Chunking options - iter_chunk: bool = Field( - title="Iterate in Chunks", - description="Whether to process the input file in chunks " - " based on the hdf5 dataset chunk size. " - "It is ignored if hdf5 dataset is not chunked. " - "If True, it overrides chunk-size-pulse and chunk-size-events options.", - default=False, - ) - chunk_size_pulse: int = Field( - title="Chunk Size Pulse", - description="Number of pulses to process in each chunk. " - "If 0 or negative, process all pulses at once.", - default=0, - ) - chunk_size_events: int = Field( - title="Chunk Size Events", - description="Number of events to process in each chunk. " - "If 0 or negative, process all events at once." - "If both chunk-size-pulse and chunk-size-events are set, " - "chunk-size-pulse is preferred.", - default=0, - ) - - -class TimeBinUnit(enum.StrEnum): - ms = 'ms' - us = 'us' - ns = 'ns' - - -class TimeBinCoordinate(enum.StrEnum): - event_time_offset = 'event_time_offset' - time_of_flight = 'time_of_flight' - - -class WorkflowConfig(BaseModel): - # Add title of the basemodel - model_config = {"title": "Workflow Configuration"} - time_bin_coordinate: TimeBinCoordinate = Field( - title="Time Bin Coordinate", - description="Coordinate to bin the time data.", - default=TimeBinCoordinate.event_time_offset, - ) - nbins: int = Field( - title="Number of Time Bins", - description="Number of Time bins", - default=50, - ) - min_time_bin: int | None = Field( - title="Minimum Time Bin", - description="Minimum time edge of [time_bin_coordinate] in [time_bin_unit].", - default=None, - ) - max_time_bin: int | None = Field( - title="Maximum Time Bin", - description="Maximum time edge of [time_bin_coordinate] in [time_bin_unit].", - default=None, - ) - time_bin_unit: TimeBinUnit = Field( - title="Unit of Time Bins", - description="Unit of time bins.", - default=TimeBinUnit.ms, - ) - tof_lookup_table_file_path: str | None = Field( - title="TOF Lookup Table File Path", - description="Path to the TOF lookup table file. " - "If None, the lookup table will be computed on-the-fly.", - default=None, - ) - tof_simulation_min_wavelength: float = Field( - title="TOF Simulation Minimum Wavelength", - description="Minimum wavelength for TOF simulation in Angstrom.", - default=1.8, - ) - tof_simulation_max_wavelength: float = Field( - title="TOF Simulation Maximum Wavelength", - description="Maximum wavelength for TOF simulation in Angstrom.", - default=3.6, - ) - tof_simulation_seed: int = Field( - title="TOF Simulation Seed", - description="Random seed for TOF simulation.", - default=42, # No reason. - ) - - -class OutputConfig(BaseModel): - # Add title of the basemodel - model_config = {"title": "Output Configuration"} - # Log verbosity - verbose: bool = Field( - title="Verbose Logging", - description="Increase output verbosity.", - default=False, - ) - # File output - output_file: str = Field( - title="Output File", - description="Path to the output file.", - default="scipp_output.h5", - ) - compression: Compression = Field( - title="Compression", - description="Compress option of reduced output file.", - default=Compression.BITSHUFFLE_LZ4, - ) - - -class ReductionConfig(BaseModel): - """Container for all reduction configurations.""" - - inputs: InputConfig - workflow: WorkflowConfig = Field(default_factory=WorkflowConfig) - output: OutputConfig = Field(default_factory=OutputConfig) - - @property - def _children(self) -> list[BaseModel]: - return [self.inputs, self.workflow, self.output] - - T = TypeVar('T', bound=BaseModel) diff --git a/src/ess/nmx/configurations.py b/src/ess/nmx/configurations.py new file mode 100644 index 0000000..87e1ce9 --- /dev/null +++ b/src/ess/nmx/configurations.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import enum + +from pydantic import BaseModel, Field + +from .types import Compression + + +class InputConfig(BaseModel): + # Add title of the basemodel + model_config = {"title": "Input Configuration"} + # File IO + input_file: list[str] = Field( + title="Input File", + description="Path to the input file. If multiple file paths are given," + " the output(histogram) will be merged(summed) " + "and will not save individual outputs per input file. ", + ) + swmr: bool = Field( + title="SWMR Mode", + description="Open the input file in SWMR mode", + default=False, + ) + # Detector selection + detector_ids: list[int] = Field( + title="Detector IDs", + description="Detector indices to process", + default=[0, 1, 2], + ) + # Chunking options + iter_chunk: bool = Field( + title="Iterate in Chunks", + description="Whether to process the input file in chunks " + " based on the hdf5 dataset chunk size. " + "It is ignored if hdf5 dataset is not chunked. " + "If True, it overrides chunk-size-pulse and chunk-size-events options.", + default=False, + ) + chunk_size_pulse: int = Field( + title="Chunk Size Pulse", + description="Number of pulses to process in each chunk. " + "If 0 or negative, process all pulses at once.", + default=0, + ) + chunk_size_events: int = Field( + title="Chunk Size Events", + description="Number of events to process in each chunk. " + "If 0 or negative, process all events at once." + "If both chunk-size-pulse and chunk-size-events are set, " + "chunk-size-pulse is preferred.", + default=0, + ) + + +class TimeBinUnit(enum.StrEnum): + ms = 'ms' + us = 'us' + ns = 'ns' + + +class TimeBinCoordinate(enum.StrEnum): + event_time_offset = 'event_time_offset' + time_of_flight = 'time_of_flight' + + +class WorkflowConfig(BaseModel): + # Add title of the basemodel + model_config = {"title": "Workflow Configuration"} + time_bin_coordinate: TimeBinCoordinate = Field( + title="Time Bin Coordinate", + description="Coordinate to bin the time data.", + default=TimeBinCoordinate.event_time_offset, + ) + nbins: int = Field( + title="Number of Time Bins", + description="Number of Time bins", + default=50, + ) + min_time_bin: int | None = Field( + title="Minimum Time Bin", + description="Minimum time edge of [time_bin_coordinate] in [time_bin_unit].", + default=None, + ) + max_time_bin: int | None = Field( + title="Maximum Time Bin", + description="Maximum time edge of [time_bin_coordinate] in [time_bin_unit].", + default=None, + ) + time_bin_unit: TimeBinUnit = Field( + title="Unit of Time Bins", + description="Unit of time bins.", + default=TimeBinUnit.ms, + ) + tof_lookup_table_file_path: str | None = Field( + title="TOF Lookup Table File Path", + description="Path to the TOF lookup table file. " + "If None, the lookup table will be computed on-the-fly.", + default=None, + ) + tof_simulation_min_wavelength: float = Field( + title="TOF Simulation Minimum Wavelength", + description="Minimum wavelength for TOF simulation in Angstrom.", + default=1.8, + ) + tof_simulation_max_wavelength: float = Field( + title="TOF Simulation Maximum Wavelength", + description="Maximum wavelength for TOF simulation in Angstrom.", + default=3.6, + ) + tof_simulation_seed: int = Field( + title="TOF Simulation Seed", + description="Random seed for TOF simulation.", + default=42, # No reason. + ) + + +class OutputConfig(BaseModel): + # Add title of the basemodel + model_config = {"title": "Output Configuration"} + # Log verbosity + verbose: bool = Field( + title="Verbose Logging", + description="Increase output verbosity.", + default=False, + ) + # File output + output_file: str = Field( + title="Output File", + description="Path to the output file.", + default="scipp_output.h5", + ) + compression: Compression = Field( + title="Compression", + description="Compress option of reduced output file.", + default=Compression.BITSHUFFLE_LZ4, + ) + + +class ReductionConfig(BaseModel): + """Container for all reduction configurations.""" + + inputs: InputConfig + workflow: WorkflowConfig = Field(default_factory=WorkflowConfig) + output: OutputConfig = Field(default_factory=OutputConfig) + + @property + def _children(self) -> list[BaseModel]: + return [self.inputs, self.workflow, self.output] diff --git a/src/ess/nmx/executables.py b/src/ess/nmx/executables.py index 52ded28..2d83cb4 100644 --- a/src/ess/nmx/executables.py +++ b/src/ess/nmx/executables.py @@ -2,176 +2,35 @@ # Copyright (c) 2025 Scipp contributors (https://github.com/scipp) import logging import pathlib -from collections.abc import Callable -from dataclasses import dataclass -from typing import Literal +from collections.abc import Callable, Iterable +import sciline as sl import scipp as sc import scippnexus as snx +from ess.reduce.nexus.types import Filename, NeXusName, SampleRun +from ess.reduce.time_of_flight.types import TimeOfFlightLookupTable, TofDetector + from ._executable_helper import ( - ReductionConfig, build_logger, build_reduction_argument_parser, collect_matching_input_files, reduction_config_from_args, ) - -# Temporarily keeping them until we migrate GenericWorkflow here -from .mcstas.nexus import ( - _compute_positions, - _export_detector_metadata_as_nxlauetof, - _export_reduced_data_as_nxlauetof, - _export_static_metadata_as_nxlauetof, +from .configurations import ReductionConfig, WorkflowConfig +from .nexus import ( + export_detector_metadata_as_nxlauetof, + export_monitor_metadata_as_nxlauetof, + export_reduced_data_as_nxlauetof, + export_static_metadata_as_nxlauetof, ) - -# Temporarily keeping them until we migrate GenericWorkflow here -from .mcstas.types import NMXDetectorMetadata, NMXExperimentMetadata -from .types import Compression - - -def _validate_chunk_size(chunk_size: int) -> None: - """Validate the chunk size.""" - if not isinstance(chunk_size, int): - raise TypeError("Chunk size must be an integer.") - if chunk_size < -1: - raise ValueError("Invalid chunk size. It should be -1(for all) or > 0.") - - -def _retrieve_source_position(file: snx.File) -> sc.Variable: - da = file['entry/instrument/source'][()] - return _compute_positions(da, auto_fix_transformations=True)['position'] - - -def _retrieve_sample_position(file: snx.File) -> sc.Variable: - da = file['entry/sample'][()] - return _compute_positions(da, auto_fix_transformations=True)['position'] - - -def _retrieve_crystal_rotation(file: snx.File) -> sc.Variable: - if 'crystal_rotation' not in file['entry/sample']: - import warnings - - warnings.warn( - "No crystal rotation found in the Nexus file under " - "'entry/sample/crystal_rotation'. Returning zero rotation.", - RuntimeWarning, - stacklevel=2, - ) - return sc.vector([0, 0, 0], unit='deg') - - # Temporary way of storing crystal rotation. - # streaming-sample-mcstas module writes crystal rotation under - # 'entry/sample/crystal_rotation' as an array of three values. - return file['entry/sample/crystal_rotation'][()] - - -def _decide_fast_axis(da: sc.DataArray) -> str: - x_slice = da['x_pixel_offset', 0].coords['detector_number'] - y_slice = da['y_pixel_offset', 0].coords['detector_number'] - - if (x_slice.max() < y_slice.max()).value: - return 'y' - elif (x_slice.max() > y_slice.max()).value: - return 'x' - else: - raise ValueError( - "Cannot decide fast axis based on pixel offsets. " - "Please specify the fast axis explicitly." - ) - - -def _decide_step(offsets: sc.Variable) -> sc.Variable: - """Decide the step size based on the offsets assuming at least 2 values.""" - sorted_offsets = sc.sort(offsets, key=offsets.dim, order='ascending') - return sorted_offsets[1] - sorted_offsets[0] - - -@dataclass -class DetectorDesc: - """Detector information extracted from McStas instrument xml description.""" - - name: str - id_start: int # 'idstart' - num_x: int # 'xpixels' - num_y: int # 'ypixels' - step_x: sc.Variable # 'xstep' - step_y: sc.Variable # 'ystep' - start_x: float # 'xstart' - start_y: float # 'ystart' - position: sc.Variable # 'x', 'y', 'z' - # Calculated fields - rotation_matrix: sc.Variable - fast_axis_name: str - slow_axis_name: str - fast_axis: sc.Variable - slow_axis: sc.Variable - - -def build_detector_desc( - name: str, dg: sc.DataGroup, *, fast_axis: Literal['x', 'y'] | None = None -) -> DetectorDesc: - da: sc.DataArray = dg['data'] - _fast_axis = fast_axis if fast_axis is not None else _decide_fast_axis(da) - transformation_matrix = dg['transform_matrix'] - t_unit = transformation_matrix.unit - fast_axis_vector = ( - sc.vector([1, 0, 0], unit=t_unit) - if _fast_axis == 'x' - else sc.vector([0, 1, 0], unit=t_unit) - ) - slow_axis_vector = ( - sc.vector([0, 1, 0], unit=t_unit) - if _fast_axis == 'x' - else sc.vector([1, 0, 0], unit=t_unit) - ) - return DetectorDesc( - name=name, - id_start=da.coords['detector_number'].min().value, - num_x=da.sizes['x_pixel_offset'], - num_y=da.sizes['y_pixel_offset'], - start_x=da.coords['x_pixel_offset'].min().value, - start_y=da.coords['y_pixel_offset'].min().value, - position=dg['position'], - rotation_matrix=dg['transform_matrix'], - fast_axis_name=_fast_axis, - slow_axis_name='x' if _fast_axis == 'y' else 'y', - fast_axis=fast_axis_vector, - slow_axis=slow_axis_vector, - step_x=_decide_step(da.coords['x_pixel_offset']), - step_y=_decide_step(da.coords['y_pixel_offset']), - ) - - -def calculate_number_of_chunks(detector_gr: snx.Group, *, chunk_size: int = 0) -> int: - _validate_chunk_size(chunk_size) - event_time_zero_size = detector_gr.sizes['event_time_zero'] - if chunk_size == -1: - return 1 # Read all at once - else: - return event_time_zero_size // chunk_size + int( - event_time_zero_size % chunk_size != 0 - ) - - -def build_toa_bin_edges( - *, - min_toa: sc.Variable | int = 0, - max_toa: sc.Variable | int = int((1 / 14) * 1_000), # Default for ESS NMX - toa_bin_edges: sc.Variable | int = 250, -) -> sc.Variable: - if isinstance(toa_bin_edges, sc.Variable): - return toa_bin_edges - elif isinstance(toa_bin_edges, int): - min_toa = sc.scalar(min_toa, unit='ms') if isinstance(min_toa, int) else min_toa - max_toa = sc.scalar(max_toa, unit='ms') if isinstance(max_toa, int) else max_toa - return sc.linspace( - dim='event_time_offset', - start=min_toa.value, - stop=max_toa.to(unit=min_toa.unit).value, - unit=min_toa.unit, - num=toa_bin_edges + 1, - ) +from .types import ( + NMXDetectorMetadata, + NMXMonitorMetadata, + NMXSampleMetadata, + NMXSourceMetadata, +) +from .workflows import NMXWorkflow, compute_lookup_table, select_detector_names def _retrieve_input_file(input_file: list[pathlib.Path] | pathlib.Path) -> pathlib.Path: @@ -201,6 +60,45 @@ def _retrieve_display( return logging.getLogger(__name__).info +def compute_and_cache_lookup_table( + *, + wf: sl.Pipeline, + workflow_config: WorkflowConfig, + detector_names: Iterable[str], + display: Callable, +) -> sl.Pipeline: + """Compute and cache the TOF lookup table in the workflow. + + **Note**: ``base_wf`` is modified in-place and also returned. + """ + # We compute one lookup table that covers all range + # to avoid multiple tof simulations. + if workflow_config.tof_lookup_table_file_path is None: + display("Computing TOF lookup table from simulation...") + else: + display("Loading TOF lookup table from file...") + + lookup_table = compute_lookup_table( + base_wf=wf, workflow_config=workflow_config, detector_names=detector_names + ) + wf[TimeOfFlightLookupTable] = lookup_table + return wf + + +def _finalize_tof_bin_edges( + *, tof_das: sc.DataGroup, config: WorkflowConfig +) -> sc.Variable: + tof_bin_edges = sc.concat( + tuple(tof_da.coords['tof'] for tof_da in tof_das.values()), dim='tof' + ) + return sc.linspace( + dim='tof', + start=sc.min(tof_bin_edges), + stop=sc.max(tof_bin_edges), + num=config.nbins + 1, + ) + + def reduction( *, config: ReductionConfig, @@ -242,149 +140,73 @@ def reduction( output_file_path = pathlib.Path(config.output.output_file).resolve() display(f"Output file: {output_file_path}") - toa_bin_edges = build_toa_bin_edges( - min_toa=config.workflow.min_time_bin or 0, - max_toa=config.workflow.max_time_bin or int((1 / 14) * 1_000), - toa_bin_edges=config.workflow.nbins, + detector_names = select_detector_names( + input_files=[input_file_path], detector_ids=config.inputs.detector_ids + ) + + base_wf = NMXWorkflow() + # Insert input file path into the workflow for later use + base_wf[Filename] = input_file_path + + base_wf = compute_and_cache_lookup_table( + wf=base_wf, + workflow_config=config.workflow, + detector_names=detector_names, + display=display, ) - with snx.File(input_file_path) as f: - intrument_group = f['entry/instrument'] - dets = intrument_group[snx.NXdetector] - detector_group_keys = list(dets.keys()) - display(f"Found NXdetectors: {detector_group_keys}") - detector_id_map = { - det_name: dets[det_name] - for i, det_name in enumerate(detector_group_keys) - if i in config.inputs.detector_ids or det_name in config.inputs.detector_ids - } - if len(detector_id_map) != len(config.inputs.detector_ids): - raise ValueError( - f"Requested detector ids {config.inputs.detector_ids} " - "not found in the file.\n" - f"Found {detector_group_keys}\n" - f"Try using integer indices instead of names." - ) - display(f"Selected detectors: {list(detector_id_map.keys())}") - source_position = _retrieve_source_position(f) - sample_position = _retrieve_sample_position(f) - crystal_rotation = _retrieve_crystal_rotation(f) - experiment_metadata = NMXExperimentMetadata( - sc.DataGroup( - { - 'crystal_rotation': crystal_rotation, - 'sample_position': sample_position, - 'source_position': source_position, - 'sample_name': sc.scalar(f['entry/sample/name'][()]), - } - ) + metadatas = base_wf.compute((NMXSampleMetadata, NMXSourceMetadata)) + export_static_metadata_as_nxlauetof( + sample_metadata=metadatas[NMXSampleMetadata], + source_metadata=metadatas[NMXSourceMetadata], + output_file=config.output.output_file, + ) + tof_das = sc.DataGroup() + detector_metas = sc.DataGroup() + for detector_name in detector_names: + cur_wf = base_wf.copy() + cur_wf[NeXusName[snx.NXdetector]] = detector_name + results = cur_wf.compute((TofDetector[SampleRun], NMXDetectorMetadata)) + detector_meta: NMXDetectorMetadata = results[NMXDetectorMetadata] + export_detector_metadata_as_nxlauetof( + detector_metadata=detector_meta, output_file=config.output.output_file ) - display(experiment_metadata) - display("Experiment metadata component:") - for name, component in experiment_metadata.items(): - display(f"{name}: {component}") - - _export_static_metadata_as_nxlauetof( - experiment_metadata=experiment_metadata, - output_file=output_file_path, + detector_metas[detector_name] = detector_meta + # Binning into 1 bin and getting final tof bin edges later. + tof_das[detector_name] = results[TofDetector[SampleRun]].bin(tof=1) + + tof_bin_edges = _finalize_tof_bin_edges(tof_das=tof_das, config=config.workflow) + + monitor_metadata = NMXMonitorMetadata( + tof_bin_coord='tof', + # TODO: Use real monitor data + # Currently NMX simulations or experiments do not have monitors + monitor_histogram=sc.DataArray( + coords={'tof': tof_bin_edges}, + data=sc.ones_like(tof_bin_edges[:-1]), + ), + ) + export_monitor_metadata_as_nxlauetof( + monitor_metadata=monitor_metadata, output_file=config.output.output_file + ) + + # Histogram detector counts + tof_histograms = sc.DataGroup() + for detector_name, tof_da in tof_das.items(): + det_meta: NMXDetectorMetadata = detector_metas[detector_name] + histogram = tof_da.hist(tof=tof_bin_edges) + tof_histograms[detector_name] = histogram + export_reduced_data_as_nxlauetof( + detector_name=det_meta.detector_name, + da=histogram, + output_file=config.output.output_file, + compress_mode=config.output.compression, ) - detector_grs = {} - for det_name, det_group in detector_id_map.items(): - display(f"Processing {det_name}") - if config.inputs.chunk_size_events <= 0: - dg = det_group[()] - else: - # Slice the first chunk for metadata extraction - dg = det_group['event_time_zero', 0 : config.inputs.chunk_size_events] - display("Computing detector positions...") - display(dg := _compute_positions(dg, auto_fix_transformations=True)) - detector = build_detector_desc(det_name, dg) - detector_meta = sc.DataGroup( - { - 'fast_axis': detector.fast_axis, - 'slow_axis': detector.slow_axis, - 'origin_position': sc.vector([0, 0, 0], unit='m'), - 'position': detector.position, - 'detector_shape': sc.scalar( - ( - dg['data'].sizes['x_pixel_offset'], - dg['data'].sizes['y_pixel_offset'], - ) - ), - 'x_pixel_size': detector.step_x, - 'y_pixel_size': detector.step_y, - 'detector_name': sc.scalar(detector.name), - } - ) - _export_detector_metadata_as_nxlauetof( - NMXDetectorMetadata(detector_meta), - output_file=output_file_path, - ) - - da: sc.DataArray = dg['data'] - event_time_offset_unit = da.bins.coords['event_time_offset'].bins.unit - display("Event time offset unit: %s", event_time_offset_unit) - toa_bin_edges = toa_bin_edges.to(unit=event_time_offset_unit, copy=False) - if config.inputs.chunk_size_events <= 0: - counts = da.hist(event_time_offset=toa_bin_edges).rename_dims( - x_pixel_offset='x', y_pixel_offset='y', event_time_offset='t' - ) - counts.coords['t'] = counts.coords['event_time_offset'] - - else: - num_chunks = calculate_number_of_chunks( - det_group, chunk_size=config.inputs.chunk_size_events - ) - display(f"Number of chunks: {num_chunks}") - counts = da.hist(event_time_offset=toa_bin_edges).rename_dims( - x_pixel_offset='x', y_pixel_offset='y', event_time_offset='t' - ) - counts.coords['t'] = counts.coords['event_time_offset'] - for chunk_index in range(1, num_chunks): - cur_chunk = det_group[ - 'event_time_zero', - chunk_index * config.inputs.chunk_size_events : ( - chunk_index + 1 - ) - * config.inputs.chunk_size_events, - ] - display(f"Processing chunk {chunk_index + 1} of {num_chunks}") - cur_chunk = _compute_positions( - cur_chunk, auto_fix_transformations=True - ) - cur_counts = ( - cur_chunk['data'] - .hist(event_time_offset=toa_bin_edges) - .rename_dims( - x_pixel_offset='x', - y_pixel_offset='y', - event_time_offset='t', - ) - ) - cur_counts.coords['t'] = cur_counts.coords['event_time_offset'] - counts += cur_counts - display("Accumulated counts:") - display(counts.sum().data) - - dg = sc.DataGroup( - counts=counts, - detector_shape=detector_meta['detector_shape'], - detector_name=detector_meta['detector_name'], - ) - display("Final data group:") - display(dg) - display("Saving reduced data to Nexus file...") - _export_reduced_data_as_nxlauetof( - dg, - output_file=output_file_path, - compress_counts=( - config.output.compression == Compression.BITSHUFFLE_LZ4 - ), - ) - detector_grs[det_name] = dg - - display("Reduction completed successfully.") - histograms = {name: det_gr['counts'] for name, det_gr in detector_grs.items()} - return sc.DataGroup(histogram=sc.DataGroup(histograms)) + + return sc.DataGroup( + metadata=detector_metas, + histogram=tof_histograms, + lookup_table=base_wf.compute(TimeOfFlightLookupTable), + ) def main() -> None: diff --git a/src/ess/nmx/nexus.py b/src/ess/nmx/nexus.py index 1fc5858..eb233b2 100644 --- a/src/ess/nmx/nexus.py +++ b/src/ess/nmx/nexus.py @@ -3,79 +3,31 @@ import io import pathlib import warnings -from functools import wraps from typing import Any import h5py import numpy as np import scipp as sc - -def _fallback_compute_positions(dg: sc.DataGroup) -> sc.DataGroup: - import warnings - - import scippnexus as snx - - warnings.warn( - "Using fallback compute_positions due to empty log entries. " - "This may lead to incorrect results. Please check the data carefully." - "The fallback will replace empty logs with a scalar value of zero.", - UserWarning, - stacklevel=2, - ) - - empty_transformations = [ - transformation - for transformation in dg['depends_on'].transformations.values() - if 'time' in transformation.value.dims - and transformation.sizes['time'] == 0 # empty log - ] - for transformation in empty_transformations: - orig_value = transformation.value - orig_value = sc.scalar(0, unit=orig_value.unit, dtype=orig_value.dtype) - transformation.value = orig_value - return snx.compute_positions(dg, store_transform='transform_matrix') - - -def _compute_positions( - dg: sc.DataGroup, auto_fix_transformations: bool = False -) -> sc.DataGroup: - """Compute positions of the data group from transformations. - - Wraps the `scippnexus.compute_positions` function - and provides a fallback for cases where the transformations - contain empty logs. - - Parameters - ---------- - dg: - Data group containing the transformations and data. - auto_fix_transformations: - If `True`, it will attempt to fix empty transformations. - It will replace them with a scalar value of zero. - It is because adding a time dimension will make it not possible - to compute positions of children due to time-dependent transformations. - - Returns - ------- - : - Data group with computed positions. - - Warnings - -------- - If `auto_fix_transformations` is `True`, it will warn about the fallback - being used due to empty logs or scalar transformations. - This is because the fallback may lead to incorrect results. - - """ - import scippnexus as snx - - try: - return snx.compute_positions(dg, store_transform='transform_matrix') - except ValueError as e: - if auto_fix_transformations: - return _fallback_compute_positions(dg) - raise e +from .configurations import Compression +from .types import ( + NMXDetectorMetadata, + NMXMonitorMetadata, + NMXSampleMetadata, + NMXSourceMetadata, +) + + +def _check_file( + filename: str | pathlib.Path | io.BytesIO, overwrite: bool +) -> pathlib.Path | io.BytesIO: + if isinstance(filename, str | pathlib.Path): + filename = pathlib.Path(filename) + if filename.exists() and not overwrite: + raise FileExistsError( + f"File '{filename}' already exists. Use `overwrite=True` to overwrite." + ) + return filename def _create_dataset_from_string(*, root_entry: h5py.Group, name: str, var: str) -> None: @@ -112,152 +64,29 @@ def _create_dataset_from_var( return dataset -@wraps(_create_dataset_from_var) -def _create_compressed_dataset(*args, **kwargs): - """Create dataset with compression options. - - It will try to use ``bitshuffle`` for compression if available. - Otherwise, it will fall back to ``gzip`` compression. - - [``Bitshuffle/LZ4``](https://github.com/kiyo-masui/bitshuffle) - is used for convenience. - Since ``Dectris`` uses it for their Nexus file compression, - it is compatible with DIALS. - ``Bitshuffle/LZ4`` tends to give similar results to - GZIP and other compression algorithms with better performance. - A naive implementation of bitshuffle/LZ4 compression, - shown in [issue #124](https://github.com/scipp/essnmx/issues/124), - led to 80% file reduction (365 MB vs 1.8 GB). - - """ - try: - import bitshuffle.h5 - - compression_filter = bitshuffle.h5.H5FILTER - default_compression_opts = (0, bitshuffle.h5.H5_COMPRESS_LZ4) - except ImportError: - warnings.warn( - UserWarning( - "Could not find the bitshuffle.h5 module from bitshuffle package. " - "The bitshuffle package is not installed or only partially installed. " - "Exporting to NeXus files with bitshuffle compression is not possible." - ), - stacklevel=2, - ) - compression_filter = "gzip" - default_compression_opts = 4 - - return _create_dataset_from_var( - *args, - **kwargs, - compression=compression_filter, - compression_opts=default_compression_opts, - ) - - -def _create_root_data_entry(file_obj: h5py.File) -> h5py.Group: - nx_entry = file_obj.create_group("NMX_data") - nx_entry.attrs["NX_class"] = "NXentry" - nx_entry.attrs["default"] = "data" - nx_entry.attrs["name"] = "NMX" - nx_entry["name"] = "NMX" - nx_entry["definition"] = "TOFRAW" - return nx_entry - - -def _create_sample_group(data: sc.DataGroup, nx_entry: h5py.Group) -> h5py.Group: - nx_sample = nx_entry.create_group("NXsample") - nx_sample["name"] = data['sample_name'].value - _create_dataset_from_var( - root_entry=nx_sample, - var=data['crystal_rotation'], - name='crystal_rotation', - long_name='crystal rotation in Phi (XYZ)', - ) - return nx_sample - - -def _create_instrument_group(data: sc.DataGroup, nx_entry: h5py.Group) -> h5py.Group: - nx_instrument = nx_entry.create_group("NXinstrument") - nx_instrument.create_dataset("proton_charge", data=data['proton_charge'].values) - - nx_detector_1 = nx_instrument.create_group("detector_1") - # Detector counts - _create_compressed_dataset( - root_entry=nx_detector_1, - name="counts", - var=data['counts'], - ) - # Time of arrival bin edges - _create_dataset_from_var( - root_entry=nx_detector_1, - var=data['counts'].coords['t'], - name="t_bin", - long_name="t_bin TOF (ms)", - ) - # Pixel IDs - _create_compressed_dataset( - root_entry=nx_detector_1, - name="pixel_id", - var=data['counts'].coords['id'], - long_name="pixel ID", - ) - return nx_instrument - - -def _create_detector_group(data: sc.DataGroup, nx_entry: h5py.Group) -> h5py.Group: - nx_detector = nx_entry.create_group("NXdetector") - # Position of the first pixel (lowest ID) in the detector - _create_compressed_dataset( - root_entry=nx_detector, - name="origin", - var=data['origin_position'], - ) - # Fast axis, along where the pixel ID increases by 1 - _create_dataset_from_var( - root_entry=nx_detector, var=data['fast_axis'], name="fast_axis" - ) - # Slow axis, along where the pixel ID increases - # by the number of pixels in the fast axis - _create_dataset_from_var( - root_entry=nx_detector, var=data['slow_axis'], name="slow_axis" - ) - return nx_detector - - -def _create_source_group(data: sc.DataGroup, nx_entry: h5py.Group) -> h5py.Group: - nx_source = nx_entry.create_group("NXsource") - nx_source["name"] = "European Spallation Source" - nx_source["short_name"] = "ESS" - nx_source["type"] = "Spallation Neutron Source" - nx_source["distance"] = sc.norm(data['source_position']).value - nx_source["probe"] = "neutron" - nx_source["target_material"] = "W" - return nx_source - - -def export_as_nexus( - data: sc.DataGroup, output_file: str | pathlib.Path | io.BytesIO -) -> None: - """Export the reduced data to a NeXus file. +def _retrieve_compression_arguments(compress_mode: Compression) -> dict: + if compress_mode == Compression.BITSHUFFLE_LZ4: + try: + import bitshuffle.h5 + + compression_filter = bitshuffle.h5.H5FILTER + compression_opts = (0, bitshuffle.h5.H5_COMPRESS_LZ4) + except ImportError: + warnings.warn( + UserWarning( + "Could not find the bitshuffle.h5 module from bitshuffle package. " + "The bitshuffle package is not installed properly. " + "Trying with gzip compression instead..." + ), + stacklevel=2, + ) + compression_filter = "gzip" + compression_opts = 4 + else: + compression_filter = None + compression_opts = None - Currently exporting step is not expected to be part of sciline pipelines. - """ - warnings.warn( - DeprecationWarning( - "Exporting to custom NeXus format will be deprecated in the near future " - ">=26.12.0. " - "Please use ``export_as_nxlauetof`` instead." - ), - stacklevel=2, - ) - with h5py.File(output_file, "w") as f: - f.attrs["default"] = "NMX_data" - nx_entry = _create_root_data_entry(f) - _create_sample_group(data, nx_entry) - _create_instrument_group(data, nx_entry) - _create_detector_group(data, nx_entry) - _create_source_group(data, nx_entry) + return {"compression": compression_filter, "compression_opts": compression_opts} def _create_lauetof_data_entry(file_obj: h5py.File) -> h5py.Group: @@ -277,7 +106,9 @@ def _add_lauetof_instrument(nx_entry: h5py.Group) -> h5py.Group: return nx_instrument -def _add_lauetof_source_group(dg, nx_instrument: h5py.Group) -> None: +def _add_lauetof_source_group( + source_position: sc.Variable, nx_instrument: h5py.Group +) -> None: nx_source = nx_instrument.create_group("source") nx_source.attrs["NX_class"] = "NXsource" _create_dataset_from_string( @@ -288,106 +119,71 @@ def _add_lauetof_source_group(dg, nx_instrument: h5py.Group) -> None: root_entry=nx_source, name="type", var="Spallation Neutron Source" ) _create_dataset_from_var( - root_entry=nx_source, name="distance", var=sc.norm(dg["source_position"]) + root_entry=nx_source, name="distance", var=sc.norm(source_position) ) # Legacy probe information. _create_dataset_from_string(root_entry=nx_source, name="probe", var="neutron") -def _add_lauetof_detector_group(dg: sc.DataGroup, nx_instrument: h5py.Group) -> None: - nx_detector = nx_instrument.create_group(dg["detector_name"].value) # Detector name - nx_detector.attrs["NX_class"] = "NXdetector" - _create_dataset_from_var( - name="polar_angle", - root_entry=nx_detector, - var=sc.scalar(0, unit='deg'), # TODO: Add real data - ) - _create_dataset_from_var( - name="azimuthal_angle", - root_entry=nx_detector, - var=sc.scalar(0, unit='deg'), # TODO: Add real data - ) - _create_dataset_from_var( - name="x_pixel_size", root_entry=nx_detector, var=dg["x_pixel_size"] - ) - _create_dataset_from_var( - name="y_pixel_size", root_entry=nx_detector, var=dg["y_pixel_size"] - ) +def _add_lauetof_detector_group( + *, + detector_name: str, + x_pixel_size: sc.Variable, + y_pixel_size: sc.Variable, + origin_position: sc.Variable, + fast_axis: sc.Variable, + slow_axis: sc.Variable, + distance: sc.Variable, + polar_angle: sc.Variable, + azimuthal_angle: sc.Variable, + nx_instrument: h5py.Group, +) -> None: + nx_det = nx_instrument.create_group(detector_name) # Detector name + nx_det.attrs["NX_class"] = "NXdetector" + _create_dataset_from_var(name="polar_angle", root_entry=nx_det, var=polar_angle) _create_dataset_from_var( - name="distance", - root_entry=nx_detector, - var=sc.scalar(0, unit='m'), # TODO: Add real data + name="azimuthal_angle", root_entry=nx_det, var=azimuthal_angle ) + _create_dataset_from_var(name="x_pixel_size", root_entry=nx_det, var=x_pixel_size) + _create_dataset_from_var(name="y_pixel_size", root_entry=nx_det, var=y_pixel_size) + _create_dataset_from_var(name="distance", root_entry=nx_det, var=distance) # Legacy geometry information until we have a better way to store it - _create_dataset_from_var( - name="origin", root_entry=nx_detector, var=dg['origin_position'] - ) + _create_dataset_from_var(name="origin", root_entry=nx_det, var=origin_position) # Fast axis, along where the pixel ID increases by 1 - _create_dataset_from_var( - root_entry=nx_detector, var=dg['fast_axis'], name="fast_axis" - ) + _create_dataset_from_var(root_entry=nx_det, name="fast_axis", var=fast_axis) # Slow axis, along where the pixel ID increases # by the number of pixels in the fast axis - _create_dataset_from_var( - root_entry=nx_detector, var=dg['slow_axis'], name="slow_axis" - ) + _create_dataset_from_var(root_entry=nx_det, name="slow_axis", var=slow_axis) -def _add_lauetof_sample_group(dg, nx_entry: h5py.Group) -> None: +def _add_lauetof_sample_group( + *, + crystal_rotation: sc.Variable, + sample_name: str | sc.Variable, + sample_orientation_matrix: sc.Variable, + sample_unit_cell: sc.Variable, + nx_entry: h5py.Group, +) -> None: nx_sample = nx_entry.create_group("sample") nx_sample.attrs["NX_class"] = "NXsample" _create_dataset_from_var( root_entry=nx_sample, - var=dg['crystal_rotation'], + var=crystal_rotation, name='crystal_rotation', long_name='crystal rotation in Phi (XYZ)', ) _create_dataset_from_string( root_entry=nx_sample, name='name', - var=dg['sample_name'].value, + var=sample_name if isinstance(sample_name, str) else sample_name.value, ) _create_dataset_from_var( - name='orientation_matrix', - root_entry=nx_sample, - var=sc.array( - dims=['i', 'j'], - values=[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], - unit="dimensionless", - ), # TODO: Add real data, the sample orientation matrix + name='orientation_matrix', root_entry=nx_sample, var=sample_orientation_matrix ) _create_dataset_from_var( name='unit_cell', root_entry=nx_sample, - var=sc.array( - dims=['i'], - values=[1.0, 1.0, 1.0, 90.0, 90.0, 90.0], - unit="dimensionless", # TODO: Add real data, - # a, b, c, alpha, beta, gamma - ), - ) - - -def _add_lauetof_monitor_group(data: sc.DataGroup, nx_entry: h5py.Group) -> None: - nx_monitor = nx_entry.create_group("control") - nx_monitor.attrs["NX_class"] = "NXmonitor" - _create_dataset_from_string(root_entry=nx_monitor, name='mode', var='monitor') - nx_monitor["preset"] = 0.0 # Check if this is the correct value - data_dset = _create_dataset_from_var( - name='data', - root_entry=nx_monitor, - var=sc.array( - dims=['tof'], values=[1, 1, 1], unit="counts" - ), # TODO: Add real data, bin values - ) - data_dset.attrs["signal"] = 1 - data_dset.attrs["primary"] = 1 - _create_dataset_from_var( - name='time_of_flight', - root_entry=nx_monitor, - var=sc.array( - dims=['tof'], values=[1, 1, 1], unit="s" - ), # TODO: Add real data, bin edges + var=sample_unit_cell, ) @@ -413,8 +209,9 @@ def _add_arbitrary_metadata( ) -def _export_static_metadata_as_nxlauetof( - experiment_metadata, +def export_static_metadata_as_nxlauetof( + sample_metadata: NMXSampleMetadata, + source_metadata: NMXSourceMetadata, output_file: str | pathlib.Path | io.BytesIO, **arbitrary_metadata: sc.Variable, ) -> None: @@ -428,30 +225,70 @@ def _export_static_metadata_as_nxlauetof( Parameters ---------- - experiment_metadata: - Experiment metadata object. + sample_metadata: + Sample metadata object. + source_metadata: + Source metadata object. + monitor_metadata: + Monitor metadata object. output_file: Output file path. arbitrary_metadata: Arbitrary metadata that does not fit into the existing metadata objects. """ + _check_file(output_file, overwrite=True) with h5py.File(output_file, "w") as f: f.attrs["NX_class"] = "NXlauetof" nx_entry = _create_lauetof_data_entry(f) _add_lauetof_definition(nx_entry) - _add_lauetof_sample_group(experiment_metadata, nx_entry) + _add_lauetof_sample_group( + crystal_rotation=sample_metadata.crystal_rotation, + sample_name=sample_metadata.sample_name, + sample_orientation_matrix=sample_metadata.sample_orientation_matrix, + sample_unit_cell=sample_metadata.sample_unit_cell, + nx_entry=nx_entry, + ) nx_instrument = _add_lauetof_instrument(nx_entry) - _add_lauetof_source_group(experiment_metadata, nx_instrument) - # Placeholder for ``monitor`` group - _add_lauetof_monitor_group(experiment_metadata, nx_entry) + _add_lauetof_source_group(source_metadata.source_position, nx_instrument) # Skipping ``NXdata``(name) field with data link # Add arbitrary metadata _add_arbitrary_metadata(nx_entry, **arbitrary_metadata) -def _export_detector_metadata_as_nxlauetof( - *detector_metadatas, +def export_monitor_metadata_as_nxlauetof( + monitor_metadata: NMXMonitorMetadata, + output_file: str | pathlib.Path | io.BytesIO, + append_mode: bool = True, +) -> None: + """Export the detector specific metadata to a NeXus file. + + Since NMX can have arbitrary number of detectors, + this function can take multiple detector metadata objects. + + Parameters + ---------- + monitor_metadata: + Monitor metadata object. + output_file: + Output file path. + + """ + if not append_mode: + raise NotImplementedError("Only append mode is supported for now.") + + with h5py.File(output_file, "r+") as f: + nx_entry = f["entry"] + # Placeholder for ``monitor`` group + _add_lauetof_monitor_group( + tof_bin_coord=monitor_metadata.tof_bin_coord, + monitor_histogram=monitor_metadata.monitor_histogram, + nx_entry=nx_entry, + ) + + +def export_detector_metadata_as_nxlauetof( + detector_metadata: NMXDetectorMetadata, output_file: str | pathlib.Path | io.BytesIO, append_mode: bool = True, ) -> None: @@ -478,27 +315,54 @@ def _export_detector_metadata_as_nxlauetof( nx_instrument = _add_lauetof_instrument(f["entry"]) else: nx_instrument = nx_entry["instrument"] + # Add detector group metadata - for detector_metadata in detector_metadatas: - _add_lauetof_detector_group(detector_metadata, nx_instrument) + _add_lauetof_detector_group( + detector_name=detector_metadata.detector_name, + x_pixel_size=detector_metadata.x_pixel_size, + y_pixel_size=detector_metadata.y_pixel_size, + origin_position=detector_metadata.origin_position, + fast_axis=detector_metadata.fast_axis, + slow_axis=detector_metadata.slow_axis, + distance=detector_metadata.distance, + polar_angle=detector_metadata.polar_angle, + azimuthal_angle=detector_metadata.azimuthal_angle, + nx_instrument=nx_instrument, + ) -def _extract_counts(dg: sc.DataGroup) -> sc.Variable: - counts: sc.DataArray = dg['counts'].data - if 'id' in counts.dims: - num_x, num_y = dg["detector_shape"].value - return sc.fold(counts, dim='id', sizes={'x': num_x, 'y': num_y}) - else: - # If there is no 'id' dimension, we assume it is already in the correct shape - return counts +def _add_lauetof_monitor_group( + *, + tof_bin_coord: str, + monitor_histogram: sc.DataArray, + nx_entry: h5py.Group, +) -> None: + nx_monitor = nx_entry.create_group("control") + nx_monitor.attrs["NX_class"] = "NXmonitor" + _create_dataset_from_string(root_entry=nx_monitor, name='mode', var='monitor') + nx_monitor["preset"] = 0.0 # Check if this is the correct value + data_dset = _create_dataset_from_var( + name='data', + root_entry=nx_monitor, + var=monitor_histogram.data, + ) + data_dset.attrs["signal"] = 1 + data_dset.attrs["primary"] = 1 + + _create_dataset_from_var( + name='time_of_flight', + root_entry=nx_monitor, + var=monitor_histogram.coords[tof_bin_coord], + ) -def _export_reduced_data_as_nxlauetof( - dg, +def export_reduced_data_as_nxlauetof( + detector_name: str, + da: sc.DataArray, output_file: str | pathlib.Path | io.BytesIO, *, append_mode: bool = True, - compress_counts: bool = True, + compress_mode: Compression = Compression.BITSHUFFLE_LZ4, ) -> None: """Export the reduced data to a NeXus file with the LAUE_TOF application definition. @@ -527,29 +391,30 @@ def _export_reduced_data_as_nxlauetof( raise NotImplementedError("Only append mode is supported for now.") with h5py.File(output_file, "r+") as f: - nx_detector: h5py.Group = f[f"entry/instrument/{dg['detector_name'].value}"] + nx_detector: h5py.Group = f[f"entry/instrument/{detector_name}"] # Data - shape: [n_x_pixels, n_y_pixels, n_tof_bins] # The actual application definition defines it as integer, - # but we keep the original data type for now - num_x, num_y = dg["detector_shape"].value # Probably better way to do this - if compress_counts: - data_dset = _create_compressed_dataset( + # so we overwrite the dtype here. + num_x, num_y = da.sizes['x_pixel_offset'], da.sizes['y_pixel_offset'] + + if compress_mode != Compression.NONE: + compression_args = _retrieve_compression_arguments(compress_mode) + data_dset = _create_dataset_from_var( name="data", root_entry=nx_detector, - var=_extract_counts(dg), - chunks=(num_x, num_y, 1), + var=da.data, + chunks=(num_x, num_y, 1), # Chunk along tof axis dtype=np.uint, + **compression_args, ) else: data_dset = _create_dataset_from_var( - name="data", - root_entry=nx_detector, - var=_extract_counts(dg), - dtype=np.uint, + name="data", root_entry=nx_detector, var=da.data, dtype=np.uint ) + data_dset.attrs["signal"] = 1 _create_dataset_from_var( name='time_of_flight', root_entry=nx_detector, - var=sc.midpoints(dg['counts'].coords['t'], dim='t'), + var=sc.midpoints(da.coords['tof'], dim='tof'), ) diff --git a/src/ess/nmx/types.py b/src/ess/nmx/types.py index ad4d3de..0336ec7 100644 --- a/src/ess/nmx/types.py +++ b/src/ess/nmx/types.py @@ -1,4 +1,8 @@ import enum +from dataclasses import dataclass, field +from typing import NewType + +import scipp as sc class Compression(enum.StrEnum): @@ -9,3 +13,67 @@ class Compression(enum.StrEnum): NONE = 'NONE' BITSHUFFLE_LZ4 = 'BITSHUFFLE_LZ4' + + +TofSimulationMinWavelength = NewType("TofSimulationMinWavelength", sc.Variable) +"""Minimum wavelength for tof simulation to calculate look up table.""" + +TofSimulationMaxWavelength = NewType("TofSimulationMaxWavelength", sc.Variable) +"""Maximum wavelength for tof simulation to calculate look up table.""" + + +@dataclass(kw_only=True) +class NMXSampleMetadata: + crystal_rotation: sc.Variable + sample_position: sc.Variable + sample_name: sc.Variable | str + # Temporarily hardcoding some values + # TODO: Remove hardcoded values + sample_orientation_matrix: sc.Variable = field( + default_factory=lambda: sc.array( + dims=['i', 'j'], + values=[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + unit="dimensionless", + ) + ) + sample_unit_cell: sc.Variable = field( + default_factory=lambda: sc.array( + dims=['i'], + values=[1.0, 1.0, 1.0, 90.0, 90.0, 90.0], + unit="dimensionless", # TODO: Add real data, + # a, b, c, alpha, beta, gamma + ) + ) + + +@dataclass(kw_only=True) +class NMXSourceMetadata: + source_position: sc.Variable + + +@dataclass(kw_only=True) +class NMXMonitorMetadata: + monitor_histogram: sc.DataArray + tof_bin_coord: str = field( + default='tof', + metadata={ + "description": "Name of the time-of-flight coordinate " + "in the monitor histogram." + }, + ) + + +@dataclass(kw_only=True) +class NMXDetectorMetadata: + detector_name: str + x_pixel_size: sc.Variable + y_pixel_size: sc.Variable + origin_position: sc.Variable + fast_axis: sc.Variable + slow_axis: sc.Variable + distance: sc.Variable + # TODO: Remove hardcoded values + polar_angle: sc.Variable = field(default_factory=lambda: sc.scalar(0, unit='deg')) + azimuthal_angle: sc.Variable = field( + default_factory=lambda: sc.scalar(0, unit='deg') + ) diff --git a/src/ess/nmx/workflows.py b/src/ess/nmx/workflows.py new file mode 100644 index 0000000..d37517f --- /dev/null +++ b/src/ess/nmx/workflows.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import pathlib +from collections.abc import Callable, Iterable + +import pandas as pd +import sciline +import scipp as sc +import scippnexus as snx +import tof + +from ess.reduce.nexus.types import ( + EmptyDetector, + Filename, + NeXusComponent, + NeXusName, + NeXusTransformation, + Position, + SampleRun, +) +from ess.reduce.time_of_flight import ( + DetectorLtotal, + GenericTofWorkflow, + LtotalRange, + NumberOfSimulatedNeutrons, + SimulationResults, + SimulationSeed, + TofLookupTableWorkflow, +) +from ess.reduce.time_of_flight.types import ( + TimeOfFlightLookupTable, + TimeOfFlightLookupTableFilename, +) +from ess.reduce.workflow import register_workflow + +from .configurations import WorkflowConfig +from .types import ( + NMXDetectorMetadata, + NMXSampleMetadata, + NMXSourceMetadata, + TofSimulationMaxWavelength, + TofSimulationMinWavelength, +) + +default_parameters = { + TofSimulationMaxWavelength: sc.scalar(3.6, unit='angstrom'), + TofSimulationMinWavelength: sc.scalar(1.8, unit='angstrom'), +} + + +def _validate_mergable_workflow(wf: sciline.Pipeline): + if wf.indices: + raise NotImplementedError("Only flat workflow can be merged.") + + +def _merge_workflows( + base_wf: sciline.Pipeline, merged_wf: sciline.Pipeline +) -> sciline.Pipeline: + _validate_mergable_workflow(base_wf) + _validate_mergable_workflow(merged_wf) + + for key, spec in merged_wf.underlying_graph.nodes.items(): + if 'value' in spec: + base_wf[key] = spec['value'] + elif (provider_spec := spec.get('provider')) is not None: + base_wf.insert(provider_spec.func) + + return base_wf + + +def _simulate_fixed_wavelength_tof( + wmin: TofSimulationMinWavelength, + wmax: TofSimulationMaxWavelength, + ltotal_range: LtotalRange, + neutrons: NumberOfSimulatedNeutrons, + seed: SimulationSeed, +) -> SimulationResults: + """ + Simulate a pulse of neutrons propagating through a chopper cascade using the + ``tof`` package (https://tof.readthedocs.io). + + Parameters + ---------- + """ + source = tof.Source( + facility="ess", neutrons=neutrons, pulses=2, seed=seed, wmax=wmax, wmin=wmin + ) + nmx_det = tof.Detector(distance=max(ltotal_range), name="detector") + model = tof.Model(source=source, choppers=[], detectors=[nmx_det]) + results = model.run() + events = results["detector"].data.squeeze().flatten(to="event") + return SimulationResults( + time_of_arrival=events.coords["toa"], + speed=events.coords["speed"], + wavelength=events.coords["wavelength"], + weight=events.data, + distance=results["detector"].distance, + ) + + +def _ltotal_range(detector_ltotal: DetectorLtotal[SampleRun]) -> LtotalRange: + margin = sc.scalar(0.5, unit='m').to( + unit=detector_ltotal.unit + ) # Hardcoded margin of 50 cm. It's because the detector width is ~50 cm. + ltotal_min = sc.min(detector_ltotal) - margin + ltotal_max = sc.max(detector_ltotal) + margin + return LtotalRange((ltotal_min, ltotal_max)) + + +def patch_workflow_lookup_table_steps(*, wf: sciline.Pipeline) -> sciline.Pipeline: + patched_wf = wf.copy() + + # Use TofLookupTableWorkflow + patched_wf = _merge_workflows(patched_wf, TofLookupTableWorkflow()) + patched_wf.insert(_simulate_fixed_wavelength_tof) + patched_wf.insert(_ltotal_range) + return patched_wf + + +def _merge_panels(*da: sc.DataArray) -> sc.DataArray: + """Merge multiple DataArrays representing different panels into one.""" + merged = sc.concat(da, dim='panel') + return merged + + +def select_detector_names( + *, + input_files: list[pathlib.Path] | None = None, + detector_ids: Iterable[int] = (0, 1, 2), +): + if input_files is not None: + detector_names = [] + # Collect all detector names from input files + for input_file in input_files: + with snx.File(input_file) as nexus_file: + detector_names.extend( + nexus_file['entry/instrument'][snx.NXdetector].keys() + ) + detector_names = sorted(set(detector_names)) + return [detector_names[i_d] for i_d in detector_ids] + else: + return ['detector_panel_0', 'detector_panel_1', 'detector_panel_2'] + + +def map_detector_names( + *, + wf: sciline.Pipeline, + detector_names: Iterable[str], + mapped_type: type, + reduce_func: Callable = _merge_panels, +) -> sciline.Pipeline: + """Map detector indices(`panel`) to detector names in the workflow.""" + detector_name_map = pd.DataFrame({NeXusName[snx.NXdetector]: detector_names}) + detector_name_map.rename_axis(index='panel', inplace=True) + wf[mapped_type] = wf[mapped_type].map(detector_name_map).reduce(func=reduce_func) + return wf + + +def assemble_sample_metadata( + crystal_rotation: Position[snx.NXcrystal, SampleRun], + sample_position: Position[snx.NXsample, SampleRun], + sample_component: NeXusComponent[snx.NXsample, SampleRun], +) -> NMXSampleMetadata: + """Assemble sample metadata for NMX reduction workflow.""" + return NMXSampleMetadata( + sample_name=sample_component['name'], + crystal_rotation=crystal_rotation, + sample_position=sample_position, + ) + + +def assemble_source_metadata( + source_position: Position[snx.NXsource, SampleRun], +) -> NMXSourceMetadata: + """Assemble source metadata for NMX reduction workflow.""" + return NMXSourceMetadata(source_position=source_position) + + +def _decide_fast_axis(da: sc.DataArray) -> str: + x_slice = da['x_pixel_offset', 0].coords['detector_number'] + y_slice = da['y_pixel_offset', 0].coords['detector_number'] + + if (x_slice.max() < y_slice.max()).value: + return 'y' + elif (x_slice.max() > y_slice.max()).value: + return 'x' + else: + raise ValueError( + "Cannot decide fast axis based on pixel offsets. " + "Please specify the fast axis explicitly." + ) + + +def _decide_step(offsets: sc.Variable) -> sc.Variable: + """Decide the step size based on the offsets assuming at least 2 values.""" + sorted_offsets = sc.sort(offsets, key=offsets.dim, order='ascending') + return sorted_offsets[1] - sorted_offsets[0] + + +def _normalize_vector(vec: sc.Variable) -> sc.Variable: + return vec / sc.norm(vec) + + +def _retrieve_crystal_rotation( + file_path: Filename[SampleRun], +) -> Position[snx.NXcrystal, SampleRun]: + """Temporary provider to retrieve crystal rotation from Nexus file.""" + with snx.File(file_path) as file: + if 'crystal_rotation' not in file['entry/sample']: + import warnings + + warnings.warn( + "No crystal rotation found in the Nexus file under " + "'entry/sample/crystal_rotation'. Returning zero rotation.", + RuntimeWarning, + stacklevel=1, + ) + return Position[snx.NXcrystal, SampleRun](sc.vector([0, 0, 0], unit='deg')) + + # Temporary way of storing crystal rotation. + # streaming-sample-mcstas module writes crystal rotation under + # 'entry/sample/crystal_rotation' as an array of three values. + return Position[snx.NXcrystal, SampleRun]( + file['entry/sample/crystal_rotation'][()] + ) + + +def assemble_detector_metadata( + detector_component: NeXusComponent[snx.NXdetector, SampleRun], + transformation: NeXusTransformation[snx.NXdetector, SampleRun], + source_position: Position[snx.NXsource, SampleRun], + empty_detector: EmptyDetector[SampleRun], +) -> NMXDetectorMetadata: + """Assemble detector metadata for NMX reduction workflow.""" + first_id = empty_detector.coords['detector_number'].min() + # Assuming `empty_detector` has (`x_pixel_offset`, `y_pixel_offset`) dims + origin = empty_detector.flatten(dims=empty_detector.dims, to='detector_number')[ + 'detector_number', first_id + ].coords['position'] + _fast_axis = _decide_fast_axis(empty_detector) + t_unit = transformation.value.unit + + fast_axis_vector = transformation.value * ( + sc.vector([1.0, 0, 0], unit=t_unit) + if _fast_axis == 'x' + else sc.vector([0.0, 1, 0], unit=t_unit) + ) + slow_axis_vector = transformation.value * ( + sc.vector([0.0, 1, 0], unit=t_unit) + if _fast_axis == 'x' + else sc.vector([1.0, 0, 0], unit=t_unit) + ) + x_pixel_size = _decide_step(empty_detector.coords['x_pixel_offset']) + y_pixel_size = _decide_step(empty_detector.coords['y_pixel_offset']) + distance = sc.norm(origin - source_position.to(unit=origin.unit)) + + return NMXDetectorMetadata( + detector_name=detector_component['nexus_component_name'], + x_pixel_size=x_pixel_size, + y_pixel_size=y_pixel_size, + origin_position=origin, + fast_axis=_normalize_vector(fast_axis_vector), + slow_axis=_normalize_vector(slow_axis_vector), + distance=distance, + ) + + +@register_workflow +def NMXWorkflow() -> sciline.Pipeline: + generic_wf = GenericTofWorkflow(run_types=[SampleRun], monitor_types=[]) + + generic_wf.insert(_retrieve_crystal_rotation) + generic_wf.insert(assemble_sample_metadata) + generic_wf.insert(assemble_source_metadata) + generic_wf.insert(assemble_detector_metadata) + for key, value in default_parameters.items(): + generic_wf[key] = value + + return generic_wf + + +def compute_lookup_table( + *, + base_wf: sciline.Pipeline, + workflow_config: WorkflowConfig, + detector_names: Iterable[str], +) -> sc.DataArray: + wf = base_wf.copy() + if workflow_config.tof_lookup_table_file_path is not None: + wf[TimeOfFlightLookupTableFilename] = workflow_config.tof_lookup_table_file_path + else: + wf = patch_workflow_lookup_table_steps(wf=wf) + wmax = sc.scalar(workflow_config.tof_simulation_max_wavelength, unit='angstrom') + wmin = sc.scalar(workflow_config.tof_simulation_min_wavelength, unit='angstrom') + wf[TofSimulationMaxWavelength] = wmax + wf[TofSimulationMinWavelength] = wmin + wf[SimulationSeed] = workflow_config.tof_simulation_seed + wf = map_detector_names( + wf=wf, + detector_names=detector_names, + mapped_type=DetectorLtotal[SampleRun], + reduce_func=_merge_panels, + ) + + return wf.compute(TimeOfFlightLookupTable) + + +__all__ = ['NMXWorkflow'] diff --git a/tests/executable_test.py b/tests/executable_test.py index 6c42879..e3560ac 100644 --- a/tests/executable_test.py +++ b/tests/executable_test.py @@ -3,25 +3,25 @@ import pathlib import subprocess +from contextlib import contextmanager from enum import Enum import pydantic import pytest import scipp as sc import scippnexus as snx -from scipp.testing import assert_allclose from ess.nmx._executable_helper import ( InputConfig, OutputConfig, ReductionConfig, - TimeBinCoordinate, - TimeBinUnit, WorkflowConfig, build_reduction_argument_parser, reduction_config_from_args, to_command_arguments, ) +from ess.nmx.configurations import TimeBinCoordinate, TimeBinUnit +from ess.nmx.executables import reduction from ess.nmx.types import Compression @@ -127,9 +127,7 @@ def small_nmx_nexus_path(): return get_small_nmx_nexus() -def _check_output_file( - output_file_path: pathlib.Path, expected_toa_output: sc.Variable -): +def _check_output_file(output_file_path: pathlib.Path, nbins: int): detector_names = [f'detector_panel_{i}' for i in range(3)] with snx.File(output_file_path, 'r') as f: # Test @@ -137,7 +135,7 @@ def _check_output_file( det_gr = f[f'entry/instrument/{name}'] assert det_gr is not None toa_edges = det_gr['time_of_flight'][()] - assert_allclose(toa_edges, expected_toa_output) + assert len(toa_edges) == nbins def test_executable_runs(small_nmx_nexus_path, tmp_path: pathlib.Path): @@ -147,15 +145,6 @@ def test_executable_runs(small_nmx_nexus_path, tmp_path: pathlib.Path): nbins = 20 # Small number of bins for testing. # The output has 1280x1280 pixels per detector per time bin. - expected_toa_bins = sc.linspace( - dim='dim_0', - start=2, # Unrealistic number for testing - stop=int((1 / 15) * 1_000), # Unrealistic number for testing - num=nbins + 1, - unit='ms', - ) - expected_toa_output = sc.midpoints(expected_toa_bins, dim='dim_0').to(unit='ns') - commands = ( 'essnmx-reduce', '--input-file', @@ -164,18 +153,61 @@ def test_executable_runs(small_nmx_nexus_path, tmp_path: pathlib.Path): str(nbins), '--output-file', output_file.as_posix(), - '--min-time-bin', - str(int(expected_toa_bins.min().value)), - '--max-time-bin', - str(int(expected_toa_bins.max().value)), ) # Validate that all commands are strings and contain no unsafe characters result = subprocess.run( # noqa: S603 - We are not accepting arbitrary input here. - commands, - text=True, - capture_output=True, - check=False, + commands, text=True, capture_output=True, check=False ) assert result.returncode == 0 assert output_file.exists() - _check_output_file(output_file, expected_toa_output=expected_toa_output) + _check_output_file(output_file, nbins=nbins) + + +@contextmanager +def known_warnings(): + with pytest.warns(RuntimeWarning, match="No crystal rotation*"): + yield + + +@pytest.fixture +def temp_output_file(tmp_path: pathlib.Path): + output_file_path = tmp_path / "scipp_output.h5" + yield output_file_path + if output_file_path.exists(): + output_file_path.unlink() + + +@pytest.fixture +def reduction_config( + small_nmx_nexus_path: pathlib.Path, temp_output_file: pathlib.Path +) -> ReductionConfig: + input_config = InputConfig(input_file=[small_nmx_nexus_path.as_posix()]) + # Compression option is not default (NONE) but + # the actual default compression option, BITSHUFFLE_LZ4, + # only properly works in linux so we set it to NONE here + # for convenience of testing on all platforms. + output_config = OutputConfig( + output_file=temp_output_file.as_posix(), compression=Compression.NONE + ) + return ReductionConfig(inputs=input_config, output=output_config) + + +def _retrieve_one_hist(results: sc.DataGroup) -> sc.DataArray: + """Helper to retrieve the first DataArray from the results dictionary.""" + return results['histogram']['detector_panel_0'] + + +def test_reduction_default_settings(reduction_config: ReductionConfig) -> None: + # Only check if reduction runs without errors with default settings. + with known_warnings(): + reduction(config=reduction_config) + + +def test_reduction_only_number_of_time_bins(reduction_config: ReductionConfig) -> None: + reduction_config.workflow.nbins = 20 + reduction_config.workflow.time_bin_coordinate = TimeBinCoordinate.time_of_flight + with known_warnings(): + hist = _retrieve_one_hist(reduction(config=reduction_config)) + + # Check that the number of time bins is as expected. + assert len(hist.coords['tof']) == 21 # nbins + 1 edges From 6bf8310fc118e9cec16af23b9325df717593d912 Mon Sep 17 00:00:00 2001 From: YooSunyoung <17974113+YooSunYoung@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:21:24 +0100 Subject: [PATCH 2/6] Update input file pattern retrieval routine. --- src/ess/nmx/executables.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/ess/nmx/executables.py b/src/ess/nmx/executables.py index 2d83cb4..4cd16d7 100644 --- a/src/ess/nmx/executables.py +++ b/src/ess/nmx/executables.py @@ -33,16 +33,23 @@ from .workflows import NMXWorkflow, compute_lookup_table, select_detector_names -def _retrieve_input_file(input_file: list[pathlib.Path] | pathlib.Path) -> pathlib.Path: +def _retrieve_input_file(input_file: list[str]) -> pathlib.Path: """Temporary helper to retrieve a single input file from the list Until multiple input file support is implemented. """ - if isinstance(input_file, list) and len(input_file) != 1: - raise NotImplementedError( - "Currently, only a single input file is supported for reduction." - ) - elif isinstance(input_file, list): - input_file_path = input_file[0] + if isinstance(input_file, list): + input_files = collect_matching_input_files(*input_file) + if len(input_files) == 0: + raise ValueError( + "No input files found for reduction." + "Check if the file paths are correct.", + input_file, + ) + elif len(input_files) > 1: + raise NotImplementedError( + "Currently, only a single input file is supported for reduction." + ) + input_file_path = input_files[0] else: input_file_path = input_file @@ -132,9 +139,7 @@ def reduction( """ display = _retrieve_display(logger, display) - input_file_path = _retrieve_input_file( - collect_matching_input_files(*config.inputs.input_file) - ).resolve() + input_file_path = _retrieve_input_file(config.inputs.input_file).resolve() display(f"Input file: {input_file_path}") output_file_path = pathlib.Path(config.output.output_file).resolve() From 26c8987bcd3d396488833ed89ed3e01e15a3e7da Mon Sep 17 00:00:00 2001 From: YooSunyoung <17974113+YooSunYoung@users.noreply.github.com> Date: Fri, 5 Dec 2025 14:58:50 +0100 Subject: [PATCH 3/6] Fix origin of detector. --- src/ess/nmx/workflows.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/ess/nmx/workflows.py b/src/ess/nmx/workflows.py index d37517f..a794623 100644 --- a/src/ess/nmx/workflows.py +++ b/src/ess/nmx/workflows.py @@ -232,11 +232,8 @@ def assemble_detector_metadata( empty_detector: EmptyDetector[SampleRun], ) -> NMXDetectorMetadata: """Assemble detector metadata for NMX reduction workflow.""" - first_id = empty_detector.coords['detector_number'].min() - # Assuming `empty_detector` has (`x_pixel_offset`, `y_pixel_offset`) dims - origin = empty_detector.flatten(dims=empty_detector.dims, to='detector_number')[ - 'detector_number', first_id - ].coords['position'] + # Origin should be the center of the detector. + origin = empty_detector.coords['position'].mean() _fast_axis = _decide_fast_axis(empty_detector) t_unit = transformation.value.unit From 488f108959c3c16d04f943a8cb793122eb69d592 Mon Sep 17 00:00:00 2001 From: YooSunyoung <17974113+YooSunYoung@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:12:23 +0100 Subject: [PATCH 4/6] Calculate fast/slow axis. --- src/ess/nmx/workflows.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/ess/nmx/workflows.py b/src/ess/nmx/workflows.py index a794623..d0af71c 100644 --- a/src/ess/nmx/workflows.py +++ b/src/ess/nmx/workflows.py @@ -232,21 +232,20 @@ def assemble_detector_metadata( empty_detector: EmptyDetector[SampleRun], ) -> NMXDetectorMetadata: """Assemble detector metadata for NMX reduction workflow.""" + positions = empty_detector.coords['position'] # Origin should be the center of the detector. - origin = empty_detector.coords['position'].mean() + origin = positions.mean() _fast_axis = _decide_fast_axis(empty_detector) + _slow_axis = 'y' if _fast_axis == 'x' else 'x' t_unit = transformation.value.unit - fast_axis_vector = transformation.value * ( - sc.vector([1.0, 0, 0], unit=t_unit) - if _fast_axis == 'x' - else sc.vector([0.0, 1, 0], unit=t_unit) - ) - slow_axis_vector = transformation.value * ( - sc.vector([0.0, 1, 0], unit=t_unit) - if _fast_axis == 'x' - else sc.vector([1.0, 0, 0], unit=t_unit) - ) + axis_vectors = { + 'x': positions['x_pixel_offset', -1] - positions['x_pixel_offset', 0], + 'y': positions['y_pixel_offset', -1] - positions['y_pixel_offset', 0], + } + + fast_axis_vector = axis_vectors[_fast_axis].to(unit=t_unit) + slow_axis_vector = axis_vectors[_slow_axis].to(unit=t_unit) x_pixel_size = _decide_step(empty_detector.coords['x_pixel_offset']) y_pixel_size = _decide_step(empty_detector.coords['y_pixel_offset']) distance = sc.norm(origin - source_position.to(unit=origin.unit)) From 0813b1fa8b834de285b7f18b140d9b3f1a1b1ba4 Mon Sep 17 00:00:00 2001 From: YooSunyoung <17974113+YooSunYoung@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:14:24 +0100 Subject: [PATCH 5/6] Calculate fast/slow axis. --- src/ess/nmx/workflows.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/ess/nmx/workflows.py b/src/ess/nmx/workflows.py index d0af71c..4e0ec12 100644 --- a/src/ess/nmx/workflows.py +++ b/src/ess/nmx/workflows.py @@ -240,8 +240,10 @@ def assemble_detector_metadata( t_unit = transformation.value.unit axis_vectors = { - 'x': positions['x_pixel_offset', -1] - positions['x_pixel_offset', 0], - 'y': positions['y_pixel_offset', -1] - positions['y_pixel_offset', 0], + 'x': positions['x_pixel_offset', 1]['y_pixel_offset', 0] + - positions['x_pixel_offset', 0]['y_pixel_offset', 0], + 'y': positions['y_pixel_offset', 1]['x_pixel_offset', 0] + - positions['y_pixel_offset', 0]['x_pixel_offset', 0], } fast_axis_vector = axis_vectors[_fast_axis].to(unit=t_unit) From 346819adcfeddfc4e93f5cc4bff29e0d5cbae3b8 Mon Sep 17 00:00:00 2001 From: YooSunyoung <17974113+YooSunYoung@users.noreply.github.com> Date: Fri, 5 Dec 2025 16:02:36 +0100 Subject: [PATCH 6/6] Temporary fix until we fix the tof simulation. [skip ci] --- src/ess/nmx/executables.py | 17 ++++++++++------- src/ess/nmx/nexus.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/ess/nmx/executables.py b/src/ess/nmx/executables.py index 4cd16d7..ebbe335 100644 --- a/src/ess/nmx/executables.py +++ b/src/ess/nmx/executables.py @@ -8,8 +8,8 @@ import scipp as sc import scippnexus as snx -from ess.reduce.nexus.types import Filename, NeXusName, SampleRun -from ess.reduce.time_of_flight.types import TimeOfFlightLookupTable, TofDetector +from ess.reduce.nexus.types import Filename, NeXusName, RawDetector, SampleRun +from ess.reduce.time_of_flight.types import TimeOfFlightLookupTable # , TofDetector from ._executable_helper import ( build_logger, @@ -96,10 +96,11 @@ def _finalize_tof_bin_edges( *, tof_das: sc.DataGroup, config: WorkflowConfig ) -> sc.Variable: tof_bin_edges = sc.concat( - tuple(tof_da.coords['tof'] for tof_da in tof_das.values()), dim='tof' + tuple(tof_da.coords['event_time_offset'] for tof_da in tof_das.values()), + dim='event_time_offset', ) return sc.linspace( - dim='tof', + dim='event_time_offset', start=sc.min(tof_bin_edges), stop=sc.max(tof_bin_edges), num=config.nbins + 1, @@ -170,14 +171,16 @@ def reduction( for detector_name in detector_names: cur_wf = base_wf.copy() cur_wf[NeXusName[snx.NXdetector]] = detector_name - results = cur_wf.compute((TofDetector[SampleRun], NMXDetectorMetadata)) + results = cur_wf.compute((RawDetector[SampleRun], NMXDetectorMetadata)) detector_meta: NMXDetectorMetadata = results[NMXDetectorMetadata] export_detector_metadata_as_nxlauetof( detector_metadata=detector_meta, output_file=config.output.output_file ) detector_metas[detector_name] = detector_meta # Binning into 1 bin and getting final tof bin edges later. - tof_das[detector_name] = results[TofDetector[SampleRun]].bin(tof=1) + tof_das[detector_name] = results[RawDetector[SampleRun]].bin( + event_time_offset=1 + ) tof_bin_edges = _finalize_tof_bin_edges(tof_das=tof_das, config=config.workflow) @@ -198,7 +201,7 @@ def reduction( tof_histograms = sc.DataGroup() for detector_name, tof_da in tof_das.items(): det_meta: NMXDetectorMetadata = detector_metas[detector_name] - histogram = tof_da.hist(tof=tof_bin_edges) + histogram = tof_da.hist(event_time_offset=tof_bin_edges) tof_histograms[detector_name] = histogram export_reduced_data_as_nxlauetof( detector_name=det_meta.detector_name, diff --git a/src/ess/nmx/nexus.py b/src/ess/nmx/nexus.py index eb233b2..c6d6d43 100644 --- a/src/ess/nmx/nexus.py +++ b/src/ess/nmx/nexus.py @@ -416,5 +416,5 @@ def export_reduced_data_as_nxlauetof( _create_dataset_from_var( name='time_of_flight', root_entry=nx_detector, - var=sc.midpoints(da.coords['tof'], dim='tof'), + var=sc.midpoints(da.coords['event_time_offset'], dim='event_time_offset'), )