-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: initial AMD GPU Support (#223)
* feat: initial AMD GPU Support After looking into flash attention support, only a few cards are supported. This check will prevent errors from appearing during the install. Everything still works with out it. style: fix updates to amd_go_fast to fit with coding standards * feat: `--amd` flag for amd specific optimizations * tests: reqs.rocm.txt consistency with reqs.txt check * style: fix * chore: update pre-commit torch pin * docs: improved readme; note improved amd support --------- Co-authored-by: tazlin <[email protected]>
- Loading branch information
Showing
18 changed files
with
315 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
name: ldm | ||
channels: | ||
- conda-forge | ||
- defaults | ||
# These should only contain the minimal essentials to get the binaries going, everything else is managed in requirements.txt to keep it universal. | ||
dependencies: | ||
- git | ||
- pip | ||
- python==3.11.6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#!/bin/bash | ||
# Get the directory of the current script | ||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" | ||
|
||
# Build the absolute path to the Conda environment | ||
CONDA_ENV_PATH="$SCRIPT_DIR/conda/envs/linux/lib" | ||
|
||
# Add the Conda environment to LD_LIBRARY_PATH | ||
export LD_LIBRARY_PATH="$CONDA_ENV_PATH:$LD_LIBRARY_PATH" | ||
|
||
# Set torch garbage cleanup. Amd defaults cause problems. | ||
export PYTORCH_HIP_ALLOC_CONF=garbage_collection_threshold:0.6,max_split_size_mb:2048 | ||
|
||
# List of directories to check | ||
dirs=( | ||
"/usr/lib" | ||
"/usr/local/lib" | ||
"/lib" | ||
"/lib64" | ||
"/usr/lib/x86_64-linux-gnu" | ||
) | ||
|
||
# Check each directory | ||
for dir in "${dirs[@]}"; do | ||
if [ -f "$dir/libjemalloc.so.2" ]; then | ||
export LD_PRELOAD="$dir/libjemalloc.so.2" | ||
printf "Using jemalloc from $dir\n" | ||
break | ||
fi | ||
done | ||
|
||
# If jemalloc was not found, print a warning | ||
if [ -z "$LD_PRELOAD" ]; then | ||
printf "WARNING: jemalloc not found. You may run into memory issues! We recommend running `sudo apt install libjemalloc2`\n" | ||
# Press q to quit or any other key to continue | ||
read -n 1 -s -r -p "Press q to quit or any other key to continue: " key | ||
if [ "$key" = "q" ]; then | ||
printf "\n" | ||
exit 1 | ||
fi | ||
fi | ||
|
||
|
||
if ./runtime-rocm.sh python -s download_models.py; then | ||
echo "Model Download OK. Starting worker..." | ||
./runtime-rocm.sh python -s run_worker.py --amd $* | ||
else | ||
echo "download_models.py exited with error code. Aborting" | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch | ||
from loguru import logger | ||
|
||
if "AMD" in torch.cuda.get_device_name() or "Radeon" in torch.cuda.get_device_name(): | ||
try: # this import is handled via script, skipping it in mypy. If this fails somehow the module will simply not run. | ||
from flash_attn import flash_attn_func # type: ignore | ||
|
||
sdpa = torch.nn.functional.scaled_dot_product_attention | ||
|
||
def sdpa_hijack(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): | ||
if query.shape[3] <= 128 and attn_mask is None and query.dtype != torch.float32: | ||
hidden_states = flash_attn_func( | ||
q=query.transpose(1, 2), | ||
k=key.transpose(1, 2), | ||
v=value.transpose(1, 2), | ||
dropout_p=dropout_p, | ||
causal=is_causal, | ||
softmax_scale=scale, | ||
).transpose(1, 2) | ||
else: | ||
hidden_states = sdpa( | ||
query=query, | ||
key=key, | ||
value=value, | ||
attn_mask=attn_mask, | ||
dropout_p=dropout_p, | ||
is_causal=is_causal, | ||
scale=scale, | ||
) | ||
return hidden_states | ||
|
||
torch.nn.functional.scaled_dot_product_attention = sdpa_hijack | ||
logger.debug("# # # AMD GO FAST # # #") | ||
except ImportError as e: | ||
logger.debug(f"# # # AMD GO SLOW {e} # # #") | ||
else: | ||
logger.debug(f"# # # AMD GO SLOW Could not detect AMD GPU from: {torch.cuda.get_device_name()} # # #") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#!/bin/bash | ||
|
||
# Determine if the user has a flash attention supported card. | ||
SUPPORTED_CARD=$(rocminfo | grep -c -e gfx1100 -e gfx1101 -e gfx1102) | ||
|
||
if [ "$SUPPORTED_CARD" -gt 0 ]; then | ||
if ! python -s -m pip install -U git+https://github.com/ROCm/flash-attention@howiejay/navi_support; then | ||
echo "Tried to install flash attention and failed!" | ||
else | ||
echo "Installed flash attn." | ||
PY_SITE_DIR=$(python -c "import sysconfig; print(sysconfig.get_path('purelib'))") | ||
if ! cp horde_worker_regen/amd_go_fast/amd_go_fast.py "${PY_SITE_DIR}"/hordelib/nodes/; then | ||
echo "Failed to install AMD GO FAST." | ||
else | ||
echo "Installed AMD GO FAST." | ||
fi | ||
fi | ||
else | ||
echo "Did not detect support for AMD GO FAST" | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.