Skip to content

Commit

Permalink
relax the CUDA arch limit to SM89
Browse files Browse the repository at this point in the history
  • Loading branch information
leeeizhang committed Aug 21, 2024
1 parent 40210ea commit a2a62aa
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from torchtitan.parallelisms import ParallelDims


def _is_sm90_or_later():
# Float8 is only supported on H100+ GPUs
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
def _is_sm89_or_later():
# Float8 is only supported on SM89 or later (H100+ GPUs)
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


class Float8Handler:
Expand All @@ -35,9 +35,9 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
float8_config = job_config.float8
if not float8_config.enable_float8_linear:
return
if not _is_sm90_or_later():
if not _is_sm89_or_later():
logger.warning(
"Failed to swap to Float8Linear because SM90 or later is not available",
"Failed to swap to Float8Linear because only SM89 or later is available",
)
return
try:
Expand Down

0 comments on commit a2a62aa

Please sign in to comment.