diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 52c98a5..6729f1e 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -22,7 +22,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest pytest-cov + python -m pip install flake8 pytest pytest-cov flake8-pyproject if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | diff --git a/CHANGELOG.md b/CHANGELOG.md index 136e789..ae09fa6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # PyProbables Changelog +### Version 0.7.0 + +***Breaking Changes*** +Minor breaking changes; mismatched Bloom filters raise a `SimilarityError` instead of returning `None` + +* `BitArray` + * Add ability to read and write as bytes + * Add abilitt to export +* Updated typing to be more consistent and correct + + ### Version 0.6.2 * `BloomFilterOnDisk` diff --git a/probables/__init__.py b/probables/__init__.py index 8791d39..be36fb6 100644 --- a/probables/__init__.py +++ b/probables/__init__.py @@ -15,6 +15,7 @@ NotSupportedError, ProbablesBaseException, RotatingBloomFilterError, + SimilarityError, ) from probables.quotientfilter import QuotientFilter from probables.utilities import Bitarray @@ -48,4 +49,5 @@ "RotatingBloomFilterError", "QuotientFilter", "Bitarray", + "SimilarityError", ] diff --git a/probables/blooms/bloom.py b/probables/blooms/bloom.py index 888f895..921e647 100644 --- a/probables/blooms/bloom.py +++ b/probables/blooms/bloom.py @@ -18,7 +18,7 @@ from textwrap import wrap from typing import Union -from probables.exceptions import InitializationError, NotSupportedError +from probables.exceptions import InitializationError, NotSupportedError, SimilarityError from probables.hashes import HashFuncT, HashResultsT, KeyT, default_fnv_1a from probables.utilities import MMap, is_hex_string, is_valid_file, resolve_path @@ -368,7 +368,7 @@ def current_false_positive_rate(self) -> float: exp = math.exp(dbl) return math.pow((1 - exp), self.number_hashes) - def intersection(self, second: SimpleBloomT) -> Union[SimpleBloomT, None]: + def intersection(self, second: SimpleBloomT) -> SimpleBloomT: """Return a new Bloom Filter that contains the intersection of the two @@ -378,15 +378,14 @@ def intersection(self, second: SimpleBloomT) -> Union[SimpleBloomT, None]: BloomFilter: The new Bloom Filter containing the intersection Raises: TypeError: When second is not either a :class:`BloomFilter` or :class:`BloomFilterOnDisk` + SimilarityError: When second is not of the same size (false_positive_rate and est_elements) Note: - `second` may be a BloomFilterOnDisk object - Note: - If `second` is not of the same size (false_positive_rate and est_elements) then this will return `None`""" + `second` may be a BloomFilterOnDisk object""" if not _verify_not_type_mismatch(second): raise TypeError(MISMATCH_MSG) if self._verify_bloom_similarity(second) is False: - return None + raise SimilarityError("Bloom Filters are not similar") res = BloomFilter( self.estimated_elements, @@ -399,7 +398,7 @@ def intersection(self, second: SimpleBloomT) -> Union[SimpleBloomT, None]: res.elements_added = res.estimate_elements() return res - def union(self, second: SimpleBloomT) -> Union["BloomFilter", None]: + def union(self, second: SimpleBloomT) -> "BloomFilter": """Return a new Bloom Filter that contains the union of the two Args: @@ -408,15 +407,14 @@ def union(self, second: SimpleBloomT) -> Union["BloomFilter", None]: BloomFilter: The new Bloom Filter containing the union Raises: TypeError: When second is not either a :class:`BloomFilter` or :class:`BloomFilterOnDisk` + SimilarityError: When second is not of the same size (false_positive_rate and est_elements) Note: - `second` may be a BloomFilterOnDisk object - Note: - If `second` is not of the same size (false_positive_rate and est_elements) then this will return `None`""" + `second` may be a BloomFilterOnDisk object""" if not _verify_not_type_mismatch(second): raise TypeError(MISMATCH_MSG) if self._verify_bloom_similarity(second) is False: - return None + raise SimilarityError("Bloom Filters are not similar") res = BloomFilter( self.estimated_elements, @@ -429,7 +427,7 @@ def union(self, second: SimpleBloomT) -> Union["BloomFilter", None]: res.elements_added = res.estimate_elements() return res - def jaccard_index(self, second: SimpleBloomT) -> Union[float, None]: + def jaccard_index(self, second: SimpleBloomT) -> float: """Calculate the jaccard similarity score between two Bloom Filters Args: @@ -438,15 +436,14 @@ def jaccard_index(self, second: SimpleBloomT) -> Union[float, None]: float: A numeric value between 0 and 1 where 1 is identical and 0 means completely different Raises: TypeError: When second is not either a :class:`BloomFilter` or :class:`BloomFilterOnDisk` + SimilarityError: When second is not of the same size (false_positive_rate and est_elements) Note: - `second` may be a BloomFilterOnDisk object - Note: - If `second` is not of the same size (false_positive_rate and est_elements) then this will return `None`""" + `second` may be a BloomFilterOnDisk object""" if not _verify_not_type_mismatch(second): raise TypeError(MISMATCH_MSG) if self._verify_bloom_similarity(second) is False: - return None + raise SimilarityError("Bloom Filters are not similar") count_union = 0 @@ -694,7 +691,7 @@ def _get_element(self, idx: int) -> int: def __update(self): """update the on disk Bloom Filter and ensure everything is out to disk""" - self._bloom.flush() - self.__file_pointer.seek(-1 * self._UPDATE_OFFSET.size, os.SEEK_END) - self.__file_pointer.write(self._EXPECTED_ELM_STRUCT.pack(self.elements_added)) - self.__file_pointer.flush() + self._bloom.flush() # type: ignore + self.__file_pointer.seek(-1 * self._UPDATE_OFFSET.size, os.SEEK_END) # type: ignore + self.__file_pointer.write(self._EXPECTED_ELM_STRUCT.pack(self.elements_added)) # type: ignore + self.__file_pointer.flush() # type: ignore diff --git a/probables/blooms/countingbloom.py b/probables/blooms/countingbloom.py index 99015cb..99556a6 100644 --- a/probables/blooms/countingbloom.py +++ b/probables/blooms/countingbloom.py @@ -12,7 +12,7 @@ from probables.blooms.bloom import BloomFilter from probables.constants import UINT32_T_MAX, UINT64_T_MAX -from probables.exceptions import InitializationError +from probables.exceptions import InitializationError, SimilarityError from probables.hashes import HashFuncT, HashResultsT, KeyT from probables.utilities import is_hex_string, is_valid_file, resolve_path @@ -208,7 +208,7 @@ def remove_alt(self, hashes: HashResultsT, num_els: int = 1) -> int: self.elements_added -= to_remove return min_val - to_remove - def intersection(self, second: "CountingBloomFilter") -> Union["CountingBloomFilter", None]: # type: ignore + def intersection(self, second: "CountingBloomFilter") -> "CountingBloomFilter": # type: ignore """Take the intersection of two Counting Bloom Filters Args: @@ -217,17 +217,16 @@ def intersection(self, second: "CountingBloomFilter") -> Union["CountingBloomFil CountingBloomFilter: The new Counting Bloom Filter containing the union Raises: TypeError: When second is not a :class:`CountingBloomFilter` + SimilarityError: When second is not of the same size (false_positive_rate and est_elements) Note: The elements_added property will be set to the estimated number of unique elements \ - added as found in estimate_elements() - Note: - If `second` is not of the same size (false_positive_rate and est_elements) then \ - this will return `None`""" + added as found in estimate_elements()""" if not _verify_not_type_mismatch(second): raise TypeError(MISMATCH_MSG) if self._verify_bloom_similarity(second) is False: - return None + raise SimilarityError("Counting Bloom Filters are not similar enough to calculate similarity") + res = CountingBloomFilter( est_elements=self.estimated_elements, false_positive_rate=self.false_positive_rate, @@ -241,7 +240,7 @@ def intersection(self, second: "CountingBloomFilter") -> Union["CountingBloomFil res.elements_added = res.estimate_elements() return res - def jaccard_index(self, second: "CountingBloomFilter") -> Union[float, None]: # type:ignore + def jaccard_index(self, second: "CountingBloomFilter") -> float: # type: ignore """Take the Jaccard Index of two Counting Bloom Filters Args: @@ -250,15 +249,14 @@ def jaccard_index(self, second: "CountingBloomFilter") -> Union[float, None]: # float: A numeric value between 0 and 1 where 1 is identical and 0 means completely different Raises: TypeError: When second is not a :class:`CountingBloomFilter` + SimilarityError: When second is not of the same size (false_positive_rate and est_elements) Note: - The Jaccard Index is based on the unique set of elements added and not the number of each element added - Note: - If `second` is not of the same size (false_positive_rate and est_elements) then this will return `None`""" + The Jaccard Index is based on the unique set of elements added and not the number of each element added""" if not _verify_not_type_mismatch(second): raise TypeError(MISMATCH_MSG) if self._verify_bloom_similarity(second) is False: - return None + raise SimilarityError("Counting Bloom Filters are not similar enough to calculate similarity") count_union = 0 count_inter = 0 @@ -271,7 +269,7 @@ def jaccard_index(self, second: "CountingBloomFilter") -> Union[float, None]: # return 1.0 return count_inter / count_union - def union(self, second: "CountingBloomFilter") -> Union["CountingBloomFilter", None]: # type:ignore + def union(self, second: "CountingBloomFilter") -> "CountingBloomFilter": # type:ignore """Return a new Countiong Bloom Filter that contains the union of the two @@ -281,16 +279,16 @@ def union(self, second: "CountingBloomFilter") -> Union["CountingBloomFilter", N CountingBloomFilter: The new Counting Bloom Filter containing the union Raises: TypeError: When second is not a :class:`CountingBloomFilter` + SimilarityError: When second is not of the same size (false_positive_rate and est_elements) Note: The elements_added property will be set to the estimated number of unique elements added as \ - found in estimate_elements() - Note: - If `second` is not of the same size (false_positive_rate and est_elements) then this will return `None`""" + found in estimate_elements()""" if not _verify_not_type_mismatch(second): raise TypeError(MISMATCH_MSG) if self._verify_bloom_similarity(second) is False: - return None + raise SimilarityError("Counting Bloom Filters are not similar enough to calculate similarity") + res = CountingBloomFilter( est_elements=self.estimated_elements, false_positive_rate=self.false_positive_rate, diff --git a/probables/cuckoo/countingcuckoo.py b/probables/cuckoo/countingcuckoo.py index e073757..4fb52c0 100644 --- a/probables/cuckoo/countingcuckoo.py +++ b/probables/cuckoo/countingcuckoo.py @@ -304,7 +304,7 @@ def _parse_buckets(self, d: ByteString) -> None: start = end end += bin_size - def _expand_logic(self, extra_fingerprint: "CountingCuckooBin") -> None: + def _expand_logic(self, extra_fingerprint: Union["CountingCuckooBin", None]) -> None: """the logic to acutally expand the cuckoo filter""" # get all the fingerprints fingerprints = self._setup_expand(extra_fingerprint) diff --git a/probables/cuckoo/cuckoo.py b/probables/cuckoo/cuckoo.py index cbc8136..1505da2 100644 --- a/probables/cuckoo/cuckoo.py +++ b/probables/cuckoo/cuckoo.py @@ -487,7 +487,7 @@ def _indicies_from_fingerprint(self, fingerprint): Args: fingerprint (int): The fingerprint to use for generating indicies""" idx_1 = fingerprint % self.capacity - idx_2 = self.__hash_func(str(fingerprint)) % self.capacity + idx_2 = self.__hash_func(str(fingerprint)) % self.capacity # type: ignore return idx_1, idx_2 def _generate_fingerprint_info(self, key: KeyT) -> tuple[int, int, int]: @@ -497,7 +497,7 @@ def _generate_fingerprint_info(self, key: KeyT) -> tuple[int, int, int]: key (str): The element for which information is to be generated """ # generate the fingerprint along with the two possible indecies - hash_val = self.__hash_func(key) + hash_val = self.__hash_func(key) # type: ignore fingerprint = get_x_bits(hash_val, 64, self.fingerprint_size_bits, True) idx_1, idx_2 = self._indicies_from_fingerprint(fingerprint) diff --git a/probables/exceptions.py b/probables/exceptions.py index b0e3bf1..31324e5 100644 --- a/probables/exceptions.py +++ b/probables/exceptions.py @@ -37,6 +37,17 @@ def __init__(self, message: str) -> None: super().__init__(self.message) +class SimilarityError(ProbablesBaseException): + """Similarity Exception + + Args: + message (str): The error message to be reported""" + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(self.message) + + class CuckooFilterFullError(ProbablesBaseException): """Cuckoo Filter Full Exception diff --git a/probables/hashes.py b/probables/hashes.py index 8875869..b7006b9 100644 --- a/probables/hashes.py +++ b/probables/hashes.py @@ -9,12 +9,13 @@ KeyT = Union[str, bytes] SimpleHashT = Callable[[KeyT, int], int] +SimpleHashBytesT = Callable[[KeyT, int], bytes] HashResultsT = list[int] HashFuncT = Callable[[KeyT, int], HashResultsT] HashFuncBytesT = Callable[[KeyT, int], bytes] -def hash_with_depth_bytes(func: HashFuncBytesT) -> HashFuncT: +def hash_with_depth_bytes(func: Union[HashFuncBytesT, SimpleHashBytesT]) -> HashFuncT: """Decorator to turns a function taking a single key and hashes it to bytes. Wraps functions to be used in Bloom filters and Count-Min sketch data structures. @@ -40,7 +41,7 @@ def hashing_func(key, depth=1): return hashing_func -def hash_with_depth_int(func: HashFuncT) -> HashFuncT: +def hash_with_depth_int(func: Union[HashFuncT, SimpleHashT]) -> HashFuncT: """Decorator to turn a function that takes a single key and hashes it to an int. Wraps functions to be used in Bloom filters and Count-Min sketch data structures. diff --git a/pyproject.toml b/pyproject.toml index 7769600..4f43705 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ max-line-length = 120 max-line-length = 120 [tool.flake8] +extend-ignore = ["E203"] max-line-length = 120 [tool.isort] diff --git a/tests/bloom_test.py b/tests/bloom_test.py index 21f8788..5e2f927 100755 --- a/tests/bloom_test.py +++ b/tests/bloom_test.py @@ -14,7 +14,7 @@ from probables import BloomFilter, BloomFilterOnDisk # noqa: E402 from probables.constants import UINT64_T_MAX # noqa: E402 -from probables.exceptions import InitializationError, NotSupportedError # noqa: E402 +from probables.exceptions import InitializationError, NotSupportedError, SimilarityError # noqa: E402 from probables.hashes import hash_with_depth_int # noqa: E402 from tests.utilities import calc_file_md5, different_hash # noqa: E402 @@ -95,8 +95,7 @@ def test_bf_union_diff(self): blm.add("this is a test") blm2 = BloomFilter(est_elements=10, false_positive_rate=0.05, hash_function=different_hash) - blm3 = blm.union(blm2) - self.assertEqual(blm3, None) + self.assertRaises(SimilarityError, lambda: blm.union(blm2)) def test_bf_intersection(self): """test the union of two bloom filters""" @@ -146,8 +145,7 @@ def test_bf_intersection_diff(self): blm.add("this is a test") blm2 = BloomFilter(est_elements=100, false_positive_rate=0.05) - blm3 = blm.intersection(blm2) - self.assertEqual(blm3, None) + self.assertRaises(SimilarityError, lambda: blm.intersection(blm2)) def test_bf_jaccard(self): """test the jaccard index of two bloom filters""" @@ -168,14 +166,13 @@ def test_bf_jaccard_diff(self): blm.add("this is a test") blm2 = BloomFilter(est_elements=100, false_positive_rate=0.05) - blm3 = blm.jaccard_index(blm2) - self.assertEqual(blm3, None) + self.assertRaises(SimilarityError, lambda: blm.jaccard_index(blm2)) def test_bf_jaccard_invalid(self): """use an invalid type in a jaccard index""" blm = BloomFilter(est_elements=10, false_positive_rate=0.05) blm.add("this is a test") - self.assertRaises(TypeError, lambda: blm.jaccard_index(1)) + self.assertRaises(TypeError, lambda: blm.jaccard_index(1)) # type: ignore def test_bf_jaccard_invalid_msg(self): """check invalid type in a jaccard index message""" @@ -183,7 +180,7 @@ def test_bf_jaccard_invalid_msg(self): blm = BloomFilter(est_elements=10, false_positive_rate=0.05) blm.add("this is a test") try: - blm.jaccard_index(1) + blm.jaccard_index(1) # type: ignore except TypeError as ex: self.assertEqual(str(ex), msg) else: @@ -193,7 +190,7 @@ def test_bf_union_invalid(self): """use an invalid type in a union""" blm = BloomFilter(est_elements=10, false_positive_rate=0.05) blm.add("this is a test") - self.assertRaises(TypeError, lambda: blm.jaccard_index(1)) + self.assertRaises(TypeError, lambda: blm.jaccard_index(1)) # type: ignore def test_bf_union_invalid_msg(self): """check invalid type in a union message""" @@ -201,7 +198,7 @@ def test_bf_union_invalid_msg(self): blm = BloomFilter(est_elements=10, false_positive_rate=0.05) blm.add("this is a test") try: - blm.union(1) + blm.union(1) # type: ignore except TypeError as ex: self.assertEqual(str(ex), msg) else: @@ -211,7 +208,7 @@ def test_bf_intersection_invalid(self): """use an invalid type in a intersection""" blm = BloomFilter(est_elements=10, false_positive_rate=0.05) blm.add("this is a test") - self.assertRaises(TypeError, lambda: blm.jaccard_index(1)) + self.assertRaises(TypeError, lambda: blm.jaccard_index(1)) # type: ignore def test_bf_intersec_invalid_msg(self): """check invalid type in a intersection message""" @@ -219,7 +216,7 @@ def test_bf_intersec_invalid_msg(self): blm = BloomFilter(est_elements=10, false_positive_rate=0.05) blm.add("this is a test") try: - blm.intersection(1) + blm.intersection(1) # type: ignore except TypeError as ex: self.assertEqual(str(ex), msg) else: @@ -432,7 +429,7 @@ def test_invalid_fpr_2(self): def runner(): """runner""" - BloomFilter(est_elements=100, false_positive_rate="1.1") + BloomFilter(est_elements=100, false_positive_rate="1.1") # type: ignore self.assertRaises(InitializationError, runner) try: @@ -448,7 +445,7 @@ def test_invalid_estimated_els_2(self): def runner(): """runner""" - BloomFilter(est_elements=[0], false_positive_rate=0.1) + BloomFilter(est_elements=[0], false_positive_rate=0.1) # type: ignore self.assertRaises(InitializationError, runner) try: @@ -713,7 +710,7 @@ def test_bfod_close_del(self): blm.add("this is a test") del blm try: - self.assertEqual(True, blm) # noqa: F821 + self.assertEqual(True, blm) # type: ignore # noqa: F821 except UnboundLocalError as ex: msg1 = "local variable 'blm' referenced before assignment" msg2 = "cannot access local variable 'blm' where it is not associated with a value" @@ -859,8 +856,7 @@ def test_bfod_union_diff(self): blm.add("this is a test") blm2 = BloomFilter(est_elements=10, false_positive_rate=0.05, hash_function=different_hash) - blm3 = blm.union(blm2) - self.assertEqual(blm3, None) + self.assertRaises(SimilarityError, lambda: blm.union(blm2)) def test_bfod_intersection_diff(self): """make sure checking for different bloom filters on disk works intersection""" @@ -869,8 +865,7 @@ def test_bfod_intersection_diff(self): blm.add("this is a test") blm2 = BloomFilter(est_elements=10, false_positive_rate=0.05, hash_function=different_hash) - blm3 = blm.intersection(blm2) - self.assertEqual(blm3, None) + self.assertRaises(SimilarityError, lambda: blm.intersection(blm2)) def test_bfod_jaccard_diff(self): """make sure checking for different bloom filters on disk works jaccard""" @@ -879,15 +874,14 @@ def test_bfod_jaccard_diff(self): blm.add("this is a test") blm2 = BloomFilter(est_elements=10, false_positive_rate=0.05, hash_function=different_hash) - blm3 = blm.jaccard_index(blm2) - self.assertEqual(blm3, None) + self.assertRaises(SimilarityError, lambda: blm.jaccard_index(blm2)) def test_bfod_jaccard_invalid(self): """use an invalid type in a jaccard index cbf""" with NamedTemporaryFile(dir=os.getcwd(), suffix=".blm", delete=DELETE_TEMP_FILES) as fobj: blm = BloomFilterOnDisk(fobj.name, est_elements=10, false_positive_rate=0.05) blm.add("this is a test") - self.assertRaises(TypeError, lambda: blm.jaccard_index(1)) + self.assertRaises(TypeError, lambda: blm.jaccard_index(1)) # type: ignore def test_bfod_jaccard_invalid_msg(self): """check invalid type in a jaccard index message cbf""" @@ -896,7 +890,7 @@ def test_bfod_jaccard_invalid_msg(self): blm = BloomFilterOnDisk(fobj.name, est_elements=10, false_positive_rate=0.05) blm.add("this is a test") try: - blm.jaccard_index(1) + blm.jaccard_index(1) # type: ignore except TypeError as ex: self.assertEqual(str(ex), msg) else: @@ -907,7 +901,7 @@ def test_bfod_union_invalid(self): with NamedTemporaryFile(dir=os.getcwd(), suffix=".blm", delete=DELETE_TEMP_FILES) as fobj: blm = BloomFilterOnDisk(fobj.name, est_elements=10, false_positive_rate=0.05) blm.add("this is a test") - self.assertRaises(TypeError, lambda: blm.jaccard_index(1)) + self.assertRaises(TypeError, lambda: blm.jaccard_index(1)) # type: ignore def test_bfod_union_invalid_msg(self): """check invalid type in a union message cbf""" @@ -916,7 +910,7 @@ def test_bfod_union_invalid_msg(self): blm = BloomFilterOnDisk(fobj.name, est_elements=10, false_positive_rate=0.05) blm.add("this is a test") try: - blm.union(1) + blm.union(1) # type: ignore except TypeError as ex: self.assertEqual(str(ex), msg) else: @@ -927,7 +921,7 @@ def test_bfod_intersection_invalid(self): with NamedTemporaryFile(dir=os.getcwd(), suffix=".blm", delete=DELETE_TEMP_FILES) as fobj: blm = BloomFilterOnDisk(fobj.name, est_elements=10, false_positive_rate=0.05) blm.add("this is a test") - self.assertRaises(TypeError, lambda: blm.jaccard_index(1)) + self.assertRaises(TypeError, lambda: blm.jaccard_index(1)) # type: ignore def test_cbf_intersec_invalid_msg(self): """check invalid type in a intersection message cbf""" @@ -936,7 +930,7 @@ def test_cbf_intersec_invalid_msg(self): blm = BloomFilterOnDisk(fobj.name, est_elements=10, false_positive_rate=0.05) blm.add("this is a test") try: - blm.intersection(1) + blm.intersection(1) # type: ignore except TypeError as ex: self.assertEqual(str(ex), msg) diff --git a/tests/countingbloom_test.py b/tests/countingbloom_test.py index 6527945..d7177a1 100755 --- a/tests/countingbloom_test.py +++ b/tests/countingbloom_test.py @@ -13,7 +13,7 @@ sys.path.insert(0, str(this_dir.parent)) from probables import CountingBloomFilter # noqa: E402 -from probables.exceptions import InitializationError # noqa: E402 +from probables.exceptions import InitializationError, SimilarityError # noqa: E402 from tests.utilities import calc_file_md5, different_hash # noqa: E402 DELETE_TEMP_FILES = True @@ -370,7 +370,7 @@ def test_cbf_jaccard_different_2(self): """test jaccard of an mismath of counting bloom filters""" blm1 = CountingBloomFilter(est_elements=101, false_positive_rate=0.01) blm2 = CountingBloomFilter(est_elements=10, false_positive_rate=0.01) - self.assertEqual(blm1.jaccard_index(blm2), None) + self.assertRaises(SimilarityError, lambda: blm1.jaccard_index(blm2)) def test_cbf_jaccard_invalid(self): """use an invalid type in a jaccard index""" @@ -492,7 +492,7 @@ def test_cbf_union_diff(self): """test union of an mismath of counting bloom filters""" blm1 = CountingBloomFilter(est_elements=101, false_positive_rate=0.01) blm2 = CountingBloomFilter(est_elements=10, false_positive_rate=0.01) - self.assertEqual(blm1.union(blm2), None) + self.assertRaises(SimilarityError, lambda: blm1.union(blm2)) def test_cbf_inter(self): """test calculating the intersection between two @@ -529,7 +529,7 @@ def test_cbf_inter_diff(self): """test intersection of an mismath of counting bloom filters""" blm1 = CountingBloomFilter(est_elements=101, false_positive_rate=0.01) blm2 = CountingBloomFilter(est_elements=10, false_positive_rate=0.01) - self.assertEqual(blm1.intersection(blm2), None) + self.assertRaises(SimilarityError, lambda: blm1.intersection(blm2)) def test_cbf_all_bits_set(self): """test inserting too many elements so that the all bits are set""" diff --git a/tests/countminsketch_test.py b/tests/countminsketch_test.py index 910e14d..9832847 100755 --- a/tests/countminsketch_test.py +++ b/tests/countminsketch_test.py @@ -125,7 +125,7 @@ def test_cms_check_min(self): def test_cms_check_min_called(self): """test checking number elements using min algorithm called out""" cms = CountMinSketch(width=1000, depth=5) - cms.query_type = None + cms.query_type = None # type: ignore self.assertEqual(cms.add("this is a test", 255), 255) self.assertEqual(cms.add("this is another test", 189), 189) self.assertEqual(cms.add("this is also a test", 16), 16) @@ -425,7 +425,7 @@ def test_cms_join_invalid(self): cms = CountMinSketch(width=1000, depth=5) try: - cms.join(1) + cms.join(1) # type: ignore except TypeError as ex: msg = "Unable to merge a count-min sketch with {}".format("") self.assertEqual(str(ex), msg) @@ -469,7 +469,7 @@ def test_cms_invalid_width_2(self): def runner(): """runner""" - CountMinSketch(width="0.0", depth=5) + CountMinSketch(width="0.0", depth=5) # type: ignore self.assertRaises(InitializationError, runner) msg = "CountMinSketch: width and depth must be greater than 0" @@ -485,7 +485,7 @@ def test_cms_invalid_depth_2(self): def runner(): """runner""" - CountMinSketch(width=1000, depth=[]) + CountMinSketch(width=1000, depth=[]) # type: ignore self.assertRaises(InitializationError, runner) msg = "CountMinSketch: width and depth must be greater than 0" @@ -533,7 +533,7 @@ def test_cms_invalid_conf_2(self): def runner(): """runner""" - CountMinSketch(confidence=3.0, error_rate="0.99") + CountMinSketch(confidence=3.0, error_rate="0.99") # type: ignore self.assertRaises(InitializationError, runner) msg = "CountMinSketch: width and depth must be greater than 0" @@ -549,7 +549,7 @@ def test_cms_invalid_err_rate_2(self): def runner(): """runner""" - CountMinSketch(width=1000, depth=[]) + CountMinSketch(width=1000, depth=[]) # type: ignore self.assertRaises(InitializationError, runner) msg = "CountMinSketch: width and depth must be greater than 0" diff --git a/tests/cuckoo_test.py b/tests/cuckoo_test.py index c49e23b..969f92b 100755 --- a/tests/cuckoo_test.py +++ b/tests/cuckoo_test.py @@ -7,6 +7,7 @@ import unittest from pathlib import Path from tempfile import NamedTemporaryFile +from typing import Union this_dir = Path(__file__).parent sys.path.insert(0, str(this_dir)) @@ -62,9 +63,10 @@ def test_cuckoo_filter_add(self): def test_cuckoo_filter_diff_hash(self): """test using a different hash function""" - def my_hash(key): + def my_hash(key: Union[str, bytes], depth: int = 1) -> int: """fake hash""" - return int(hashlib.sha512(key.encode("utf-8")).hexdigest(), 16) + k = key if isinstance(key, bytes) else key.encode("utf-8") + return int(hashlib.sha512(k).hexdigest(), 16) cko = CuckooFilter( capacity=100, @@ -402,7 +404,7 @@ def test_invalid_capacity_2(self): def runner(): """runner""" - CuckooFilter(capacity="abc") + CuckooFilter(capacity="abc") # type: ignore self.assertRaises(InitializationError, runner) msg = "CuckooFilter: capacity, bucket_size, and max_swaps must be an integer greater than 0" @@ -418,7 +420,7 @@ def test_invalid_buckets_2(self): def runner(): """runner""" - CuckooFilter(bucket_size=[0]) + CuckooFilter(bucket_size=[0]) # type: ignore self.assertRaises(InitializationError, runner) msg = "CuckooFilter: capacity, bucket_size, and max_swaps must be an integer greater than 0" @@ -434,7 +436,7 @@ def test_invalid_swaps_2(self): def runner(): """runner""" - CuckooFilter(max_swaps=None) + CuckooFilter(max_swaps=None) # type: ignore self.assertRaises(InitializationError, runner) msg = "CuckooFilter: capacity, bucket_size, and max_swaps must be an integer greater than 0" diff --git a/tests/quotientfilter_test.py b/tests/quotientfilter_test.py index 773c998..de2b76d 100644 --- a/tests/quotientfilter_test.py +++ b/tests/quotientfilter_test.py @@ -7,6 +7,7 @@ import unittest from pathlib import Path from tempfile import NamedTemporaryFile +from typing import TextIO, cast from probables.exceptions import QuotientFilterError @@ -435,7 +436,7 @@ def test_quotient_filter_print_empty(self): """Test printing the data of a quotient filter in a manner to be read through""" qf = QuotientFilter(quotient=7) with NamedTemporaryFile(dir=os.getcwd(), suffix=".txt", delete=DELETE_TEMP_FILES, mode="wt") as fobj: - qf.print(file=fobj.file) + qf.print(file=cast(TextIO, fobj)) fobj.flush() with open(fobj.name) as fobj: @@ -443,7 +444,7 @@ def test_quotient_filter_print_empty(self): data = [x.strip() for x in data] self.assertEqual(data[0], "idx\t--\tO-C-S\tStatus") for i in range(2, len(data)): - self.assertEqual(data[i], f"{i-2}\t--\t0-0-0\tEmpty") + self.assertEqual(data[i], f"{i - 2}\t--\t0-0-0\tEmpty") def test_quotient_filter_print(self): """Test printing the data of a quotient filter in a manner to be read through not empty""" @@ -453,7 +454,7 @@ def test_quotient_filter_print(self): qf.add(a) with NamedTemporaryFile(dir=os.getcwd(), suffix=".txt", delete=DELETE_TEMP_FILES, mode="wt") as fobj: - qf.print(file=fobj.file) + qf.print(file=cast(TextIO, fobj)) fobj.flush() with open(fobj.name) as fobj: