diff --git a/sprockets_postgres.py b/sprockets_postgres.py index 33e3892..449deb1 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -5,7 +5,6 @@ import os import time import typing -from distutils import util from urllib import parse import aiodns @@ -42,6 +41,16 @@ """Type annotation for timeout values""" +def _strtobool(value: str) -> int: + """Convert common configuration strings to 1 or 0.""" + normalized = value.lower() + if normalized in ('y', 'yes', 't', 'true', 'on', '1'): + return 1 + if normalized in ('n', 'no', 'f', 'false', 'off', '0'): + return 0 + raise ValueError('invalid truth value {!r}'.format(value)) + + class QueryResult: """Contains the results of the query that was executed. @@ -476,17 +485,17 @@ def _create_postgres_settings(self) -> dict: DEFAULT_POSTGRES_CONNECTION_TIMEOUT))), 'enable_hstore': self.settings.get( 'postgres_hstore', - util.strtobool( + _strtobool( os.environ.get( 'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE))), 'enable_json': self.settings.get( 'postgres_json', - util.strtobool( + _strtobool( os.environ.get( 'POSTGRES_JSON', DEFAULT_POSTGRES_JSON))), 'enable_uuid': self.settings.get( 'postgres_uuid', - util.strtobool( + _strtobool( os.environ.get( 'POSTGRES_UUID', DEFAULT_POSTGRES_UUID))), 'query_timeout': int(self.settings.get( diff --git a/tests.py b/tests.py index 44b6459..35b35c8 100644 --- a/tests.py +++ b/tests.py @@ -361,6 +361,25 @@ async def test_postgres_status_before_first_connection(self): 'pool_free': 0}) +class StrToBoolTestCase(unittest.TestCase): + + def test_truthy_values(self): + for value in ('y', 'yes', 't', 'true', 'on', '1', + 'Y', 'YES', 'T', 'TRUE', 'ON', 'True'): + with self.subTest(value=value): + self.assertEqual(sprockets_postgres._strtobool(value), 1) + + def test_falsy_values(self): + for value in ('n', 'no', 'f', 'false', 'off', '0', + 'N', 'NO', 'F', 'FALSE', 'OFF'): + with self.subTest(value=value): + self.assertEqual(sprockets_postgres._strtobool(value), 0) + + def test_invalid_value(self): + with self.assertRaisesRegex(ValueError, 'invalid truth value'): + sprockets_postgres._strtobool('maybe') + + class ReconnectionTestCast(TestCase): @ttesting.gen_test