Skip to content

Commit

Permalink
Merge pull request #78 from coreweave/es/te-base-flash-attn
Browse files Browse the repository at this point in the history
feat(torch): Update `TransformerEngine` to v1.11 & move `flash-attn`
  • Loading branch information
wbrown authored Sep 27, 2024
2 parents 12e0613 + 52fb3c4 commit 9184858
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 58 deletions.
44 changes: 0 additions & 44 deletions torch-extras/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,9 @@

ARG BASE_IMAGE
ARG DEEPSPEED_VERSION="0.14.4"
ARG FLASH_ATTN_VERSION="2.6.3"
ARG APEX_COMMIT="23c1f86520e22b505e8fdfcf6298273dff2d93d8"
ARG XFORMERS_VERSION="0.0.27.post2"

FROM alpine/git:2.36.3 as flash-attn-downloader
WORKDIR /git
ARG FLASH_ATTN_VERSION
RUN git clone --recurse-submodules --shallow-submodules -j8 --depth 1 \
--filter=blob:none --also-filter-submodules \
https://github.com/Dao-AILab/flash-attention -b v${FLASH_ATTN_VERSION}

FROM alpine/git:2.36.3 as apex-downloader
WORKDIR /git
ARG APEX_COMMIT
Expand Down Expand Up @@ -155,40 +147,6 @@ SHELL ["/bin/sh", "-c"]
WORKDIR /wheels


FROM builder-base as flash-attn-builder

SHELL ["/bin/bash", "-o", "pipefail", "-c"]
RUN --mount=type=bind,from=flash-attn-downloader,source=/git/flash-attention,target=flash-attention/,rw \
python3 -m pip install -U --no-cache-dir \
packaging setuptools wheel pip && \
export CC=$(realpath -e ./compiler) \
MAX_JOBS="$(./scale.sh "$(./effective_cpu_count.sh)" 8 12)" \
NVCC_APPEND_FLAGS='-diag-suppress 186,177' \
PYTHONUNBUFFERED=1 \
FLASH_ATTENTION_FORCE_BUILD='TRUE' && \
cd flash-attention && \
( \
for EXT_DIR in $(realpath -s -e \
. \
csrc/ft_attention \
csrc/fused_dense_lib \
csrc/fused_softmax \
csrc/layer_norm \
csrc/rotary \
csrc/xentropy); \
do \
cd $EXT_DIR && \
python3 setup.py bdist_wheel --dist-dir /wheels && \
cd - || \
exit 1; \
done; \
) | \
grep -Ev --line-buffered 'ptxas info\s*:|bytes spill stores'
SHELL ["/bin/sh", "-c"]

WORKDIR /wheels


FROM builder-base as apex-builder

RUN LIBNCCL2_VERSION=$(dpkg-query --showformat='${Version}' --show libnccl2) && \
Expand Down Expand Up @@ -279,8 +237,6 @@ RUN apt-get -qq update && \

RUN --mount=type=bind,from=deepspeed-builder,source=/wheels,target=/tmp/wheels \
python3 -m pip install --no-cache-dir /tmp/wheels/*.whl
RUN --mount=type=bind,from=flash-attn-builder,source=/wheels,target=/tmp/wheels \
python3 -m pip install --no-cache-dir /tmp/wheels/*.whl
RUN --mount=type=bind,from=apex-builder,source=/wheels,target=/tmp/wheels \
python3 -m pip install --no-cache-dir /tmp/wheels/*.whl
RUN --mount=type=bind,from=xformers-builder,source=/wheels,target=/tmp/wheels \
Expand Down
84 changes: 70 additions & 14 deletions torch/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ ARG FINAL_BASE_IMAGE="nvidia/cuda:12.0.1-base-ubuntu22.04"
ARG BUILD_TORCH_VERSION="2.4.0"
ARG BUILD_TORCH_VISION_VERSION="0.19.0"
ARG BUILD_TORCH_AUDIO_VERSION="2.4.0"
ARG BUILD_TRANSFORMERENGINE_VERSION="1.9"
ARG BUILD_TRANSFORMERENGINE_VERSION="458c7de038ed34bdaf471ced4e3162a28055def7"
ARG BUILD_FLASH_ATTN_VERSION="2.6.3"
ARG BUILD_TRITON_VERSION=""
ARG BUILD_TORCH_CUDA_ARCH_LIST="6.0 6.1 6.2 7.0 7.2 7.5 8.0 8.6 8.9 9.0+PTX"

Expand Down Expand Up @@ -62,6 +63,11 @@ FROM downloader-base as transformerengine-downloader
ARG BUILD_TRANSFORMERENGINE_VERSION
RUN ./clone.sh NVIDIA/TransformerEngine TransformerEngine "${BUILD_TRANSFORMERENGINE_VERSION}"

FROM downloader-base as flash-attn-downloader
WORKDIR /git
ARG BUILD_FLASH_ATTN_VERSION
RUN ./clone.sh Dao-AILab/flash-attention flash-attention "${BUILD_FLASH_ATTN_VERSION}"

FROM downloader-base as triton-downloader
ARG BUILD_TRITON_VERSION
RUN if [ -n "${BUILD_TRITON_VERSION}" ]; then \
Expand Down Expand Up @@ -96,7 +102,7 @@ RUN curl -sSfo- "${AOCL_URL}" | tar xzf - --strip-components 1 && \


## Build PyTorch on a builder image.
FROM ${BUILDER_BASE_IMAGE} as builder
FROM ${BUILDER_BASE_IMAGE} as builder-base
ENV DEBIAN_FRONTEND=noninteractive

ARG BUILD_CCACHE_SIZE="1Gi"
Expand Down Expand Up @@ -168,12 +174,13 @@ RUN CODENAME="$(lsb_release -cs)" && \
SETUP_TOOLCHAIN() { \
apt-add-repository -y ppa:ubuntu-toolchain-r/test 2>&1 \
| sed -e '/connection timed out/{p; Q1}' && \
apt-get -qq install --no-install-recommends -y gcc-11 g++-11 lld-17 && \
apt-get -qq install --no-install-recommends -y gcc-11 g++-11 gfortran-11 lld-17 && \
apt-get clean; \
} && \
{ SETUP_TOOLCHAIN || { sleep "$(shuf -i10-20 -n1)" && SETUP_TOOLCHAIN; }; } && \
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 11 && \
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 11 && \
update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-11 11 && \
update-alternatives --install /usr/bin/ld ld /usr/bin/ld.lld-17 1

# Install AOCL-BLAS and AOCL-LAPACK
Expand Down Expand Up @@ -211,6 +218,8 @@ RUN mkdir /build /build/dist
WORKDIR /build
COPY --chmod=755 effective_cpu_count.sh .
COPY --chmod=755 scale.sh .
COPY compiler_wrapper.f95 .
RUN gfortran -O3 ./compiler_wrapper.f95 -o ./compiler && rm ./compiler_wrapper.f95

COPY <<-"EOT" /build/version-string.sh
#!/bin/sh
Expand Down Expand Up @@ -324,10 +333,18 @@ RUN --mount=type=bind,from=pytorch-downloader,source=/git/pytorch,target=pytorch
WITH_BLAS=FLAME \
PYTORCH_BUILD_VERSION="$(../version-string.sh "$TORCH_VERSION")" \
PYTORCH_BUILD_NUMBER=0 \
TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
TORCH_NVCC_FLAGS="-Xfatbin -compress-all -diag-suppress 191,186,177" \
python3 setup.py bdist_wheel --dist-dir ../dist
RUN pip3 install --no-cache-dir --upgrade dist/torch*.whl

ENV NVCC_APPEND_FLAGS="-diag-suppress 191,186,177"

RUN python3 -m pip install -U --no-cache-dir \
packaging setuptools wheel pip

FROM builder-base as torchvision-builder
RUN rm ./dist/*

## Build torchvision
ARG BUILD_TORCH_VISION_VERSION
ENV TORCH_VISION_VERSION=$BUILD_TORCH_VISION_VERSION
Expand Down Expand Up @@ -362,6 +379,9 @@ RUN --mount=type=bind,from=torchvision-downloader,source=/git/vision,target=visi
TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
python3 setup.py bdist_wheel --dist-dir ../dist

FROM builder-base as torchaudio-builder
RUN rm ./dist/*

## Build torchaudio
ARG BUILD_TORCH_AUDIO_VERSION
ENV TORCH_AUDIO_VERSION=$BUILD_TORCH_AUDIO_VERSION
Expand Down Expand Up @@ -396,17 +416,53 @@ RUN --mount=type=bind,from=torchaudio-downloader,source=/git/audio,target=audio/
TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
python3 setup.py bdist_wheel --dist-dir ../dist

FROM builder-base as transformerengine-builder
RUN rm ./dist/*

# Build TransformerEngine
RUN --mount=type=bind,from=transformerengine-downloader,source=/git/TransformerEngine,target=TransformerEngine/,rw \
python3 -m pip install -U --no-cache-dir \
packaging setuptools wheel pip && \
export MAX_JOBS=$(($(./effective_cpu_count.sh) + 2)) && \
export NVCC_APPEND_FLAGS='-diag-suppress 186,177' && \
cd TransformerEngine && \
if [ `python3 -V | cut -f 2 -d ' ' | cut -f 2 -d '.'` -lt "9" ] ; then \
sed -i "s/from functools import cache/from functools import lru_cache as cache/g" \
build_tools/utils.py; fi && \
python3 setup.py bdist_wheel --dist-dir /build/dist
export MAX_JOBS=$(($(./effective_cpu_count.sh) + 2)) && \
cd TransformerEngine && \
if python3 -c "import sys; sys.exit(sys.version_info.minor > 8)"; then \
sed -i "s/from functools import cache/from functools import lru_cache as cache/g" \
build_tools/utils.py; \
fi && \
python3 setup.py bdist_wheel --dist-dir /build/dist

FROM builder-base as flash-attn-builder
RUN rm ./dist/*

SHELL ["/bin/bash", "-o", "pipefail", "-c"]
RUN --mount=type=bind,from=flash-attn-downloader,source=/git/flash-attention,target=flash-attention/,rw \
export CC=$(realpath -e ./compiler) \
MAX_JOBS="$(./scale.sh "$(./effective_cpu_count.sh)" 8 12)" \
PYTHONUNBUFFERED=1 \
FLASH_ATTENTION_FORCE_BUILD='TRUE' && \
cd flash-attention && \
( \
for EXT_DIR in $(realpath -s -e \
. \
csrc/ft_attention \
csrc/fused_dense_lib \
csrc/fused_softmax \
csrc/layer_norm \
csrc/rotary \
csrc/xentropy); \
do \
cd $EXT_DIR && \
python3 setup.py bdist_wheel --dist-dir /build/dist && \
cd - || \
exit 1; \
done; \
) | \
grep -Ev --line-buffered 'ptxas info\s*:|bytes spill stores'
SHELL ["/bin/sh", "-c"]

FROM builder-base as builder
COPY --link --from=torchaudio-builder /build/dist/ /build/dist/
COPY --link --from=torchvision-builder /build/dist/ /build/dist/
COPY --link --from=transformerengine-builder /build/dist/ /build/dist/
COPY --link --from=flash-attn-builder /build/dist/ /build/dist/

## Build the final torch image.
FROM ${FINAL_BASE_IMAGE}
Expand Down Expand Up @@ -486,5 +542,5 @@ WORKDIR /usr/src/app

# Install custom PyTorch wheels.
RUN --mount=type=bind,from=builder,source=/build/dist,target=. \
pip3 install --no-cache-dir -U numpy && \
pip3 install --no-cache-dir -U numpy packaging && \
pip3 install --no-cache-dir -U ./*.whl
76 changes: 76 additions & 0 deletions torch/compiler_wrapper.f95
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
PROGRAM compiler_wrapper
! Wraps GCC invocations,
! replacing -D__AVX512__ and -D__SCALAR__ preprocessor definitions
! with -D__AVX256__, and -march=native with -march=skylake,
! for better reproducibility and compatibility.
IMPLICIT NONE
INTEGER :: i, exitcode = 0, full_length = 0, truncated = 0
CHARACTER(len=:), ALLOCATABLE :: arg, command
ALLOCATE(CHARACTER(len=128) :: arg)
command = "gcc"

DO i = 1, COMMAND_ARGUMENT_COUNT()
DO
CALL GET_COMMAND_ARGUMENT(i, arg, full_length, truncated)
IF (truncated == 0) THEN
EXIT
ELSE IF (truncated == -1) THEN
DEALLOCATE(arg)
ALLOCATE(CHARACTER(len=full_length) :: arg)
ELSE
CALL EXIT(95)
END IF
END DO
IF (arg == "-march=native") THEN
command = command // " '-march=skylake'"
ELSE IF (arg == "-D__AVX512__" .OR. arg == "-D__SCALAR__") THEN
command = command // " '-D__AVX256__'"
ELSE
command = command // shell_escaped(arg)
END IF
END DO
CALL SYSTEM(command, exitcode)
IF (exitcode > 255) THEN
exitcode = MAX(IAND(exitcode, 255), 1)
END IF
CALL EXIT(exitcode)


CONTAINS
FUNCTION shell_escaped(str) RESULT(out)
! Turns [str] into [ 'str'] and replaces all
! internal ['] characters with ['"'"']
IMPLICIT NONE
CHARACTER(len=*), INTENT(IN) :: str
CHARACTER(len=:), ALLOCATABLE :: out
INTEGER :: old_i, out_i, old_len, out_len

old_len = LEN_TRIM(str)
! Figure out the new length to allocate by scanning `str`.
! This always needs to add at least [ '] at the beginning
! and ['] at the end, so the length increases by at least 3.
out_len = old_len + 3
DO old_i = 1, old_len
IF (str(old_i:old_i) == "'") THEN
out_len = out_len + 4
END IF
END DO
ALLOCATE(CHARACTER(len=out_len) :: out)

! Copy over the string, performing necessary escapes.
out(1:2) = " '"
out_i = 3
DO old_i = 1, old_len
IF (str(old_i:old_i) == "'") THEN
! Escape internal single-quotes
out(out_i:out_i + 4) = '''"''"'''
out_i = out_i + 5
ELSE
! No escaping needed
out(out_i:out_i) = str(old_i:old_i)
out_i = out_i + 1
END IF
END DO
out(out_i:out_i) = "'"
END FUNCTION
END PROGRAM

0 comments on commit 9184858

Please sign in to comment.