diff --git a/vulnerabilities/pipelines/v2_improvers/collect_patch_texts.py b/vulnerabilities/pipelines/v2_improvers/collect_patch_texts.py new file mode 100644 index 000000000..8d6e06f50 --- /dev/null +++ b/vulnerabilities/pipelines/v2_improvers/collect_patch_texts.py @@ -0,0 +1,91 @@ +# +# Copyright (c) nexB Inc. and others. All rights reserved. +# VulnerableCode is a trademark of nexB Inc. +# SPDX-License-Identifier: Apache-2.0 +# See http://www.apache.org/licenses/LICENSE-2.0 for the license text. +# See https://github.com/aboutcode-org/vulnerablecode for support or download. +# See https://aboutcode.org for more information about nexB OSS projects. +# + +import logging + +import requests +from aboutcode.pipeline import LoopProgress +from django.db.models import Q + +from vulnerabilities.models import Patch +from vulnerabilities.pipelines import VulnerableCodePipeline + + +class CollectPatchTextsPipeline(VulnerableCodePipeline): + """ + Improver pipeline to collect missing patch texts for Patch objects that have a patch_url. + """ + + pipeline_id = "collect_patch_texts_v2" + license_expression = None + + @classmethod + def steps(cls): + return (cls.collect_and_store_patch_texts,) + + def collect_and_store_patch_texts(self): + patches_without_text = Patch.objects.filter( + Q(patch_url__isnull=False) & ~Q(patch_url=""), + Q(patch_text__isnull=True) | Q(patch_text=""), + ) + + self.log(f"Processing {patches_without_text.count():,d} patches to collect text.") + + updated_patch_count = 0 + progress = LoopProgress(total_iterations=patches_without_text.count(), logger=self.log) + + for patch in progress.iter(patches_without_text.iterator(chunk_size=500)): + raw_url = get_raw_patch_url(patch.patch_url) + if not raw_url: + continue + + try: + response = requests.get(raw_url, timeout=10) + if response.status_code == 200: + patch.patch_text = response.text + patch.save() + updated_patch_count += 1 + else: + self.log( + f"Failed to fetch patch from {raw_url}: Status {response.status_code}", + level=logging.WARNING if response.status_code < 500 else logging.ERROR, + ) + except requests.RequestException as e: + self.log(f"Error fetching patch from {raw_url}: {e}", level=logging.ERROR) + + self.log(f"Successfully collected text for {updated_patch_count:,d} Patch entries.") + + +def get_raw_patch_url(url): + """ + Return a fetchable raw patch URL from common VCS hosting URLs, + or the URL itself if it already points to a .patch or .diff file. + Return None if the URL type is not recognized. + """ + if not url: + return None + + url = url.strip() + + if "github.com" in url and "/commit/" in url and not url.endswith(".patch"): + return f"{url}.patch" + + if "github.com" in url and "/pull/" in url and not url.endswith(".patch"): + return f"{url}.patch" + + if "gitlab.com" in url and "/commit/" in url and not url.endswith(".patch"): + return f"{url}.patch" + + if "gitlab.com" in url and "/merge_requests/" in url and not url.endswith(".patch"): + return f"{url}.patch" + + if url.endswith(".patch") or url.endswith(".diff"): + return url + + return None diff --git a/vulnerabilities/tests/pipelines/v2_improvers/test_collect_patch_texts.py b/vulnerabilities/tests/pipelines/v2_improvers/test_collect_patch_texts.py new file mode 100644 index 000000000..0b208d850 --- /dev/null +++ b/vulnerabilities/tests/pipelines/v2_improvers/test_collect_patch_texts.py @@ -0,0 +1,86 @@ +# +# Copyright (c) nexB Inc. and others. All rights reserved. +# VulnerableCode is a trademark of nexB Inc. +# SPDX-License-Identifier: Apache-2.0 +# See http://www.apache.org/licenses/LICENSE-2.0 for the license text. +# See https://github.com/aboutcode-org/vulnerablecode for support or download. +# See https://aboutcode.org for more information about nexB OSS projects. +# + +import unittest +from unittest.mock import MagicMock +from unittest.mock import patch as mock_patch + +from vulnerabilities.pipelines.v2_improvers.collect_patch_texts import CollectPatchTextsPipeline +from vulnerabilities.pipelines.v2_improvers.collect_patch_texts import get_raw_patch_url + + +class TestCollectPatchTextsPipeline(unittest.TestCase): + def setUp(self): + self.pipeline = CollectPatchTextsPipeline() + + def test_get_raw_patch_url(self): + url = "https://github.com/user/repo/commit/abc1234567890" + expected = "https://github.com/user/repo/commit/abc1234567890.patch" + self.assertEqual(get_raw_patch_url(url), expected) + + url = "https://github.com/user/repo/pull/123" + expected = "https://github.com/user/repo/pull/123.patch" + self.assertEqual(get_raw_patch_url(url), expected) + + url = "https://gitlab.com/user/repo/-/commit/abc1234567890" + expected = "https://gitlab.com/user/repo/-/commit/abc1234567890.patch" + self.assertEqual(get_raw_patch_url(url), expected) + + url = "https://gitlab.com/user/repo/-/merge_requests/123" + expected = "https://gitlab.com/user/repo/-/merge_requests/123.patch" + self.assertEqual(get_raw_patch_url(url), expected) + + url = "https://example.com/fix.patch" + self.assertEqual(get_raw_patch_url(url), url) + + url = "https://example.com/some/article" + self.assertIsNone(get_raw_patch_url(url)) + + @mock_patch("vulnerabilities.pipelines.v2_improvers.collect_patch_texts.Patch") + @mock_patch("requests.get") + def test_collect_and_store_patch_texts(self, mock_get, mock_patch_model): + p1 = MagicMock(patch_url="https://github.com/u/r/commit/c1", patch_text=None) + p2 = MagicMock(patch_url="https://github.com/u/r/pull/1", patch_text="") + p3 = MagicMock(patch_url="https://example.com/no-patch", patch_text=None) + p4 = MagicMock(patch_url="https://example.com/fix.patch", patch_text=None) + + mock_qs = MagicMock() + mock_qs.count.return_value = 4 + mock_qs.iterator.return_value = [p1, p2, p3, p4] + + mock_patch_model.objects.filter.return_value = mock_qs + + def side_effect(url, timeout=10): + mock_resp = MagicMock() + mock_resp.status_code = 404 + if url == "https://github.com/u/r/commit/c1.patch": + mock_resp.status_code = 200 + mock_resp.text = "diff --git a/file b/file\n+code" + elif url == "https://github.com/u/r/pull/1.patch": + mock_resp.status_code = 200 + mock_resp.text = "diff --git a/pr b/pr\n+pr_code" + elif url == "https://example.com/fix.patch": + mock_resp.status_code = 200 + mock_resp.text = "diff --git a/direct b/direct\n+direct_code" + return mock_resp + + mock_get.side_effect = side_effect + + self.pipeline.collect_and_store_patch_texts() + + self.assertEqual(p1.patch_text, "diff --git a/file b/file\n+code") + p1.save.assert_called_once() + + self.assertEqual(p2.patch_text, "diff --git a/pr b/pr\n+pr_code") + p2.save.assert_called_once() + + p3.save.assert_not_called() + + self.assertEqual(p4.patch_text, "diff --git a/direct b/direct\n+direct_code") + p4.save.assert_called_once()