Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Got stuck when training with multiple GPU using dist_train.sh #696

Closed
xiazhongyv opened this issue Dec 4, 2021 · 9 comments · Fixed by #728 or #784
Closed

Got stuck when training with multiple GPU using dist_train.sh #696

xiazhongyv opened this issue Dec 4, 2021 · 9 comments · Fixed by #728 or #784
Labels
bug Something isn't working

Comments

@xiazhongyv
Copy link
Contributor

All child threads getting stuck when training with multiple GPU using dist_train.sh
With CUDA == 11.3, Pytorch == 1.10
After diagnosis, I found it was stuck at https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/utils/common_utils.py#L166-L171

I modified the code from

dist.init_process_group(
        backend=backend,
        init_method='tcp://127.0.0.1:%d' % tcp_port,
        rank=local_rank,
        world_size=num_gpus
)

to

dist.init_process_group(
        backend=backend
)

and it worked.

I'm curious why this is so, and if someone else is having the same problem, you can try to do the same.

@dk-liang
Copy link
Contributor

dk-liang commented Dec 8, 2021

Thanks. I have the same problem, and I solved it using your method.

@Eaphan
Copy link

Eaphan commented Feb 2, 2022

@sshaoshuai After you fix bug in this way, the tcp_port is not used actually.
Can you fix it in a more decent way?

@sshaoshuai
Copy link
Collaborator

Thank you for the bug report. It has been fixed in #784.

Can you help to double check whether it works now?

@Eaphan
Copy link

Eaphan commented Feb 3, 2022

Thank you for the bug report. It has been fixed in #784.

Can you help to double check whether it works now?

@sshaoshuai Thanks for your work. It's ok now.

@aotiansysu
Copy link

For single-machine multi-GPU training, I also modified the local_rank to rank in torch.cuda.set_device() to be able to train properly. Otherwise it throws this error: Duplicate GPU detected : rank 0 and rank 1 both on CUDA device a000.
Modified:

def init_dist_pytorch(tcp_port, local_rank, backend='nccl'):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')
    num_gpus = torch.cuda.device_count()

    dist.init_process_group(
        backend=backend,
    )

    rank = dist.get_rank()
    torch.cuda.set_device(rank % num_gpus)
    return num_gpus, rank

@jiaminglei-lei
Copy link

@sshaoshuai
torch=1.9.0 cuda=11.1.
Got stuck at dist.init_process_group and the code is latest....
In other distribued training project having the same code for init_process_group, it ran successfully. ......

@jiaminglei-lei
Copy link

@sshaoshuai torch=1.9.0 cuda=11.1. Got stuck at dist.init_process_group and the code is latest.... In other distribued training project having the same code for init_process_group, it ran successfully. ......

after I uncomment the lines mentioned in #784 (comment), it works.

@sshaoshuai sshaoshuai pinned this issue Feb 19, 2022
@sshaoshuai
Copy link
Collaborator

sshaoshuai commented Feb 19, 2022

I have submitted a new PR to solve this issue in #815.

Please pull the latest master branch if you still get block when training with dist_train.sh.

@Liaoqing-up
Copy link

@sshaoshuai torch=1.9.0 cuda=11.1. Got stuck at dist.init_process_group and the code is latest.... In other distribued training project having the same code for init_process_group, it ran successfully. ......

after I uncomment the lines mentioned in #784 (comment), it works.

So what is the cause of this stuck? I also counter this and will try your way...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
7 participants