Skip to content

Commit

Permalink
Use class for returning versions partitioned by cutoff time
Browse files Browse the repository at this point in the history
Signed-off-by: Shivam Sandbhor <[email protected]>
  • Loading branch information
sbs2001 committed Jun 13, 2021
1 parent e8f0a57 commit 33d8daa
Show file tree
Hide file tree
Showing 16 changed files with 53 additions and 41 deletions.
4 changes: 3 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
[pytest]
DJANGO_SETTINGS_MODULE = vulnerablecode.settings
DJANGO_SETTINGS_MODULE = vulnerablecode.settings
markers =
webtest
5 changes: 2 additions & 3 deletions vulnerabilities/importers/apache_httpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from vulnerabilities.data_source import VulnerabilitySeverity
from vulnerabilities.package_managers import GitHubTagsAPI
from vulnerabilities.severity_systems import scoring_systems
from vulnerabilities.helpers import create_etag
from vulnerabilities.helpers import nearest_patched_package


Expand Down Expand Up @@ -106,7 +105,7 @@ def to_advisory(self, data):
fixed_packages.extend(
[
PackageURL(type="apache", name="httpd", version=version)
for version in self.version_api.get("apache/httpd")["valid"]
for version in self.version_api.get("apache/httpd").valid_versions
if MavenVersion(version) in version_range
]
)
Expand All @@ -115,7 +114,7 @@ def to_advisory(self, data):
affected_packages.extend(
[
PackageURL(type="apache", name="httpd", version=version)
for version in self.version_api.get("apache/httpd")["valid"]
for version in self.version_api.get("apache/httpd").valid_versions
if MavenVersion(version) in version_range
]
)
Expand Down
4 changes: 2 additions & 2 deletions vulnerabilities/importers/apache_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def to_advisory(self, advisory_page):

fixed_packages = [
PackageURL(type="apache", name="kafka", version=version)
for version in self.version_api.get("apache/kafka")["valid"]
for version in self.version_api.get("apache/kafka").valid_versions
if any(
[
MavenVersion(version) in version_range
Expand All @@ -83,7 +83,7 @@ def to_advisory(self, advisory_page):

affected_packages = [
PackageURL(type="apache", name="kafka", version=version)
for version in self.version_api.get("apache/kafka")["valid"]
for version in self.version_api.get("apache/kafka").valid_versions
if any(
[
MavenVersion(version) in version_range
Expand Down
6 changes: 4 additions & 2 deletions vulnerabilities/importers/apache_tomcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def updated_advisories(self):

def fetch_pages(self):
tomcat_major_versions = {
i[0] for i in self.version_api.get("org.apache.tomcat:tomcat")["valid"]
i[0] for i in self.version_api.get("org.apache.tomcat:tomcat").valid_versions
}
for version in tomcat_major_versions:
page_url = self.base_url.format(version)
Expand Down Expand Up @@ -104,7 +104,9 @@ def to_advisories(self, apache_tomcat_advisory_html):
PackageURL(
type="maven", namespace="apache", name="tomcat", version=version
)
for version in self.version_api.get("org.apache.tomcat:tomcat")["valid"]
for version in self.version_api.get(
"org.apache.tomcat:tomcat"
).valid_versions
if MavenVersion(version) in version_range
]
)
Expand Down
2 changes: 1 addition & 1 deletion vulnerabilities/importers/elixir_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_versions_for_pkg_from_range_list(self, version_range_list, pkg_name):

safe_pkg_versions = []
vuln_pkg_versions = []
all_version_list = self.pkg_manager_api.get(pkg_name)["valid"]
all_version_list = self.pkg_manager_api.get(pkg_name).valid_versions
if not version_range_list:
return [], all_version_list
version_ranges = [
Expand Down
2 changes: 1 addition & 1 deletion vulnerabilities/importers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def process_response(self) -> List[Advisory]:
aff_vers, unaff_vers = self.categorize_versions(
self.version_api.package_type,
aff_range,
self.version_api.get(name, until=cutoff_time)["valid"],
self.version_api.get(name, until=cutoff_time).valid_versions,
)
affected_purls = [
PackageURL(name=pkg_name, namespace=ns, version=version, type=pkg_type)
Expand Down
2 changes: 1 addition & 1 deletion vulnerabilities/importers/istio.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_pkg_versions_from_ranges(self, version_range_list, release_date):
"""Takes a list of version ranges(affected) of a package
as parameter and returns a tuple of safe package versions and
vulnerable package versions"""
all_version = self.version_api.get("istio/istio", release_date)["valid"]
all_version = self.version_api.get("istio/istio", release_date).valid_versions
safe_pkg_versions = []
vuln_pkg_versions = []
version_ranges = [
Expand Down
10 changes: 6 additions & 4 deletions vulnerabilities/importers/nginx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ def set_api(self):

# For some reason nginx tags it's releases are in the form of `release-1.2.3`
# Chop off the `release-` part here.
for index, version in enumerate(self.version_api.cache["nginx/nginx"]["valid"]):
self.version_api.cache["nginx/nginx"]["valid"][index] = version.replace("release-", "")
for index, version in enumerate(self.version_api.cache["nginx/nginx"].valid_versions):
self.version_api.cache["nginx/nginx"].valid_versions[index] = version.replace(
"release-", ""
)

def updated_advisories(self):
advisories = []
Expand Down Expand Up @@ -135,7 +137,7 @@ def extract_fixed_pkgs(self, vuln_info):
)

valid_versions = find_valid_versions(
self.version_api.get("nginx/nginx")["valid"], version_ranges
self.version_api.get("nginx/nginx").valid_versions, version_ranges
)

return [
Expand Down Expand Up @@ -172,7 +174,7 @@ def extract_vuln_pkgs(self, vuln_info):
)

valid_versions = find_valid_versions(
self.version_api.get("nginx/nginx")["valid"], version_ranges
self.version_api.get("nginx/nginx").valid_versions, version_ranges
)
qualifiers = {}
if windows_only:
Expand Down
2 changes: 1 addition & 1 deletion vulnerabilities/importers/npm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def process_file(self, file) -> List[Advisory]:
publish_date = parse(record["updated_at"])
publish_date.replace(tzinfo=pytz.UTC)

all_versions = self.versions.get(package_name, until=publish_date)["valid"]
all_versions = self.versions.get(package_name, until=publish_date).valid_versions
aff_range = record.get("vulnerable_versions")
if not aff_range:
aff_range = ""
Expand Down
5 changes: 1 addition & 4 deletions vulnerabilities/importers/ruby.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,7 @@ def process_file(self, path) -> List[Advisory]:

if not getattr(self, "pkg_manager_api", None):
self.pkg_manager_api = RubyVersionAPI()
all_vers = self.pkg_manager_api.get(package_name, until=publish_time)["valid"]
print(
f"Ignored {len(self.pkg_manager_api.get(package_name,until=publish_time)['new'])} versions"
)
all_vers = self.pkg_manager_api.get(package_name, until=publish_time).valid_versions
safe_versions, affected_versions = self.categorize_versions(all_vers, safe_version_ranges)

impacted_purls = [
Expand Down
2 changes: 1 addition & 1 deletion vulnerabilities/importers/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _load_advisory(self, path: str) -> Optional[Advisory]:
references.append(Reference(url=advisory["url"]))

publish_date = parse(advisory["date"]).replace(tzinfo=pytz.UTC)
all_versions = self.crates_api.get(crate_name, publish_date)["valid"]
all_versions = self.crates_api.get(crate_name, publish_date).valid_versions

# FIXME: Avoid wildcard version ranges for now.
# See https://github.com/RustSec/advisory-db/discussions/831
Expand Down
2 changes: 1 addition & 1 deletion vulnerabilities/importers/safety_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def updated_advisories(self) -> Set[Advisory]:
logger.error(e)
continue

all_package_versions = self.versions.get(package_name)["valid"]
all_package_versions = self.versions.get(package_name).valid_versions
if not len(all_package_versions):
# PyPi does not have data about this package, we skip these
continue
Expand Down
17 changes: 13 additions & 4 deletions vulnerabilities/package_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dateutil import parser
from json import JSONDecodeError
from typing import Mapping
from typing import Tuple
from typing import Set
from datetime import datetime

Expand All @@ -41,19 +42,27 @@ class Version:
release_date: datetime = None


@dataclasses.dataclass(frozen=True)
class VersionResponse:
valid_versions: Set[str] = dataclasses.field(default_factory=set)
newer_versions: Set[str] = dataclasses.field(default_factory=set)


@dataclasses.dataclass(frozen=True)
class VersionAPI:
def __init__(self, cache: Mapping[str, Set[str]] = None):
self.cache = cache or {}

def get(self, package_name, until=None) -> Set[str]:
versions = {"new": set(), "valid": set()}
new_versions = set()
valid_versions = set()
for version in self.cache.get(package_name, set()):
if until and version.release_date and version.release_date > until:
versions["new"].add(version.value)
new_versions.add(version.value)
continue
versions["valid"].add(version.value)
valid_versions.add(version.value)

return versions
return VersionResponse(valid_versions=valid_versions, newer_versions=new_versions)


def client_session():
Expand Down
22 changes: 11 additions & 11 deletions vulnerabilities/tests/test_package_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@
# Visit https://github.com/nexB/vulnerablecode/ for support and download.

import asyncio
import json
import os
from datetime import datetime
from bs4 import BeautifulSoup
from dateutil.tz import tzlocal
from pytz import UTC
import os
import json
from unittest import TestCase
from unittest.mock import AsyncMock
import xml.etree.ElementTree as ET

from vulnerabilities.package_managers import ComposerVersionAPI
from vulnerabilities.package_managers import MavenVersionAPI
from vulnerabilities.package_managers import NugetVersionAPI
from vulnerabilities.package_managers import Version
from vulnerabilities.package_managers import VersionResponse

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
TEST_DATA = os.path.join(BASE_DIR, "test_data")
Expand Down Expand Up @@ -343,7 +343,7 @@ def test_extract_versions(self):

def test_fetch(self):

assert self.version_api.get("typo3/cms-core") == {"valid": set(), "new": set()}
assert self.version_api.get("typo3/cms-core") == VersionResponse()
client_session = MockClientSession(self.response)
asyncio.run(self.version_api.fetch("typo3/cms-core", client_session))
assert self.version_api.cache["typo3/cms-core"] == self.expected_versions
Expand All @@ -355,7 +355,7 @@ def setUpClass(cls):
cls.version_api = MavenVersionAPI()
with open(os.path.join(TEST_DATA, "maven_api", "easygcm.html"), "rb") as f:
data = f.read()
cls.response = BeautifulSoup(data)
cls.response = BeautifulSoup(data, features="lxml")
cls.content = data

def test_artifact_url(self):
Expand All @@ -377,7 +377,7 @@ def test_extract_versions(self):
assert expected_versions == self.version_api.extract_versions(self.response)

def test_fetch(self):
assert self.version_api.get("org.apache:kafka") == {"new": set(), "valid": set()}
assert self.version_api.get("org.apache:kafka") == VersionResponse()
expected = {
Version(value="1.2.2", release_date=datetime(2014, 12, 22, 10, 29, tzinfo=UTC)),
Version(value="1.3.0", release_date=datetime(2015, 3, 12, 15, 20, tzinfo=UTC)),
Expand Down Expand Up @@ -467,12 +467,12 @@ def test_extract_versions(self):

def test_fetch(self):

assert self.version_api.get("Exfat.Ntfs") == {"new": set(), "valid": set()}
assert self.version_api.get("Exfat.Ntfs") == VersionResponse()
client_session = MockClientSession(self.response)
asyncio.run(self.version_api.fetch("Exfat.Ntfs", client_session))
assert self.version_api.get("Exfat.Ntfs") == {
"new": set(),
"valid": {
assert self.version_api.get("Exfat.Ntfs") == VersionResponse(
newer_versions=set(),
valid_versions={
"2.0.0",
"2.1.0",
"2.0.0-preview01",
Expand All @@ -488,7 +488,7 @@ def test_fetch(self):
"2.5.0",
"2.6.0",
},
}
)

# def test_load_to_api(self):
# assert self.version_api.get("Exfat.Ntfs") == set()
Expand Down
7 changes: 4 additions & 3 deletions vulnerabilities/tests/test_ruby.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@
import pathlib
from unittest.mock import patch
from unittest import TestCase
from collections import OrderedDict

from packageurl import PackageURL

from vulnerabilities.importers.ruby import RubyDataSource
from vulnerabilities.data_source import GitDataSourceConfiguration
from vulnerabilities.data_source import Advisory
from vulnerabilities.data_source import Reference
from vulnerabilities.package_managers import RubyVersionAPI
from vulnerabilities.package_managers import VersionResponse
from vulnerabilities.helpers import AffectedPackage

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -55,7 +54,9 @@ def setUpClass(cls):

@patch(
"vulnerabilities.package_managers.RubyVersionAPI.get",
return_value={"valid": {"1.0.0", "1.8.0", "2.0.3"}, "new": {}},
return_value=VersionResponse(
valid_versions={"1.0.0", "1.8.0", "2.0.3"}, newer_versions=set()
),
)
def test_process_file(self, mock_write):
expected_advisories = [
Expand Down
2 changes: 1 addition & 1 deletion vulnerabilities/tests/test_rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@


def test_categorize_versions():
flatbuffers_versions = MOCKED_CRATES_API_VERSIONS.get("flatbuffers")["valid"]
flatbuffers_versions = MOCKED_CRATES_API_VERSIONS.get("flatbuffers").valid_versions

unaffected_ranges = [VersionSpecifier.from_scheme_version_spec_string("semver", "< 0.4.0")]
affected_ranges = [
Expand Down

0 comments on commit 33d8daa

Please sign in to comment.