diff --git a/src/nitypes/vector.py b/src/nitypes/vector.py index ff719c4d..bbb96548 100644 --- a/src/nitypes/vector.py +++ b/src/nitypes/vector.py @@ -8,7 +8,7 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping, MutableSequence +from collections.abc import Iterable, Mapping, MutableSequence, Sequence from typing import TYPE_CHECKING, Any, Union, overload from typing_extensions import Self, TypeVar, final, override @@ -79,19 +79,19 @@ def __init__( Returns: A vector data object. """ - if not values: + backing_values = list(values) + if not backing_values: if not value_type: raise TypeError("You must specify values as non-empty or specify value_type.") self._value_type = value_type else: + self._value_type = type(backing_values[0]) # Validate the values input - for index, value in enumerate(values): - # Only set _value_type once. - if not index: - self._value_type = type(value) - + for value in backing_values: if not isinstance(value, (bool, int, float, str)): - raise invalid_arg_type("vector input data", "bool, int, float, or str", values) + raise invalid_arg_type( + "vector input data", "bool, int, float, or str", backing_values + ) if not isinstance(value, self._value_type): raise TypeError("All values in the values input must be of the same type.") @@ -99,7 +99,7 @@ def __init__( if not isinstance(units, str): raise invalid_arg_type("units", "str", units) - self._values = list(values) + self._values = backing_values if copy_extended_properties or not isinstance( extended_properties, ExtendedPropertyDictionary ): @@ -178,14 +178,20 @@ def __setitem__(self, index: int | slice, value: TScalar | Iterable[TScalar]) -> raise TypeError("You must assign an Iterable to a Vector slice.") elif isinstance(value, str): # Narrow the type to exclude string. raise TypeError("You cannot assign a string to Vector slice.") + + if isinstance(value, Sequence): + # A Sequence can be iterated over multiple times, so no need for a wrapper list. + replacement_values = value else: - # Assigning an empty Iterable to a slice is valid, so we don't check for empty. - # If an empty Iterable is assigned to a slice, that slice is deleted. - for subval in value: - if not isinstance(subval, self._value_type): - raise self._create_value_mismatch_exception(subval) + replacement_values = list(value) - self._values[index] = value + # Assigning an empty Iterable to a slice is valid, so we don't check for empty. + # If an empty Iterable is assigned to a slice, that slice is deleted. + for subval in replacement_values: + if not isinstance(subval, self._value_type): + raise self._create_value_mismatch_exception(subval) + + self._values[index] = replacement_values def __delitem__(self, index: int | slice) -> None: """Delete item(s) from the specified location.""" diff --git a/tests/unit/vector/test_vector.py b/tests/unit/vector/test_vector.py index 65cced64..6f4e68c8 100644 --- a/tests/unit/vector/test_vector.py +++ b/tests/unit/vector/test_vector.py @@ -2,6 +2,7 @@ import copy import pickle +from collections.abc import Generator from typing import Any import pytest @@ -51,7 +52,7 @@ def test___int_data_values___create___creates_with_int_data_and_default_units() assert data.units == "" -def test___float_data_value___create___creates_scalar_data_with_data_and_default_units() -> None: +def test___float_data_values___create___creates_with_float_data_and_default_units() -> None: data = Vector([20.2, 30.3, 40.4]) assert_type(data._values[0], float) @@ -59,7 +60,7 @@ def test___float_data_value___create___creates_scalar_data_with_data_and_default assert data.units == "" -def test___str_data_value___create___creates_scalar_data_with_data_and_default_units() -> None: +def test___str_data_values___create___creates_with_str_data_and_default_units() -> None: data = Vector(["one", "two"]) assert_type(data._values[0], str) @@ -67,6 +68,19 @@ def test___str_data_value___create___creates_scalar_data_with_data_and_default_u assert data.units == "" +def test___generator_data_values___create___creates_with_data_and_default_units() -> None: + def get_values() -> Generator[float]: + yield 20.2 + yield 30.3 + yield 40.4 + + data = Vector(get_values()) + + assert_type(data._values[0], float) + assert data._values == [20.2, 30.3, 40.4] + assert data.units == "" + + @pytest.mark.parametrize("data_value", [True, 10, 20.0, "value"]) @pytest.mark.parametrize("units", ["volts"]) def test___data_value_and_units___create___creates_scalar_data_with_data_and_units( @@ -101,6 +115,29 @@ def test___invalid_data_value___create___raises_type_error(data_value: Any) -> N assert exc.value.args[0].startswith("The vector input data must be a bool, int, float, or str.") +def test___invalid_generator_data_values___create___raises_type_error() -> None: + def get_values() -> Generator[Any]: + yield from [{"value_one", "value_two"}, 1.0, 2.0] + + with pytest.raises(TypeError) as exc: + _ = Vector(get_values()) + + assert exc.value.args[0].startswith("The vector input data must be a bool, int, float, or str.") + assert "value_one" in exc.value.args[0] + + +def test___empty_generator_data_values___create___raises_type_error() -> None: + def get_values() -> Generator[float]: + yield from [] + + with pytest.raises(TypeError) as exc: + _ = Vector(get_values()) + + assert exc.value.args[0].startswith( + "You must specify values as non-empty or specify value_type." + ) + + def test___mixed_data_values___create___raises_type_error() -> None: with pytest.raises(TypeError) as exc: _ = Vector([True, "string", 1.0]) @@ -143,6 +180,19 @@ def test___vector_with_data___set_item_at_slice___values_set_correctly() -> None vector = Vector([1, 2, 3], "volts") vector[0:2] = [6, 7] + + assert vector._values == [6, 7, 3] + + +def test___vector_with_data___set_item_at_slice_with_generator___values_set_correctly() -> None: + def get_values() -> Generator[int]: + yield 6 + yield 7 + + vector = Vector([1, 2, 3], "volts") + + vector[0:2] = get_values() + assert vector._values == [6, 7, 3] @@ -150,6 +200,7 @@ def test___vector_with_data___set_item_at_slice_to_empty_list___values_set_corre vector = Vector([1, 2, 3], "volts") vector[0:2] = [] + assert vector._values == [3]