Skip to content

Commit

Permalink
Remove rocm libtorch pre-cxx11 ABI option from preview (#1857)
Browse files Browse the repository at this point in the history
* fix

* Don't update cxx11 abi for rocm

* fix
  • Loading branch information
atalman authored Dec 23, 2024
1 parent efca5dd commit 7998469
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 32 deletions.
1 change: 0 additions & 1 deletion published_versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@
"rocm5.x": {
"note": null,
"versions": {
"Download here (Pre-cxx11 ABI):": "https://download.pytorch.org/libtorch/nightly/rocm6.2.4/libtorch-shared-with-deps-latest.zip",
"Download here (cxx11 ABI):": "https://download.pytorch.org/libtorch/nightly/rocm6.2.4/libtorch-cxx11-abi-shared-with-deps-latest.zip"
}
}
Expand Down
99 changes: 68 additions & 31 deletions scripts/gen_quick_start_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,22 @@
published_version.json file
"""

import json
import copy
import argparse
import copy
import json
from enum import Enum
from pathlib import Path
from typing import Dict
from enum import Enum

BASE_DIR = Path(__file__).parent.parent
BASE_DIR = Path(__file__).parent.parent


class OperatingSystem(Enum):
LINUX: str = "linux"
WINDOWS: str = "windows"
MACOS: str = "macos"


PRE_CXX11_ABI = "pre-cxx11"
CXX11_ABI = "cxx11-abi"
DEBUG = "debug"
Expand All @@ -38,29 +40,30 @@ class OperatingSystem(Enum):
"cuda.x": ("cuda", "11.8"),
"cuda.y": ("cuda", "12.1"),
"cuda.z": ("cuda", "12.4"),
"rocm5.x": ("rocm", "6.0")
},
"rocm5.x": ("rocm", "6.0"),
},
"release": {
"accnone": ("cpu", ""),
"cuda.x": ("cuda", "11.8"),
"cuda.y": ("cuda", "12.1"),
"cuda.z": ("cuda", "12.4"),
"rocm5.x": ("rocm", "6.0")
}
}
"rocm5.x": ("rocm", "6.0"),
},
}

# Initialize arch version to default values
# these default values will be overwritten by
# extracted values from the release marix
acc_arch_ver_map = acc_arch_ver_default

LIBTORCH_DWNL_INSTR = {
PRE_CXX11_ABI: "Download here (Pre-cxx11 ABI):",
CXX11_ABI: "Download here (cxx11 ABI):",
RELEASE: "Download here (Release version):",
DEBUG: "Download here (Debug version):",
MACOS: "Download arm64 libtorch here (ROCm and CUDA are not supported):",
}
PRE_CXX11_ABI: "Download here (Pre-cxx11 ABI):",
CXX11_ABI: "Download here (cxx11 ABI):",
RELEASE: "Download here (Release version):",
DEBUG: "Download here (Debug version):",
MACOS: "Download arm64 libtorch here (ROCm and CUDA are not supported):",
}


def load_json_from_basedir(filename: str):
try:
Expand All @@ -71,32 +74,39 @@ def load_json_from_basedir(filename: str):
except json.JSONDecodeError as exc:
raise ImportError(f"Invalid JSON {filename}") from exc


def read_published_versions():
return load_json_from_basedir("published_versions.json")


def write_published_versions(versions):
with open(BASE_DIR / "published_versions.json", "w") as outfile:
json.dump(versions, outfile, indent=2)


def read_matrix_for_os(osys: OperatingSystem, channel: str):
jsonfile = load_json_from_basedir(f"{osys.value}_{channel}_matrix.json")
return jsonfile["include"]


def read_quick_start_module_template():
with open(BASE_DIR / "_includes" / "quick-start-module.js") as fptr:
return fptr.read()


def get_package_type(pkg_key: str, os_key: OperatingSystem) -> str:
if pkg_key != "pip":
return pkg_key
return "manywheel" if os_key == OperatingSystem.LINUX.value else "wheel"


def get_gpu_info(acc_key, instr, acc_arch_map):
gpu_arch_type, gpu_arch_version = acc_arch_map[acc_key]
if DEFAULT in instr:
gpu_arch_type, gpu_arch_version = acc_arch_map["accnone"]
return (gpu_arch_type, gpu_arch_version)


# This method is used for generating new published_versions.json file
# It will modify versions json object with installation instructions
# Provided by generate install matrix Github Workflow, stored in release_matrix
Expand All @@ -109,42 +119,62 @@ def update_versions(versions, release_matrix, release_version):
if release_version != "nightly":
version = release_matrix[OperatingSystem.LINUX.value][0]["stable_version"]
if version not in versions["versions"]:
versions["versions"][version] = copy.deepcopy(versions["versions"][template])
versions["versions"][version] = copy.deepcopy(
versions["versions"][template]
)
versions["latest_stable"] = version

# Perform update of the json file from release matrix
for os_key, os_vers in versions["versions"][version].items():
for pkg_key, pkg_vers in os_vers.items():
for acc_key, instr in pkg_vers.items():
package_type = get_package_type(pkg_key, os_key)
gpu_arch_type, gpu_arch_version = get_gpu_info(acc_key, instr, acc_arch_map)
gpu_arch_type, gpu_arch_version = get_gpu_info(
acc_key, instr, acc_arch_map
)

pkg_arch_matrix = [
x for x in release_matrix[os_key]
if (x["package_type"], x["gpu_arch_type"], x["gpu_arch_version"]) ==
(package_type, gpu_arch_type, gpu_arch_version)
]
x
for x in release_matrix[os_key]
if (x["package_type"], x["gpu_arch_type"], x["gpu_arch_version"])
== (package_type, gpu_arch_type, gpu_arch_version)
]

if pkg_arch_matrix:
if package_type != "libtorch":
instr["command"] = pkg_arch_matrix[0]["installation"]
else:
if os_key == OperatingSystem.LINUX.value:
rel_entry_dict = {
x["devtoolset"]: x["installation"] for x in pkg_arch_matrix
x["devtoolset"]: x["installation"]
for x in pkg_arch_matrix
if x["libtorch_variant"] == "shared-with-deps"
}
}
if instr["versions"] is not None:
for ver in [PRE_CXX11_ABI, CXX11_ABI]:
instr["versions"][LIBTORCH_DWNL_INSTR[ver]] = rel_entry_dict[ver]
if gpu_arch_type == "rocm" and ver == PRE_CXX11_ABI:
continue
else:
instr["versions"][LIBTORCH_DWNL_INSTR[ver]] = (
rel_entry_dict[ver]
)

elif os_key == OperatingSystem.WINDOWS.value:
rel_entry_dict = {x["libtorch_config"]: x["installation"] for x in pkg_arch_matrix}
rel_entry_dict = {
x["libtorch_config"]: x["installation"]
for x in pkg_arch_matrix
}
if instr["versions"] is not None:
for ver in [RELEASE, DEBUG]:
instr["versions"][LIBTORCH_DWNL_INSTR[ver]] = rel_entry_dict[ver]
instr["versions"][LIBTORCH_DWNL_INSTR[ver]] = (
rel_entry_dict[ver]
)
elif os_key == OperatingSystem.MACOS.value:
if instr["versions"] is not None:
instr["versions"][LIBTORCH_DWNL_INSTR[MACOS]] = pkg_arch_matrix[0]["installation"]
instr["versions"][LIBTORCH_DWNL_INSTR[MACOS]] = (
pkg_arch_matrix[0]["installation"]
)


# This method is used for generating new quick-start-module.js
# from the versions json object
Expand All @@ -158,21 +188,25 @@ def gen_install_matrix(versions) -> Dict[str, str]:
for os_key, os_vers in versions["versions"][ver_key].items():
for pkg_key, pkg_vers in os_vers.items():
for acc_key, instr in pkg_vers.items():
extra_key = 'python' if pkg_key != 'libtorch' else 'cplusplus'
extra_key = "python" if pkg_key != "libtorch" else "cplusplus"
key = f"{ver},{pkg_key},{os_key},{acc_key},{extra_key}"
note = instr["note"]
lines = [note] if note is not None else []
if pkg_key == "libtorch":
ivers = instr["versions"]
if ivers is not None:
lines += [f"{lab}<br /><a href='{val}'>{val}</a>" for (lab, val) in ivers.items()]
lines += [
f"{lab}<br /><a href='{val}'>{val}</a>"
for (lab, val) in ivers.items()
]
else:
command = instr["command"]
if command is not None:
lines.append(command)
result[key] = "<br />".join(lines)
return result


# This method is used for extracting two latest verisons of cuda and
# last verion of rocm. It will modify the acc_arch_ver_map object used
# to update getting started page.
Expand All @@ -195,7 +229,7 @@ def gen_ver_list(chan, gpu_arch_type):

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--autogenerate', dest='autogenerate', action='store_true')
parser.add_argument("--autogenerate", dest="autogenerate", action="store_true")
parser.set_defaults(autogenerate=True)

options = parser.parse_args()
Expand All @@ -217,8 +251,11 @@ def main():
template = read_quick_start_module_template()
versions_str = json.dumps(gen_install_matrix(versions))
template = template.replace("{{ installMatrix }}", versions_str)
template = template.replace("{{ VERSION }}", f"\"Stable ({versions['latest_stable']})\"")
template = template.replace(
"{{ VERSION }}", f"\"Stable ({versions['latest_stable']})\""
)
print(template.replace("{{ ACC ARCH MAP }}", json.dumps(acc_arch_ver_map)))


if __name__ == "__main__":
main()

0 comments on commit 7998469

Please sign in to comment.