Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ShardedQuantEmbeddingBagCollection doesn't seem to be distributing the shards properly #2575

Open
Hanyu-Li opened this issue Nov 21, 2024 · 1 comment

Comments

@Hanyu-Li
Copy link

Hi Torchrec Team,

I'm following the protocol in https://github.com/pytorch/torchrec/blob/main/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb to setup sharded quantization tables, but have encountered the following issues:

  1. In an multi-GPU setup, I'm trying to follow the standard protocol and convert a EmbeddingBagCollection(device="meta", ...) to quantized version with quant_dynamic and then Shard it to different gpus with DistributedModelParallel(or _shard_module) wrappers, but noticed that it always ends up allocating all shards on to every rank. The script I attached prints out a comparison between the original vs quantized embedding tables. The original shared embedding bag collection only keeps one local shard of weight tensor but the quantized counterpart seems to hold all pieces locally and creates massive memory footprint.

  2. Another observation is that the quantized table loses all values during sharding, I cannot seem to recover the quantized weight/weight_qscale/weight_qbias without manually reloading from some state_dict, which doesn't feel right and may suggest some issue with the sharding process in the first place

Let me know if you have any suggestions, thanks!

I'm using the following versions

Python 3.8.16

NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.1

torch                                    2.1.0
torchrec                                 0.4.0a0+6e8cc97
fbgemm-gpu                               0.5.0rc3

A minimal sample code to reproduce the issue

#!/usr/bin/env python
# coding: utf-8
from __future__ import annotations

from typing import Dict
from typing import List
from typing import cast

import copy
import datetime
import logging
import os
import time
from pprint import pprint

import multiprocess
import torch
import torch.distributed as dist
import torchrec
from torch import nn
from torch import quantization as quant
from torch.distributed._shard.sharded_tensor import Shard
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._shard.sharded_tensor import ShardMetadata
from torch.distributed._shard.sharded_tensor import init_from_local_shards
from torch.distributed.remote_device import _remote_device
from torch.quantization import QConfig as QuantConfig
from torchrec import EmbeddingBagCollection
from torchrec import JaggedTensor
from torchrec import KeyedJaggedTensor
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder
from torchrec.distributed.fused_embeddingbag import ShardedFusedEmbeddingBagCollection
from torchrec.distributed.fused_params import FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS
from torchrec.distributed.fused_params import FUSED_PARAM_REGISTER_TBE_BOOL
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import EmbeddingShardingPlanner
from torchrec.distributed.planner import Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder
from torchrec.distributed.quant_embeddingbag import ShardedQuantEmbeddingBagCollection
from torchrec.distributed.shard import _shard_modules

# from torchrec.distributed.test_utils.test_sharding import copy_state_dict
from torchrec.distributed.test_utils.test_sharding import ModelInputCallable
from torchrec.distributed.test_utils.test_sharding import SharderType
from torchrec.distributed.test_utils.test_sharding import create_test_sharder
from torchrec.distributed.test_utils.test_sharding import gen_model_and_input
from torchrec.distributed.types import BoundsCheckMode
from torchrec.distributed.types import ModuleSharder
from torchrec.distributed.types import ShardingEnv
from torchrec.distributed.types import ShardingType
from torchrec.modules.embedding_configs import DATA_TYPE_NUM_BITS
from torchrec.modules.embedding_configs import DataType
from torchrec.quant.embedding_modules import EmbeddingBagCollection as QuantEmbeddingBagCollection
from torchrec.quant.embedding_modules import quant_prep_enable_quant_state_dict_split_scale_bias_for_types

# from torchsnapshot import Snapshot

logger = logging.getLogger(__name__)

VOCAB_SIZE = 10_000
EMBEDDING_DIM = 8
QUANT_DTYPE = torch.quint8


class InferenceModule(torch.nn.Module):
    def __init__(self, ebc: torchrec.EmbeddingBagCollection):
        super().__init__()
        self.ebc_ = ebc

    def forward(self, kjt: KeyedJaggedTensor):
        return self.ebc_(kjt)


class WrapperModuleA(torch.nn.Module):
    def __init__(self, ebc: torchrec.EmbeddingBagCollection):
        super().__init__()
        self.inf_model = InferenceModule(ebc=ebc)

    def forward(self, kjt: KeyedJaggedTensor):
        return self.inf_model.ebc_(kjt)


def create_tables(large_table_cnt=1, embedding_dim=8, vocab_size=16, data_type=DataType.FP32):
    large_tables = [
        torchrec.EmbeddingBagConfig(
            name="large_table_" + str(i),
            embedding_dim=embedding_dim,
            num_embeddings=vocab_size,
            feature_names=["large_table_feature_" + str(i)],
            pooling=torchrec.PoolingType.SUM,
            data_type=data_type,
        )
        for i in range(large_table_cnt)
    ]
    return large_tables


def gen_constraints(
    large_table_cnt=1,
    sharding_type: ShardingType = ShardingType.TABLE_WISE,
    compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED,
) -> Dict[str, ParameterConstraints]:
    compute_kernels = []
    if compute_kernel:
        compute_kernels = [compute_kernel.value]

    large_table_constraints = {
        "large_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
            compute_kernels=compute_kernels,
        )
        for i in range(large_table_cnt)
    }
    return large_table_constraints


def shard_ori_model(model, pg, sharding_type):
    env = ShardingEnv.from_process_group(pg)

    topology = Topology(
        world_size=env.world_size,
        compute_device="cuda",
    )
    constraints = gen_constraints(sharding_type=sharding_type)
    planner = EmbeddingShardingPlanner(
        topology=topology,
        constraints=constraints,
    )

    sharders = [EmbeddingBagCollectionSharder()]

    plan = planner.collective_plan(model, sharders, pg)

    dmp_model = DistributedModelParallel(
        module=model,
        env=env,
        device=torch.device("cuda"),
        plan=plan,
        sharders=sharders,
        init_parameters=True,
    )
    return dmp_model


def inplace_quantize(module, quant_dtype=torch.quint8, split_scale_bias=False):
    if split_scale_bias:
        quant_prep_enable_quant_state_dict_split_scale_bias_for_types(module, [EmbeddingBagCollection])

    qconfig = QuantConfig(
        # dtype of the result of the embedding lookup, post activation
        # torch.float generally for compatability with rest of the model
        # as rest of the model here usually isn't quantized
        activation=quant.PlaceholderObserver.with_args(dtype=torch.float),
        # quantized type for embedding weights, aka parameters to actually quantize
        weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),
    )
    qconfig_spec = {
        # Map of module type to qconfig
        torchrec.EmbeddingBagCollection: qconfig,
    }
    mapping = {
        # Map of module type to quantized module type
        torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,
    }

    quant.quantize_dynamic(
        module,
        qconfig_spec=qconfig_spec,
        mapping=mapping,
        inplace=True,
    )


def shard_q_model(model, pg, sharding_type):
    torch.cuda.set_device(dist.get_rank())
    split_scale_bias = sharding_type != ShardingType.TABLE_WISE
    env = ShardingEnv.from_process_group(pg)

    topology = Topology(
        world_size=env.world_size,
        compute_device="cuda",
    )
    compute_kernel = EmbeddingComputeKernel.QUANT
    if sharding_type == ShardingType.DATA_PARALLEL:
        compute_kernel = None
    constraints = gen_constraints(sharding_type=sharding_type, compute_kernel=compute_kernel)
    planner = EmbeddingShardingPlanner(
        topology=topology,
        constraints=constraints,
    )

    sharders = [
        QuantEmbeddingBagCollectionSharder(
            fused_params={
                FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: split_scale_bias,
                FUSED_PARAM_REGISTER_TBE_BOOL: True,
            }
        )
    ]

    plan = planner.collective_plan(model, sharders, pg)
    sharded_q_model = DistributedModelParallel(
        model,
        env=ShardingEnv.from_process_group(pg),
        plan=plan,
        sharders=sharders,
        device=torch.device("cuda"),
    )
    return sharded_q_model


def single_rank_q_execution(
    rank: int,
    world_size: int,
    sharding_type: ShardingType,
    backend: str,
) -> None:
    timeout = datetime.timedelta(minutes=120)

    def init_distributed_single_host(
        rank: int,
        world_size: int,
        backend: str,
        # pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
    ) -> dist.ProcessGroup:
        os.environ["RANK"] = f"{rank}"
        os.environ["WORLD_SIZE"] = f"{world_size}"
        dist.init_process_group(rank=rank, world_size=world_size, backend=backend, timeout=timeout)
        return dist.group.WORLD

    if backend == "nccl":
        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")

    pg = init_distributed_single_host(rank, world_size, backend)
    rank = dist.get_rank()

    # prepare kjt
    large_jt = JaggedTensor(
        values=torch.tensor([1, round(VOCAB_SIZE * 0.25), round(VOCAB_SIZE * 0.5), round(VOCAB_SIZE * 0.75)]),
        lengths=torch.tensor([1, 1, 1, 1]),
    )
    kjt = KeyedJaggedTensor.from_jt_dict(
        {
            "large_table_feature_0": large_jt,
        }
    )
    kjt = kjt.to('cuda')

    """ Option 1: Create the model at original resolution + shard"""
    vocab_size = VOCAB_SIZE
    embedding_dim = EMBEDDING_DIM
    large_tables = create_tables(embedding_dim=embedding_dim, vocab_size=vocab_size, data_type=DataType.FP16)
    ebc = torchrec.EmbeddingBagCollection(device=torch.device("meta"), tables=large_tables)

    model = WrapperModuleA(ebc)
    dmp_model = shard_ori_model(model, pg=pg, sharding_type=sharding_type)

    ori_out = dmp_model(kjt).wait()

    """ Option 2: Create quantized table + shard"""

    quant_dtype = QUANT_DTYPE
    split_scale_bias = sharding_type != ShardingType.TABLE_WISE

    large_tables_2 = create_tables(embedding_dim=embedding_dim, vocab_size=vocab_size)
    new_ebc_2 = torchrec.EmbeddingBagCollection(device=torch.device("cuda"), tables=large_tables_2)
    new_model_2 = WrapperModuleA(new_ebc_2)
    inplace_quantize(new_model_2, quant_dtype=quant_dtype, split_scale_bias=split_scale_bias)
    sharded_q_model = shard_q_model(new_model_2, pg, sharding_type=sharding_type)

    kjt = kjt.to(device)
    output = sharded_q_model(kjt).to_dict()

    """ Compare outputs """
    print(f'ori output {dist.get_rank()}', ori_out['large_table_feature_0'][:, :])
    print(f'q output {dist.get_rank()}', output['large_table_feature_0'][:, :])

    time.sleep(2)
    if dist.get_rank() == 1:
        logger.warning('=== original state ===')
        pprint(dmp_model.module.inf_model.ebc_.__dict__)

        print('==============')

        logger.warning('=== quant state ===')
        pprint(sharded_q_model.module.inf_model.ebc_.__dict__)

    return sharded_q_model


def spmd_sharing_q_simulation(
    sharding_type: ShardingType = ShardingType.ROW_WISE,
    world_size=2,
):
    ctx = multiprocess.get_context("spawn")
    processes = []
    for rank in range(world_size):
        p = ctx.Process(
            target=single_rank_q_execution,
            args=(rank, world_size, sharding_type, "nccl"),
        )
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
        assert 0 == p.exitcode


if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    world_size = torch.cuda.device_count()

    spmd_sharing_q_simulation(sharding_type=ShardingType.ROW_WISE, world_size=world_size)
@dstaay-fb
Copy link
Contributor

Looking quickly, I think key gap here is how inference model and training work wrt to devices / processes. In Training DMP works with assumption of one python process per (cuda) device, and devices implicitly communicate though collectives. While in Inference, its a single process for all devices, and data moves around though intra device alls (ie. tensor.to(...)).

So your test setup (multiple processes) is incompatible with this above assumption (which agree is poorly documented). Effectively what you have implemented is an inference model twice, one on each process. Take a look at our tests to see how we do this: https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/tests/test_quant_model_parallel.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants