Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan6419846 committed Nov 6, 2023
1 parent ec05285 commit bd7c9b7
Show file tree
Hide file tree
Showing 7 changed files with 755 additions and 36 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ jobs:
- name: install package
run:
python -m pip install .[dev]
# - name: test
# run:
# python -m unittest discover --verbose --start-directory tests/
- name: test
run:
python -m unittest discover --verbose --start-directory tests/
- name: lint
run:
flake8
- name: mypy
run:
mypy --strict license_tools/
mypy --strict license_tools/ tests/
if: ${{ matrix.python != '3.7' }}
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Development version

* Switch to *mypy* strict mode.
* Add unit tests.
* Fix handling of license clues.

# Version 0.3.2 - 2023-08-21

* Fix type hints.
Expand Down
111 changes: 79 additions & 32 deletions license_tools/scancode_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import math
import shutil
import subprocess
import sys
import zipfile
from collections import defaultdict
from dataclasses import dataclass, field as dataclass_field
Expand Down Expand Up @@ -74,9 +75,9 @@ class Copyrights:
authors: list[Author] = dataclass_field(default_factory=list)

def __post_init__(self) -> None:
self.copyrights = [Copyright(**x) for x in self.copyrights] # type: ignore[arg-type]
self.holders = [Holder(**x) for x in self.holders] # type: ignore[arg-type]
self.authors = [Author(**x) for x in self.authors] # type: ignore[arg-type]
self.copyrights = [Copyright(**x) if not isinstance(x, Copyright) else x for x in self.copyrights] # type: ignore[arg-type]
self.holders = [Holder(**x) if not isinstance(x, Holder) else x for x in self.holders] # type: ignore[arg-type]
self.authors = [Author(**x) if not isinstance(x, Author) else x for x in self.authors] # type: ignore[arg-type]


@dataclass
Expand All @@ -99,7 +100,7 @@ class Emails:
emails: list[Email] = dataclass_field(default_factory=list)

def __post_init__(self) -> None:
self.emails = [Email(**x) for x in self.emails] # type: ignore[arg-type]
self.emails = [Email(**x) if not isinstance(x, Email) else x for x in self.emails] # type: ignore[arg-type]


@dataclass
Expand All @@ -122,7 +123,7 @@ class Urls:
urls: list[Url] = dataclass_field(default_factory=list)

def __post_init__(self) -> None:
self.urls = [Url(**x) for x in self.urls] # type: ignore[arg-type]
self.urls = [Url(**x) if not isinstance(x, Url) else x for x in self.urls] # type: ignore[arg-type]


@dataclass
Expand Down Expand Up @@ -166,21 +167,30 @@ class LicenseMatch:
license_expression: str
rule_identifier: str
rule_relevance: int
rule_url: str
rule_url: str | None


@dataclass
class LicenseClue(LicenseMatch):
"""
Enriched matching information about a license.
"""

pass


@dataclass
class LicenseDetection:
"""
Information an a specific detected license.
Information on a specific detected license.
"""

license_expression: str
identifier: str
matches: list[LicenseMatch] = dataclass_field(default_factory=list)

def __post_init__(self) -> None:
self.matches = [LicenseMatch(**x) for x in self.matches] # type: ignore[arg-type]
self.matches = [LicenseMatch(**x) if not isinstance(x, LicenseMatch) else x for x in self.matches] # type: ignore[arg-type]


@dataclass
Expand All @@ -189,15 +199,18 @@ class Licenses:
Information on all detected licenses.
"""

detected_license_expression: str
detected_license_expression_spdx: str
detected_license_expression: str | None
detected_license_expression_spdx: str | None
percentage_of_license_text: float
license_detections: list[LicenseDetection] = dataclass_field(default_factory=list)
license_clues: list[str] = dataclass_field(default_factory=list)
license_clues: list[LicenseClue] = dataclass_field(default_factory=list)

def __post_init__(self) -> None:
self.license_detections = [
LicenseDetection(**x) for x in self.license_detections # type: ignore[arg-type]
LicenseDetection(**x) if not isinstance(x, LicenseDetection) else x for x in self.license_detections # type: ignore[arg-type]
]
self.license_clues = [
LicenseClue(**x) if not isinstance(x, LicenseClue) else x for x in self.license_clues # type: ignore[arg-type]
]

def get_scores_of_detected_license_expression_spdx(self) -> list[float]:
Expand Down Expand Up @@ -286,6 +299,19 @@ def to_int(
cls.LDD_DATA * retrieve_ldd_data
)

@classmethod
def all(cls, as_kwargs: bool = False) -> int | dict[str, bool]:
"""
Utility method to enable all flags.
:param: If enabled, return kwargs instead of the integer value.
:return: The value for all flags enabled.
"""
value = cls.to_int(True, True, True, True, True)
if as_kwargs:
return cls.to_kwargs(value)
return value

@classmethod
def is_set(cls, flags: int, flag: int) -> bool:
"""
Expand Down Expand Up @@ -321,10 +347,11 @@ def check_shared_objects(path: Path) -> str | None:
:param path: The file path to analyze.
:return: The analysis results if the path points to a shared object, `None` otherwise.
"""
if path.suffix != '.so' and not (path.suffixes and path.suffixes[0] == '.so'):
# TODO: Handle binary files here as well (like `/usr/bin/bc`).
if path.suffix != ".so" and not (path.suffixes and path.suffixes[0] == ".so"):
return None
output = subprocess.check_output(['ldd', path], stderr=subprocess.PIPE)
return output.decode('UTF-8')
output = subprocess.check_output(["ldd", path], stderr=subprocess.PIPE)
return output.decode("UTF-8")


def run_on_file(
Expand All @@ -343,10 +370,10 @@ def run_on_file(
retrieval_kwargs = RetrievalFlags.to_kwargs(flags=retrieval_flags)

# This data is not yet part of the dataclasses above, as it is a custom analysis.
if retrieval_kwargs.pop('retrieve_ldd_data'):
if retrieval_kwargs.pop("retrieve_ldd_data"):
results = check_shared_objects(path=path)
if results:
print(short_path + '\n' + results)
print(short_path + "\n" + results)

# Register this here as each parallel process has its own directory.
atexit.register(cleanup, scancode_config.scancode_temp_dir)
Expand All @@ -359,6 +386,24 @@ def run_on_file(
)


def get_files_from_directory(directory: str | Path) -> Generator[tuple[Path, str], None, None]:
"""
Get the files from the given directory, recursively.
:param directory: The directory to walk through.
:return: For each file, the complete Path object as well as the path string
relative to the given directory.
"""
directory_string = str(directory)
common_prefix_length = len(directory_string) + int(not directory_string.endswith("/"))

for path in sorted(Path(directory).rglob("*"), key=str):
if path.is_dir():
continue
distribution_path = str(path)[common_prefix_length:]
yield path, distribution_path


def run_on_directory(
directory: str,
job_count: int = 4,
Expand All @@ -367,26 +412,17 @@ def run_on_directory(
"""
Run the analysis on the given directory.
:param path: The directory to analyze.
:param directory: The directory to analyze.
:param job_count: The number of parallel jobs to use.
:param retrieval_flags: Values to retrieve.
:return: The requested results per file.
"""
common_prefix_length = len(directory) + int(not directory.endswith("/"))

def get_paths() -> Generator[tuple[Path, str], None, None]:
for path in sorted(Path(directory).rglob("*"), key=str):
if path.is_dir():
continue
distribution_path = str(path)[common_prefix_length:]
yield path, distribution_path

results = Parallel(n_jobs=job_count)(
delayed(run_on_file)(
path=path,
short_path=short_path,
retrieval_flags=retrieval_flags,
) for path, short_path in get_paths()
) for path, short_path in get_files_from_directory(directory)
)
yield from results

Expand All @@ -399,7 +435,7 @@ def run_on_package_archive_file(
"""
Run the analysis on the given package archive file.
:param path: The package archive path to analyze.
:param archive_path: The package archive path to analyze.
:param job_count: The number of parallel jobs to use.
:param retrieval_flags: Values to retrieve.
:return: The requested results.
Expand Down Expand Up @@ -435,6 +471,8 @@ def run_on_downloaded_package_file(
"""
with TemporaryDirectory() as download_directory:
command = [
sys.executable,
"-m",
"pip",
"download",
"--no-deps",
Expand All @@ -444,7 +482,14 @@ def run_on_downloaded_package_file(
]
if index_url:
command += ["--index-url", index_url]
subprocess.check_output(command)
try:
subprocess.run(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE, check=True)
except subprocess.CalledProcessError as exception:
if exception.stdout:
sys.stdout.write(exception.stdout.decode("UTF-8"))
if exception.stderr:
sys.stderr.write(exception.stderr.decode("UTF-8"))
raise
name = list(Path(download_directory).glob("*"))[0]
yield from run_on_package_archive_file(
archive_path=name.resolve(),
Expand Down Expand Up @@ -505,9 +550,9 @@ def run(

assert _check_that_exactly_one_value_is_set(
[directory, file_path, archive_path, package_definition]
), 'Exactly one source is required.'
), "Exactly one source is required."

license_counts: dict[str, int] = defaultdict(int)
license_counts: dict[str | None, int] = defaultdict(int)
retrieval_flags = RetrievalFlags.to_int(
retrieve_copyrights=retrieve_copyrights,
retrieve_emails=retrieve_emails,
Expand Down Expand Up @@ -550,6 +595,8 @@ def run(
retrieval_flags=retrieval_flags,
)
]
else:
return []

# Display the file-level results.
max_path_length = max(len(result.short_path) for result in results)
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
'flake8-bugbear',
'pep8-naming',
'mypy',
'requests',
'types-requests',
]
},
classifiers=[
Expand Down
Empty file added tests/__init__.py
Empty file.
Loading

0 comments on commit bd7c9b7

Please sign in to comment.