Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: How to use Float8InferenceLinear with FSDP1/2? #704

Open
qingquansong opened this issue Aug 19, 2024 · 15 comments
Open

Question: How to use Float8InferenceLinear with FSDP1/2? #704

qingquansong opened this issue Aug 19, 2024 · 15 comments

Comments

@qingquansong
Copy link
Contributor

qingquansong commented Aug 19, 2024

Hey Team,

I'm trying to use FSDP1/2 with Float8InferenceLinear but seems have some issues (with torch 2.3.1+cu118). Do you suggestion to bump to higher version of torch and have a try or maybe use the training setup without using the inference layer? I also tried using the Flont8linear layer without using the quantization function to convert to Float8InferenceLinear but seems face some issues when using FSDP1 that when computing the amax, some input x tensors are empty (x.numel()=0) and some are NaN.

Best regards,
QQ

@supriyar
Copy link
Contributor

cc @drisspg @jainapurva

@drisspg
Copy link
Contributor

drisspg commented Aug 19, 2024

Unfortunately the Float8InferenceLinear is being developed against the latest pytorch nightly and is not very tested on older versions of PyTorch. If it is possible for you to update your PyTorch version that is recommend. If the problem still persists after updating and you are able to create a minimal reproducer we can look into this.

@qingquansong
Copy link
Contributor Author

@drisspg Got it. Thank you! To confirm, torch==2.5.0dev should be the right one to use?

@msaroufim
Copy link
Member

pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall

Today this would indeed install 2.5 dev for today's date but yeah generally for any feature leveraging torch.compile you want to either be on the latest stable (today this is 2.4) or use nightlies

@qingquansong
Copy link
Contributor Author

qingquansong commented Aug 19, 2024

Thanks @msaroufim ! Is cu118 version also supported and tested? (if I disable torch compile and fsdp2 dtensor and just use fsdp1) let me do a quick test and check. Thank you!

@qingquansong
Copy link
Contributor Author

qingquansong commented Aug 19, 2024

A quick update: it turns out that there might but some issues with torch==2.5.0.dev20240819+cu118 installed from https://download.pytorch.org/whl/nightly/cu118

[rank3]:     attn_output = torch.nn.functional.scaled_dot_product_attention(
[rank3]: RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans support the graph.

exploring other options now and probably have to use 12.1 runtime version instead

Update: seems root cause are libnvrtc.so.11.2 loading issues

for 11.8: Could not load library libnvrtc.so.11.2. Error: libnvrtc.so.11.2: cannot open shared object file: No such file or directory
Could not load library libnvrtc.so. Error: libnvrtc.so: cannot open shared object file: No such file or directory

for 12.1: Could not load library libnvrtc.so.12. Error: libnvrtc.so.12: cannot open shared object file: No such file or directory investigating now

@msaroufim
Copy link
Member

I'd try isolating things in a fresh conda environment, also if you're mucking around with CUDA versions keep in mind that torchao binaries on pypi are using cuda 12.1 so would recommend installing ao from source or downloading it from the pytorch index

@qingquansong
Copy link
Contributor Author

Thank you! Resolved the above issue by adding the current path to LD_LIBRARY_PATH and currently testing the fp8 with latest ao build + torch 2.5.0 dev cu121 as suggested 🤞

@qingquansong
Copy link
Contributor Author

qingquansong commented Aug 20, 2024

Faced the same issue when testing the mixtral 8X7B model (gated routing layer has been excluded) with the code of replacing layers + FSDP below:

ank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 934, in forward
[rank2]:     hidden_states, router_logits = self.block_sparse_moe(hidden_states)
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 861, in forward
[rank2]:     current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 797, in forward
[rank2]:     current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torchao/float8/float8_linear.py", line 360, in forward
[rank2]:     input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torchao/float8/float8_linear.py", line 253, in cast_input_to_float8
[rank2]:     _maybe_initialize_amaxes_scales_for_float8_cast(
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torchao/float8/float8_scaling_utils.py", line 119, in _maybe_initialize_amaxes_scales_for_float8_cast
[rank2]:     new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank2]:     return func(*args, **kwargs)
[rank2]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torchao/float8/float8_utils.py", line 102, in tensor_to_amax
[rank2]:     amax = torch.max(torch.abs(x))
[rank2]: RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.
    # Define the FSDP configuration
    import functools

    def custom_auto_wrap_policy(module, recurse, nonwrapped_numel):
        # Define the set of layers that you want to wrap
        layers_to_wrap = {MixtralDecoderLayer}
        # Check if the module is in the set of layers to wrap
        return type(module) in layers_to_wrap

    if args.enable_fp8:
        # from train_utils import patch_torch

        # patch_torch()
        from torchao.float8 import (  # precompute_float8_dynamic_scale_for_fsdp, # specific to fsdp2 + dynamic scaling, apply after each training loop iter
            CastConfig,
            Float8LinearConfig,
            ScalingType,
            convert_to_float8_training,
        )

        config = Float8LinearConfig(
            # enable_amax_init=True,  # only needed for autocast + compile + FSDP +  float8 delayed
            # enable_pre_and_post_forward=True,  # only needed for autocast + compile + FSDP +  float8 delayed
            # enable_fsdp_float8_all_gather=True,
            cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
            cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
            cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
        )

        # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
        # type
        def module_filter_fn(mod: torch.nn.Module, fqn: str):
            # don't convert the output module
            if "lm_head" in fqn:
                return False
            # don't convert linear modules with weight dimensions not divisible by 16
            if isinstance(mod, torch.nn.Linear):
                if "block_sparse_moe.gate" in fqn:
                    print(f"Ignore router layer replacement {fqn}")
                # if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
                    return False
            return True

        convert_to_float8_training(
            model,
            config=config,
            module_filter_fn=module_filter_fn
        )
        from torchao.float8.inference import (
            ActivationCasting,
            Float8InferenceLinear,
            QuantConfig,
            quantize_to_float8,
        )
        quant_config = QuantConfig(ActivationCasting.DYNAMIC)
        # quantize_to_float8(model, quant_config)

    print(model)
    torch.distributed.constants.default_pg_timeout = timedelta(seconds=7200)
    fsdp_config = FSDP(
        model,
        auto_wrap_policy=custom_auto_wrap_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        # backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        # state_dict_type="sharded",
        mixed_precision=MixedPrecision(
           param_dtype=torch.bfloat16,
           reduce_dtype=torch.bfloat16,
            # buffer_dtype=torch.bfloat16,
        ),
        device_id=torch.cuda.current_device(),
        use_orig_params=True,
    )

Also tried uncomment the line quantize_to_float8(model, quant_config) to replace with the Float8InferenceLinear layer and got an error when wrapping this layer with FSDP: (tried a bit to modify the autocast_to_copy code but still got some other errors)

[rank4]:   File "/export/home/qsong/torch_fsdp_inference.py", line 222, in main
[rank4]:     fsdp_config = FSDP(
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__
[rank4]:     _auto_wrap(
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank4]:     _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 545, in _recursive_wrap
[rank4]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 545, in _recursive_wrap
[rank4]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 545, in _recursive_wrap
[rank4]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 563, in _recursive_wrap
[rank4]:     return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 492, in _wrap
[rank4]:     return wrapper_cls(module, **kwargs)
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
[rank4]:     _init_param_handle_from_module(
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 612, in _init_param_handle_from_module
[rank4]:     _move_module_to_device(
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 1005, in _move_module_to_device
[rank4]:     _move_states_to_device(params_to_move, bufs_to_move, device_from_device_id)
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 1035, in _move_states_to_device
[rank4]:     param.data = param.to(device_from_device_id)
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torchao/float8/float8_tensor.py", line 359, in __torch_dispatch__
[rank4]:     return FLOAT8_OPS_TABLE[func](func, args, kwargs)
[rank4]:   File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torchao/float8/float8_ops.py", line 244, in autocast_to_copy
[rank4]:     len(kwargs) == 1 and "dtype" in kwargs
[rank4]: AssertionError: Only support dtype kwarg for autocast

@msaroufim
Copy link
Member

@qingquansong thanks! do you have a minimal repro so we can take a look?

@qingquansong
Copy link
Contributor Author

Let me create a mini mixtral model with some synthetic data.

@qingquansong
Copy link
Contributor Author

qingquansong commented Aug 20, 2024

Update:

  1. for the first issue, it is caused by some NAN values results in some mixtral experts cannot get tokens. And something weird is that FSDP loading itself seems have some issues on the loaded weights (even without FP8 layer) and behave differently with different wrapping policies and likely it's related to the reduction precision I set to be bfloat16 + sync_module_states need to set to be True. I'll need to debug a bit more on this and currently can confirm if the weight loading is correct and expert can access at least 1 token, this error should be resolved. Some extra problem is:

[rank7]: ValueError: The module has CPU parameters or buffers when sync_module_states=True, which requires them to be on GPU. Please specify the device_id argument or move the module to GPU before passing it to FSDP. where it seems the original model weights parameters is still allocated on CPU if we enable FP8 convert_to_float8_training conversion thus causing a bit issue, I temporarily set sync_module_states=True and enable_pre_and_post_forward=False to avoid this issue but not sure if this is the correct way.

Some other thing I'm not sure is if I just wanna do inference how should I set the following 6 args + the FSDP args? The speed seems to slow down with the FP8 layer in this case and memory is also not reduced much as expected. Setting input config to DYNAMIC seems to make things faster but still comparable with bf16 for mixtral 8*7B

            config = Float8LinearConfig(
                # enable_amax_init=True,  # only needed for autocast + compile + FSDP +  float8 delayed
                # enable_pre_and_post_forward=False,  # only needed for autocast + compile + FSDP +  float8 delayed
                # enable_fsdp_float8_all_gather=False,
                cast_config_input=CastConfig(scaling_type=ScalingType.DYNAMIC),
                cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
                cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
            )
  1. For the second Float8InferenceLinear issue, it can be reproduced with the following codes with running command below. (torch and ao are both latest version with cu121 and using H100 X 8 to test, model config can be set to smaller to make things faster) I commented out this line # quantize_to_float8(model, quant_config) so this script should be able to run smoothly but if commenting out, it will raise issues with the Float8InferenceLinear layer.
ACCELERATE_USE_FSDP=1 FSDP_CPU_RAM_EFFICIENT_LOADING=1 torchrun --nnodes=1 --nproc-per-node=8 torch_fsdp_inference_mini.py \
   --batch_size 16 \
   --enable_fp8

**save this script in torch_fsdp_inference_mini.py **

import os
from datetime import timedelta
import argparse
from dataclasses import _MISSING_TYPE, dataclass

import torch
import torch.distributed as dist
from config import parse_args
from torch.distributed.fsdp import BackwardPrefetch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer

import numpy as np
from torch.utils.data import Dataset, DataLoader




class SyntheticDataset(Dataset):
    def __init__(self, num_samples, max_length):
        self.num_samples = num_samples
        self.max_length = max_length
        self.input_ids = np.random.randint(0, num_samples, (num_samples, max_length))
        self.attention_mask = np.ones((num_samples, max_length), dtype=np.int32)
        self.labels = self.input_ids

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx]
        }


def get_distributed_dataloader(
        batch_size, shuffle=True
):
    dataset = SyntheticDataset(num_samples=512, max_length=4096)
    sampler = DistributedSampler(
        dataset,
        shuffle=shuffle,
    )
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
    )
    return dataloader


def configure_model():
    from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
    mini_model_config=MixtralConfig(
        attention_dropout=0.0,
        bos_token_id=1,
        eos_token_id=2,
        hidden_act="silu",
        hidden_size= 4096,
        initializer_range=0.02,
        intermediate_size=14336,
        max_position_embeddings=32768,
        num_attention_heads=32,
        num_experts_per_tok=2,
        num_hidden_layers=1,
        num_key_value_heads=8,
        num_local_experts=8,
        output_router_logits=False,
        rms_norm_eps=1e-5,
        rope_theta=1000000.0,
        router_aux_loss_coef=0.02,
        sliding_window=None,
        tie_word_embeddings=False,
        use_cache=True,
        vocab_size=32000,
        # At rope backward
        # Eager produces incontiguous dq and dk
        # SDPA produces contiguous dq and incontiguous dk
        # Flash_attn produces contiguous dq and dk
        attn_implementation="sdpa",  # default value, pytorch native attention
    )
    return MixtralForCausalLM(mini_model_config).to(dtype=torch.float16)


def cleanup():
    dist.destroy_process_group()


def run_inference(model, dataloader, device):
    num_correct = 0
    num_total = 0

    with torch.no_grad():
        for batch in tqdm(
                dataloader, desc=f"Processing batches on rank {dist.get_rank()}"
        ):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch[
                    "labels"
                ],
            )
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = batch["labels"][..., 1:].contiguous()
            mask = shift_labels != -100
            correct = (shift_logits.argmax(dim=-1) == shift_labels) & mask
            num_correct += correct.sum().item()
            num_total += mask.sum().item()

    accuracy = num_correct / num_total
    print(f"Final prediction accuracy: {accuracy}")
    return accuracy


@dataclass
class TrainingArgs:
    enable_fp8: bool = False
    batch_size: int = 8


def parse_args() -> TrainingArgs:
    parser = argparse.ArgumentParser()
    for k, v in TrainingArgs.__dataclass_fields__.items():
        if v.type != bool:
            parser.add_argument(f"--{k}", type=v.type, default=v.default)
        else:
            if not v.default:
                parser.add_argument(f"--{k}", action="store_true")
            else:
                parser.add_argument(f"--{k}", action="store_false")
    parsed = parser.parse_args()
    return TrainingArgs(
        **{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)}
    )



def main():
    args = parse_args()
    dist.init_process_group("nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    global_rank = int(os.environ["RANK"])

    torch.manual_seed(42)
    val_dataloader = get_distributed_dataloader(
        args.batch_size,
    )

    # Initialize and configure the model
    model = configure_model()
    # Set device and run inference
    torch.cuda.set_device(local_rank)
    torch.cuda.empty_cache()
    device = "cuda:" + str(local_rank)

    # Define the FSDP configuration
    def custom_auto_wrap_policy(module, recurse, nonwrapped_numel):
        # Define the set of layers that you want to wrap
        layers_to_wrap = {MixtralDecoderLayer}
        # Check if the module is in the set of layers to wrap
        return type(module) in layers_to_wrap

    if args.enable_fp8:
        from train_utils import patch_torch

        patch_torch()
        from torchao.float8 import (  # precompute_float8_dynamic_scale_for_fsdp, # specific to fsdp2 + dynamic scaling, apply after each training loop iter
            CastConfig,
            Float8LinearConfig,
            ScalingType,
            convert_to_float8_training,
        )

        config = Float8LinearConfig(
            # enable_amax_init=True,  # only needed for autocast + compile + FSDP +  float8 delayed
            # enable_pre_and_post_forward=True,  # only needed for autocast + compile + FSDP +  float8 delayed
            # enable_fsdp_float8_all_gather=True,
            cast_config_input=CastConfig(scaling_type=ScalingType.DYNAMIC),
            cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
            cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
        )

        # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
        # type
        def module_filter_fn(mod: torch.nn.Module, fqn: str):
            # don't convert the output module
            if "lm_head" in fqn:
                return False
            # don't convert linear modules with weight dimensions not divisible by 16
            if isinstance(mod, torch.nn.Linear):
                if "block_sparse_moe.gate" in fqn:
                    print(f"Ignore router layer replacement {fqn}")
                    # if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
                    return False
            return True

        convert_to_float8_training(
            model,
            config=config,
            module_filter_fn=module_filter_fn
        )
        from torchao.float8.inference import (
            ActivationCasting,
            Float8InferenceLinear,
            QuantConfig,
            quantize_to_float8,
        )
        quant_config = QuantConfig(ActivationCasting.DYNAMIC)
        # quantize_to_float8(model, quant_config)

    torch.distributed.constants.default_pg_timeout = timedelta(seconds=7200)
    fsdp_config = FSDP(
        model,
        auto_wrap_policy=custom_auto_wrap_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        # backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        # state_dict_type="sharded",
        sync_module_states=True,
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            # reduce_dtype=torch.bfloat16,
            # buffer_dtype=torch.bfloat16,
        ),
        device_id=torch.cuda.current_device(),
        use_orig_params=True,
    )
    # inference and record the time
    init_start_event = torch.cuda.Event(enable_timing=True)
    init_end_event = torch.cuda.Event(enable_timing=True)
    init_start_event.record()
    run_inference(fsdp_config, val_dataloader, device)
    init_end_event.record()
    torch.cuda.synchronize()

    if global_rank == 0:
        print(
            f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec"
        )
        print(f"{model}")
    # Clean up
    cleanup()


if __name__ == "__main__":
    main()

@qingquansong
Copy link
Contributor Author

For the speed / memory issue, I guess it related to not using torch compile based on the related tickets:

#685 [FP8] performance degradation in speed and memory without compile

I'll check if I can use torch compile here. Thanks.

@qingquansong
Copy link
Contributor Author

qingquansong commented Aug 22, 2024

Currently it's a bit blocked on the torch compile + Mixtral. [The context of using torch.compile is that it seems it's required to combine with fp8 linear to help improve the speed as discussed in some threads:

#685

pytorch/torchtitan#462 (comment)

Huggingface Mixtral model does not directly support torch compile as stated here mainly due to the sparse moe with torch where causing the dynamic token numbers in routing to different experts and is also an ongoing efforts here: huggingface/transformers#30793

I've tried the option in gpu-fast (similar as the above pr change to convert to a fused moe) but it's more suitable for fast text generation phase with small batch size and would have high memory consumption for large batch size prefill stage. Also could break the nature of linear layers to replace with fp8linear directly.

@qingquansong
Copy link
Contributor Author

I put some of my raw test scripts here https://github.com/qingquansong/fp8_fsdp_test in case anyone is interested. Sorry that didn't change the model and data local paths.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants