Skip to content

Commit

Permalink
Add package contraints to torchbench
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Jun 19, 2024
1 parent 8ab8a3e commit ef770cf
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 18 deletions.
19 changes: 15 additions & 4 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path

from userbenchmark import list_userbenchmarks
from utils import get_pkg_versions, TORCH_DEPS
from utils import get_pkg_versions, TORCH_DEPS, generate_pkg_constraints

REPO_ROOT = Path(__file__).parent

Expand Down Expand Up @@ -38,6 +38,11 @@ def pip_install_requirements(requirements_txt="requirements.txt"):
action="store_true",
help="Run in test mode and check package versions",
)
parser.add_argument(
"--check-only",
action="store_true",
help="Only run the version check and generate the contraints"
)
parser.add_argument("--canary", action="store_true", help="Install canary model.")
parser.add_argument("--continue_on_fail", action="store_true")
parser.add_argument("--verbose", "-v", action="store_true")
Expand All @@ -51,12 +56,12 @@ def pip_install_requirements(requirements_txt="requirements.txt"):
os.chdir(os.path.realpath(os.path.dirname(__file__)))

print(
f"checking packages {', '.join(TORCH_DEPS)} are installed...",
f"checking packages {', '.join(TORCH_DEPS)} are installed, generating constaints...",
end="",
flush=True,
)
if args.userbenchmark:
TORCH_DEPS = ["torch"]
TORCH_DEPS = ["numpy", "torch"]
try:
versions = get_pkg_versions(TORCH_DEPS)
except ModuleNotFoundError as e:
Expand All @@ -65,8 +70,12 @@ def pip_install_requirements(requirements_txt="requirements.txt"):
f"Error: Users must first manually install packages {TORCH_DEPS} before installing the benchmark."
)
sys.exit(-1)
generate_pkg_constraints(versions)
print("OK")

if args.check_only:
exit(0)

if args.userbenchmark:
# Install userbenchmark dependencies if exists
userbenchmark_dir = REPO_ROOT.joinpath("userbenchmark", args.userbenchmark)
Expand Down Expand Up @@ -101,7 +110,9 @@ def pip_install_requirements(requirements_txt="requirements.txt"):
new_versions = get_pkg_versions(TORCH_DEPS)
if versions != new_versions:
print(
f"The torch packages are re-installed after installing the benchmark deps. \
f"The numpy and torch package versions become inconsistent after installing the benchmark deps. \
Before: {versions}, after: {new_versions}"
)
sys.exit(-1)
else:
print(f"installed torchbench with package constraints: {versions}")
4 changes: 2 additions & 2 deletions torchbenchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ def setup(
versions = get_pkg_versions(TORCH_DEPS)
success, errmsg, stdout_stderr = _install_deps(model_path, verbose=verbose)
if test_mode:
new_versions = get_pkg_versions(TORCH_DEPS, reload=True)
new_versions = get_pkg_versions(TORCH_DEPS)
if versions != new_versions:
print(
f"The torch packages are re-installed after installing the benchmark model {model_path}. \
f"The numpy and torch packages are re-installed after installing the benchmark model {model_path}. \
Before: {versions}, after: {new_versions}"
)
sys.exit(-1)
Expand Down
9 changes: 5 additions & 4 deletions torchbenchmark/util/env_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
This file may be loaded without torch packages installed, e.g., in OnDemand CI.
"""

import argparse
import copy
import importlib
import os
import shutil
import argparse
Expand Down Expand Up @@ -187,10 +185,13 @@ def deterministic_torch_manual_seed(*args, **kwargs):


def get_pkg_versions(packages: List[str]) -> Dict[str, str]:
import sys
import subprocess
versions = {}
for module in packages:
module = importlib.import_module(module)
versions[module] = module.__version__
cmd = [sys.executable, "-c", f'import {module}; print({module}.__version__)']
version = subprocess.check_output(cmd).decode().strip()
versions[module] = version
return versions


Expand Down
26 changes: 18 additions & 8 deletions utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import importlib
import sys
import subprocess
from typing import Dict, List
from pathlib import Path

TORCH_DEPS = ["torch", "torchvision", "torchaudio"]
REPO_DIR = Path(__file__).parent.parent
TORCH_DEPS = ["numpy", "torch", "torchvision", "torchaudio"]


class add_path:
Expand All @@ -18,12 +20,20 @@ def __exit__(self, exc_type, exc_value, traceback):
except ValueError:
pass


def get_pkg_versions(packages: List[str], reload: bool = False) -> Dict[str, str]:
def get_pkg_versions(packages: List[str]) -> Dict[str, str]:
versions = {}
for module in packages:
module = importlib.import_module(module)
if reload:
module = importlib.reload(module)
versions[module.__name__] = module.__version__
cmd = [sys.executable, "-c", f'import {module}; print({module}.__version__)']
version = subprocess.check_output(cmd).decode().strip()
versions[module] = version
return versions

def generate_pkg_constraints(package_versions: Dict[str, str]):
"""
Generate package versions dict and save them to {REPO_ROOT}/build/constraints.txt
"""
output_dir = REPO_DIR.joinpath("build")
output_dir.mkdir(exist_ok=True)
with open(output_dir.joinpath("constraints.txt"), "w") as fp:
for k, v in package_versions.items():
fp.write(f"{k}=={v}\n")

0 comments on commit ef770cf

Please sign in to comment.