Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions medcat-v2/medcat/utils/download_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import importlib.metadata
import tempfile
import zipfile
import os
import sys
from pathlib import Path
import requests
Expand Down Expand Up @@ -77,12 +78,13 @@ def _determine_url(overwrite_url: str | None,
return zip_url


def _download_zip(zip_url: str, tmp: tempfile._TemporaryFileWrapper):
with requests.get(zip_url, stream=True, timeout=30) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
tmp.write(chunk)
tmp.flush()
def _download_zip(zip_url: str, tmp_path: str):
with open(tmp_path, 'wb') as tmp:
with requests.get(zip_url, stream=True, timeout=30) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
tmp.write(chunk)
tmp.flush()


def _extract_zip(dest: Path, zip_path: Path):
Expand Down Expand Up @@ -146,9 +148,10 @@ def fetch_scripts(destination: str | Path = ".",
dest.mkdir(parents=True, exist_ok=True)

zip_url = _determine_url(overwrite_url, overwrite_tag)
with tempfile.NamedTemporaryFile() as tmp:
_download_zip(zip_url, tmp)
_extract_zip(dest, Path(tmp.name))
with tempfile.TemporaryDirectory() as tmp_dir:
zip_path = os.path.join(tmp_dir, 'downloaded_scripts.zip')
_download_zip(zip_url, zip_path)
_extract_zip(dest, Path(zip_path))
_fix_requirements(dest, _get_medcat_version())
logger.info(
"You also need to install the requiements by doing:\n"
Expand Down
12 changes: 10 additions & 2 deletions medcat-v2/tests/utils/test_download_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@ def setUpClass(cls):
with unittest.mock.patch(
"medcat.utils.download_scripts._get_medcat_version"
) as mock_get_version:
mock_get_version.return_value = cls.use_version
cls.scripts_path = download_scripts.fetch_scripts(cls._temp_dir.name)
with unittest.mock.patch(
"medcat.utils.download_scripts._find_latest_scripts_tag"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree that this fixes it

I'm not 100% if we should mock stuff in the stability check though - maybe we can just add some retries with back off? Is there any chance that we actually break this but dont test it, eg we delete all the tags in github (accidentally...), this action still says success but the scripts are broken. Now I say that seems like 0% chance of that...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore my comment altogether actually - just saw that this is a regular test. 100% ignore my above, and carry on as you were

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, a regular test that also gets run during the stability workflow. But during that there's A LOT of them which means we get rate limited (apparently).

) as mock_get_tag:
mock_get_version.return_value = cls.use_version
mock_get_tag.return_value = f"medcat/v{cls.use_version}"
cls.scripts_path = download_scripts.fetch_scripts(cls._temp_dir.name)

@classmethod
def tearDownClass(cls):
cls._temp_dir.cleanup()

def test_can_download(self):
self.assertTrue(os.path.exists(self.scripts_path))
Expand Down
Loading