Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

debug #8293

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft

debug #8293

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torch_xla/_internal/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def run_multiprocess(fn: Callable[..., R],
Dict of the form {device_ordinal: return_value}, where
return_value is the result of calling `fn`.
"""
return _WORLD_SIZE
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
if torch_xla._XLAC._xla_runtime_is_initialized():
raise RuntimeError('Runtime is already initialized. Do not use the XLA '
'device before calling xmp.spawn.')
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/_internal/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,17 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str:
def _spmd_find_master_ip(current_worker_hostname: str) -> str:
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
# Translate the hostname to an IP address, e.g. for TPUs on GKE.
current_worker_ip = socket.gethostbyname(current_worker_hostname)
ip_int = int(ip_address(current_worker_ip))
n_dev = xr.global_runtime_device_count()
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
local_ndev = len(torch_xla._XLAC._xla_get_runtime_devices())
# Create a global (n_dev x 2) tensor containing all process indices and IPs,
# and find the process 0 IP as the master IP.
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def get_xla_supported_devices(devkind: Optional[str] = None,
# TODO(wcromar): Remove `devkind` after 2.3 release cut. We no longer support
# multiple device types.
if not devkind:
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
devices = torch_xla._XLAC._xla_get_devices()
return [
f'xla:{i}'
Expand Down Expand Up @@ -224,6 +227,9 @@ def xla_replication_devices(
'Cannot replicate if number of devices ({}) is different from {}'.
format(len(local_devices), len(kind_devices)))
replication_devices = []
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
for device in torch_xla._XLAC._xla_get_all_devices():
# device is like 'CUDA:0'
xdev = _utils.parse_xla_device(device)
Expand Down Expand Up @@ -255,8 +261,12 @@ def set_replication(device: torch.device,
devctx = _get_device_context(device=device)
devices = [str(x) for x in devices]
if devices:
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
# sample replication_devices: ['CUDA:0', 'CUDA:1', 'CUDA:2', 'CUDA:3']
replication_devices = xla_replication_devices(devices)
traceback.print_stack()
print(f"Current line: {inspect.currentframe().f_lineno}")
torch_xla._XLAC._xla_set_replication_devices(replication_devices)
devctx.device_index = devices.index(device)
else:
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/debug/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def clear_all():

def metrics_report():
"""Retrieves a string containing the full metrics and counters report."""
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
return torch_xla._XLAC._xla_metrics_report()


Expand All @@ -78,6 +81,10 @@ def short_metrics_report(counter_names: list = None, metric_names: list = None):
'CompileTime', 'ExecuteTime', 'ExecuteReplicatedTime',
'TransferToDeviceTime', 'TransferFromDeviceTime'
]
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()

return torch_xla._XLAC._short_xla_metrics_report(counter_names, metric_names)


Expand Down
6 changes: 6 additions & 0 deletions torch_xla/distributed/spmd/xla_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs):
# which results from the sharding.
@property
def local_shards(self) -> List[XLAShard]:
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
shard_dev = torch_xla._XLAC._get_local_shards([self.global_tensor])[0]
replica_ind = torch_xla._XLAC._get_local_shard_replica_and_indices(
[self.global_tensor])[0]
Expand All @@ -128,6 +131,9 @@ def local_shards(self) -> List[XLAShard]:
def load_local_shards_(self, shards: List[XLAShard]):
data = [s.data for s in shards]
devices = [s.shard_device for s in shards]
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
torch_xla._XLAC._load_local_shards(self.global_tensor, data, devices)

@property
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def _extract_backend_config(

def jax_import_guard():
# Somehow, we need to grab the TPU before JAX locks it. Otherwise, any pt-xla TPU operations will hang.
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
torch_xla._XLAC._init_computation_client()


Expand Down
27 changes: 27 additions & 0 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def set_device_type(pjrt_device: str) -> None:
Args:
pjrt_device: 'TPU' or 'CPU'
"""
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
if torch_xla._XLAC._xla_runtime_is_initialized() and os.environ.get(
xenv.PJRT_DEVICE) != pjrt_device:
raise RuntimeError(
Expand Down Expand Up @@ -133,6 +136,9 @@ def local_process_count() -> int:

def global_device_count() -> int:
"""Returns the total number of devices across all processes/hosts."""
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
return len(torch_xla._XLAC._xla_get_all_devices())


Expand All @@ -141,6 +147,9 @@ def world_size() -> int:
global _WORLD_SIZE
if _WORLD_SIZE is not None:
return _WORLD_SIZE
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
if torch_xla._XLAC._xla_get_replication_devices_count() == 0:
_WORLD_SIZE = 1
else:
Expand All @@ -158,6 +167,9 @@ def local_device_count() -> int:

def addressable_device_count() -> int:
"""Returns the number of devices visible to this process."""
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
return torch_xla._XLAC._xla_num_devices()


Expand All @@ -183,10 +195,16 @@ def local_ordinal() -> int:


def process_index() -> int:
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
return torch_xla._XLAC._xla_get_process_index()


def process_count() -> int:
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
return torch_xla._XLAC._xla_get_num_processes()


Expand All @@ -202,16 +220,25 @@ def host_index() -> int:

# API below will be used to query physcial device attribute.
def runtime_device_attributes(device: str) -> Dict[str, object]:
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
return torch_xla._XLAC._xla_get_device_attributes(device)


def global_runtime_device_attributes() -> List[Dict[str, object]]:
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
return torch_xla._XLAC._xla_get_all_device_attributes()


@functools.lru_cache()
def global_runtime_device_count() -> int:
"""Returns the total number of runtime devices across all processes/hosts, especially useful for SPMD."""
import traceback,inspect
print(f"Current line: {inspect.currentframe().f_lineno}")
traceback.print_stack()
return len(torch_xla._XLAC._xla_get_all_runtime_devices())


Expand Down
Loading