Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to use all tpu core in pytorch xla #8215

Open
fancy45daddy opened this issue Oct 4, 2024 · 0 comments
Open

how to use all tpu core in pytorch xla #8215

fancy45daddy opened this issue Oct 4, 2024 · 0 comments

Comments

@fancy45daddy
Copy link

❓ Questions and Help

I follow the code in https://github.com/pytorch/xla/blob/master/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb

But use xmp.spawn(print_device, args=(lock,), nprocs=8, start_method='fork')

the source code

import os
os.environ.pop('TPU_PROCESS_ADDRESSES')

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import multiprocessing as mp
lock = mp.Manager().Lock()

def print_device(i, lock):
    device = xm.xla_device()
    with lock:
        print('process', i, device)
        
xmp.spawn(print_device, args=(lock,), nprocs=8, start_method='fork')

WARNING:root:Unsupported nprocs (8), ignoring...
process 4 xla:0
process 5 xla:1
process 0 xla:0
process 1 xla:1
process 2 xla:0
process 3 xla:1
process 6 xla:0
process 7 xla:1

xla just can see 2 xla device. But when I run xm.get_xla_supported_devices() it list all ['xla:0', 'xla:1', 'xla:2', 'xla:3', 'xla:4', 'xla:5', 'xla:6', 'xla:7'] I want to know how to use all tpu cores?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant