diff --git a/medcat-v2/medcat/utils/download_scripts.py b/medcat-v2/medcat/utils/download_scripts.py index ebbedb47e..862446393 100644 --- a/medcat-v2/medcat/utils/download_scripts.py +++ b/medcat-v2/medcat/utils/download_scripts.py @@ -8,6 +8,7 @@ import importlib.metadata import tempfile import zipfile +import os import sys from pathlib import Path import requests @@ -77,12 +78,13 @@ def _determine_url(overwrite_url: str | None, return zip_url -def _download_zip(zip_url: str, tmp: tempfile._TemporaryFileWrapper): - with requests.get(zip_url, stream=True, timeout=30) as r: - r.raise_for_status() - for chunk in r.iter_content(chunk_size=8192): - tmp.write(chunk) - tmp.flush() +def _download_zip(zip_url: str, tmp_path: str): + with open(tmp_path, 'wb') as tmp: + with requests.get(zip_url, stream=True, timeout=30) as r: + r.raise_for_status() + for chunk in r.iter_content(chunk_size=8192): + tmp.write(chunk) + tmp.flush() def _extract_zip(dest: Path, zip_path: Path): @@ -146,9 +148,10 @@ def fetch_scripts(destination: str | Path = ".", dest.mkdir(parents=True, exist_ok=True) zip_url = _determine_url(overwrite_url, overwrite_tag) - with tempfile.NamedTemporaryFile() as tmp: - _download_zip(zip_url, tmp) - _extract_zip(dest, Path(tmp.name)) + with tempfile.TemporaryDirectory() as tmp_dir: + zip_path = os.path.join(tmp_dir, 'downloaded_scripts.zip') + _download_zip(zip_url, zip_path) + _extract_zip(dest, Path(zip_path)) _fix_requirements(dest, _get_medcat_version()) logger.info( "You also need to install the requiements by doing:\n" diff --git a/medcat-v2/tests/utils/test_download_scripts.py b/medcat-v2/tests/utils/test_download_scripts.py index 56162c408..e9a75c79a 100644 --- a/medcat-v2/tests/utils/test_download_scripts.py +++ b/medcat-v2/tests/utils/test_download_scripts.py @@ -15,8 +15,16 @@ def setUpClass(cls): with unittest.mock.patch( "medcat.utils.download_scripts._get_medcat_version" ) as mock_get_version: - mock_get_version.return_value = cls.use_version - cls.scripts_path = download_scripts.fetch_scripts(cls._temp_dir.name) + with unittest.mock.patch( + "medcat.utils.download_scripts._find_latest_scripts_tag" + ) as mock_get_tag: + mock_get_version.return_value = cls.use_version + mock_get_tag.return_value = f"medcat/v{cls.use_version}" + cls.scripts_path = download_scripts.fetch_scripts(cls._temp_dir.name) + + @classmethod + def tearDownClass(cls): + cls._temp_dir.cleanup() def test_can_download(self): self.assertTrue(os.path.exists(self.scripts_path))