diff --git a/src/tracksdata/utils/_multiprocessing.py b/src/tracksdata/utils/_multiprocessing.py index 24696b11..7f7c37c9 100644 --- a/src/tracksdata/utils/_multiprocessing.py +++ b/src/tracksdata/utils/_multiprocessing.py @@ -52,7 +52,10 @@ def multiprocessing_apply( options = get_options() disable_tqdm = not options.show_progress - if length == 1: + if length == 0: + return + + elif length == 1: # skipping iteration overhead yield func(sequence[0]) diff --git a/src/tracksdata/utils/_test/test_multiprocessing.py b/src/tracksdata/utils/_test/test_multiprocessing.py new file mode 100644 index 00000000..65b31a8f --- /dev/null +++ b/src/tracksdata/utils/_test/test_multiprocessing.py @@ -0,0 +1,22 @@ +import pytest + +from tracksdata.options import options_context +from tracksdata.utils._multiprocessing import multiprocessing_apply + + +def _square(x: int) -> int: + return x * x + + +@pytest.mark.parametrize("n_workers", [1, 2]) +def test_multiprocessing_apply_empty_sequence(n_workers: int) -> None: + """An empty sequence must be a no-op regardless of the worker count.""" + with options_context(n_workers=n_workers): + assert list(multiprocessing_apply(_square, [], desc="empty")) == [] + + +@pytest.mark.parametrize("n_workers", [1, 2]) +def test_multiprocessing_apply_results(n_workers: int) -> None: + with options_context(n_workers=n_workers): + results = sorted(multiprocessing_apply(_square, [1, 2, 3], desc="squares")) + assert results == [1, 4, 9]