Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does Torch.ao Support FullyShardedDataParallel? #1413

Open
Lenan22 opened this issue Dec 13, 2024 · 4 comments
Open

Does Torch.ao Support FullyShardedDataParallel? #1413

Lenan22 opened this issue Dec 13, 2024 · 4 comments

Comments

@Lenan22
Copy link

Lenan22 commented Dec 13, 2024

When I add FullyShardedDataParallel to the model,

net_model_fsdp = FullyShardedDataParallel(net, **settings)
and then try to quantize it using:

quantize_(net_model_fsdp, int8_dynamic_activation_int8_weight())
I encounter the following error with torch.ao:

RuntimeError: CUDA error: an illegal memory access was encountered.

If I do not use FullyShardedDataParallel and directly quantize net (as shown below), there is no problem:
quantize_(net, int8_dynamic_activation_int8_weight())

Please help me analyze the reason.

@Lenan22 Lenan22 changed the title Does Torch.ao Support Multi-GPU Quantization for Large Models? Does Torch.ao Support FullyShardedDataParallel? Dec 13, 2024
@Lenan22
Copy link
Author

Lenan22 commented Dec 13, 2024

The issue was ultimately traced to the function _replace_with_custom_fn_if_matches_filter, as shown :

def _replace_with_custom_fn_if_matches_filter( model, replacement_fn, filter_fn, cur_fqn="", device=None, ) -> None: """ Recursively replaces each child module inmodelwith the result ofreplacement_fn(child)iffilter_fn(child)returnsTrue`.

Args:
    model (torch.nn.Module): The model containing modules to be replaced.
    replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules.
    filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace.
    cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "".
    device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None.

Returns:
    None
"""
if isinstance(model, Float8Linear):
    with torch.device("meta"):
        new_module = nn.Linear(model.in_features, model.out_features)
    new_module.weight = model.weight
    new_module.bias = model.bias
    model = new_module
if filter_fn(model, cur_fqn[:-1]):
    if device is not None:
        model.to(device=device)  # move to device before quantization
    model = replacement_fn(model)

    print(model)
    print(model.weight)
    return model
else:
    for name, child in model.named_children():
        new_child = _replace_with_custom_fn_if_matches_filter(
            child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device
        )
        if new_child is not child:
            setattr(model, name, new_child)
    if device is not None:
        model.to(device=device)  # move parent module to device
    return model

`

If the model is wrapped with FSDP, the output when printing model and model.weight will look like this:

With fsdp. BAD CASE
`
(Pdb) p model
Linear(in_features=5120, out_features=2048, bias=False)

(Pdb) p model.weight
*** RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
`

No fsdp. GOOD CASE
`
(Pdb) p model
Linear(in_features=5120, out_features=2048, bias=False)

(Pdb) p model.weight
Parameter containing:
tensor([[ 0.5496, 2.0243, -2.6941, ..., 0.0289, 1.3861, 0.3328],
[-1.9182, -0.6501, -2.1842, ..., 0.3159, 2.4028, 1.7868],
[-1.8871, -1.5453, -1.7725, ..., 2.7592, -1.0941, -1.4395],
...,
[-1.7703, 2.2389, -2.1487, ..., 2.9390, 0.2677, -0.9585],
[-1.5673, 2.7989, -1.6924, ..., -1.2959, 1.1318, 0.0272],
[-2.4031, 1.1434, -2.4386, ..., 1.7391, 1.9264, -2.9632]],
requires_grad=True)
`

@gau-nernst
Copy link
Collaborator

Generally you should quantize the model first, before applying FSDP. This is because you can't re-assign nn.Parameter after FSDP I think.

I'm not sure if AQT (the subclass backing int8_dynamic_activation_int8_weight()) has implemented sufficient methods to support FSDP. You can try first.

@Lenan22
Copy link
Author

Lenan22 commented Dec 17, 2024

Generally you should quantize the model first, before applying FSDP. This is because you can't re-assign nn.Parameter after FSDP I think.

I'm not sure if AQT (the subclass backing int8_dynamic_activation_int8_weight()) has implemented sufficient methods to support FSDP. You can try first.

I have already tried it, and it still results in a memory access error. I will continue to investigate how to address this issue. If you have already resolved it, please share the solution with me.

@gau-nernst
Copy link
Collaborator

Oh I didn't notice you were using FSDP1. FSDP1 won't be supported I think. FSDP2 can be supported (similar to NF4+FSDP2 in torchtune), but I'm not sure if it's currently working now. You can try FSDP2. For example

quantize_(base_model.layers, quantize_fn, set_inductor_config=False)
quantize_(fsdp_model.layers, quantize_fn, set_inductor_config=False)
for layer in fsdp_model.layers:
fully_shard(layer, mp_policy=mp_policy)
fully_shard(fsdp_model, mp_policy=mp_policy)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants