Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing Device as str in EmbeddingBagCollectionSharder().shard #2624

Open
ArijitSinghEDA opened this issue Dec 10, 2024 · 3 comments
Open

Passing Device as str in EmbeddingBagCollectionSharder().shard #2624

ArijitSinghEDA opened this issue Dec 10, 2024 · 3 comments

Comments

@ArijitSinghEDA
Copy link

ArijitSinghEDA commented Dec 10, 2024

I am following the tutorial given here TorchRec Tutorial

I am on the section Planner Result. When I am trying the steps given here, I get the following error:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/local_custom_recsys/custom_recommender_torch.py", line 79, in <module>
[rank0]:     sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))
[rank0]:   File "/home/local_custom_recsys/.venv/lib/python3.10/site-packages/torchrec/distributed/embeddingbag.py", line 1247, in shard
[rank0]:     return ShardedEmbeddingBagCollection(
[rank0]:   File "/home/local_custom_recsys/.venv/lib/python3.10/site-packages/torchrec/distributed/embeddingbag.py", line 674, in __init__
[rank0]:     if module.device not in ["meta", "cpu"] and module.device.type not in [
[rank0]: AttributeError: 'str' object has no attribute 'type'
[rank0]:[W1210 12:30:57.097978618 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())

If I run the line torch.device("cuda") separately, then it has the type attribute, but still I am getting the above mentioned error

@sarckk
Copy link
Member

sarckk commented Dec 20, 2024

Hi @ArijitSinghEDA, could you share more info about how you're calling sharder.shard? I tried following the section on the TorchRec tutorial and it has worked fine for me.

Specifically, could you share how you are initializing the EmbeddingBagCollection instance ebc?

@ArijitSinghEDA
Copy link
Author

@sarckk I am attaching my code below:

import torch
from torchrec import EmbeddingBagCollection, EmbeddingBagConfig, PoolingType, JaggedTensor, KeyedJaggedTensor
import os
import torch.distributed as dist
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingEnv


# Set up environment variables for distributed training
# RANK is which GPU we are on, default 0
os.environ["RANK"] = "0"
# How many devices in our "world", colab notebook can only handle 1 process
os.environ["WORLD_SIZE"] = "1"
# Localhost as we are training locally
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"

# nccl backend is for GPUs, gloo is for CPUs
# dist.init_process_group(backend="gloo")
dist.init_process_group(backend="nccl")

local_device = "cuda" if torch.cuda.is_available() else "cpu"

ebc = EmbeddingBagCollection(
    device=local_device,
    tables=[
        EmbeddingBagConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=PoolingType.SUM,
        ),
        EmbeddingBagConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=PoolingType.SUM,
        )
    ]
)

product_jt = JaggedTensor(
    values=torch.tensor([1, 2, 1, 5], device=local_device), 
    lengths=torch.tensor([3, 1], device=local_device)
)
user_jt = JaggedTensor(
    values=torch.tensor([2, 3, 4, 1], device=local_device), 
    lengths=torch.tensor([2, 2], device=local_device)
)

kjt = KeyedJaggedTensor.from_jt_dict({
    "product": product_jt, 
    "user": user_jt
})

result = ebc(kjt)

sharder = EmbeddingBagCollectionSharder()

pg = dist.GroupMember.WORLD
assert pg is not None, "Process group is not initialized"

planner = EmbeddingShardingPlanner(
    topology=Topology(
        world_size=1,
        compute_device=local_device
    )
)

plan = planner.collective_plan(ebc, [sharder], pg)

env = ShardingEnv.from_process_group(pg)

sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device(type="cuda", index=0))

print(f"Sharded EBC Module:\n{sharded_ebc}")

@sarckk
Copy link
Member

sarckk commented Dec 20, 2024

EmbeddingBagCollection expects device arg to be a torch.device [source]. Please try the following and let me know if it works!

ebc = EmbeddingBagCollection(
    device=torch.device(local_device), # <-- pass in torch.device
    tables=[
        EmbeddingBagConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=PoolingType.SUM,
        ),
        EmbeddingBagConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=PoolingType.SUM,
        )
    ]
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants