Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchbench] hf_Longformer fails to run #453

Closed
alexbaden opened this issue Feb 6, 2024 · 7 comments
Closed

[torchbench] hf_Longformer fails to run #453

alexbaden opened this issue Feb 6, 2024 · 7 comments

Comments

@alexbaden
Copy link
Contributor

Appears to be an out of memory issue. The fbgemm_gpu undefined symbol messages are fairly common and appear on passing tests.

» benchmarks/dynamo/torchbench.py --float32 -dxpu -n10 --no-skip --dashboard --training --inductor --accuracy --output /tmp/torchbench.csv --filter hf_Longformer      

loading model: 0it [00:04, ?it/s]
xpu  train hf_Longformer                      
/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/fbgemm_gpu/fbgemm_gpu_py.so: undefined symbol: _ZNK5torch8autograd4Node4nameEv
/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/fbgemm_gpu/fbgemm_gpu_py.so: undefined symbol: _ZNK5torch8autograd4Node4nameEv
skipping cudagraphs for unknown reason
skipping cudagraphs for unknown reason
ERROR:common:backend='inductor' raised:
RuntimeError: Allocation is out of device memory on current platform.

While executing %full_562 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([48, 3, 256, 513], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: xpu:0, pin_memory: False})
Original traceback:
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 1318, in <resume in forward>
    layer_outputs = layer_module(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 1246, in forward
    self_attn_outputs = self.attention(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 1182, in forward
    self_outputs = self.self(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 571, in forward
    attn_scores = self._sliding_chunks_query_key_matmul(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 868, in _sliding_chunks_query_key_matmul
    diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[


Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
Traceback (most recent call last):
  File "/localdisk/abaden/Projects/frameworks.ai.pytorch.private-gpu/benchmarks/dynamo/common.py", line 2145, in check_accuracy
    new_result = optimized_model_iter_fn(model_copy, example_inputs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/localdisk/abaden/Projects/frameworks.ai.pytorch.private-gpu/benchmarks/dynamo/common.py", line 1909, in run_n_iterations
    self.model_iter_fn(mod, inputs, collect_outputs=False)
  File "/localdisk/abaden/Projects/frameworks.ai.pytorch.private-gpu/benchmarks/dynamo/torchbench.py", line 461, in forward_and_backward_pass
    cloned_inputs = clone_inputs(inputs)
  File "/localdisk/abaden/Projects/frameworks.ai.pytorch.private-gpu/benchmarks/dynamo/torchbench.py", line 462, in <resume in forward_and_backward_pass>
    self.optimizer_zero_grad(mod)
  File "/localdisk/abaden/Projects/frameworks.ai.pytorch.private-gpu/benchmarks/dynamo/torchbench.py", line 464, in <resume in forward_and_backward_pass>
    pred = mod(*cloned_inputs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 1835, in forward
    outputs = self.longformer(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 1738, in forward
    encoder_outputs = self.encoder(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 1291, in forward
    is_global_attn = is_index_global_attn.flatten().any().item()
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2162, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 857, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 957, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1024, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1009, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/__init__.py", line 1568, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1150, in compile_fx
    return aot_autograd(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3891, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3429, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2212, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2392, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2825, in aot_dispatch_autograd
    fw_module, bw_module = aot_config.partition_fn(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1119, in partition_fn
    joint_graph_passes(graph)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 204, in joint_graph_passes
    constant_fold_uniform_value(graph)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 119, in constant_fold_uniform_value
    cf.run()
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_inductor/freezing.py", line 175, in run
    return super().run(initial_env=env)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/fx/interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_inductor/freezing.py", line 145, in run_node
    out = super().run_node(node)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/fx/interpreter.py", line 195, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/fx/interpreter.py", line 267, in call_function
    return target(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
    return self._op(*args, **kwargs or {})
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Allocation is out of device memory on current platform.

While executing %full_562 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([48, 3, 256, 513], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: xpu:0, pin_memory: False})
Original traceback:
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 1318, in <resume in forward>
    layer_outputs = layer_module(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 1246, in forward
    self_attn_outputs = self.attention(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 1182, in forward
    self_outputs = self.self(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 571, in forward
    attn_scores = self._sliding_chunks_query_key_matmul(
  File "/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 868, in _sliding_chunks_query_key_matmul
    diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[


Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

TorchDynamo optimized model failed to run because of following error
fail_to_run
@etiotto
Copy link
Contributor

etiotto commented Feb 6, 2024

@alexbaden these warnings might cause the benchmark to miscompile and then fail at runtime. Do you know how to fix them ?

/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/fbgemm_gpu/fbgemm_gpu_py.so: undefined symbol: _ZNK5torch8autograd4Node4nameEv
/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/fbgemm_gpu/fbgemm_gpu_py.so: undefined symbol: _ZNK5torch8autograd4Node4nameEv
skipping cudagraphs for unknown reason
skipping cudagraphs for unknown reason

@ESI-SYD any input on how to fix the benchmark ?

@alexbaden alexbaden self-assigned this Feb 6, 2024
@ESI-SYD
Copy link
Contributor

ESI-SYD commented Feb 7, 2024

@alexbaden these warnings might cause the benchmark to miscompile and then fail at runtime. Do you know how to fix them ?

/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/fbgemm_gpu/fbgemm_gpu_py.so: undefined symbol: _ZNK5torch8autograd4Node4nameEv
/localdisk/abaden/Projects/envs/triton-benchmark-env/lib/python3.10/site-packages/fbgemm_gpu/fbgemm_gpu_py.so: undefined symbol: _ZNK5torch8autograd4Node4nameEv
skipping cudagraphs for unknown reason
skipping cudagraphs for unknown reason

@ESI-SYD any input on how to fix the benchmark ?

It seems like fbgemm-gpu-nightly incompatible issue. please check and try stable one: pip install fbgemm-gpu

@alexbaden
Copy link
Contributor Author

It might be an environment issue yes. But is fbgemm gpu used if cuda is not present?

@whitneywhtsang
Copy link
Contributor

Also fail with v2.1.

@vlad-penkin vlad-penkin added the bug Something isn't working label Feb 9, 2024
@vlad-penkin vlad-penkin added this to the E2E pass rate milestone Feb 9, 2024
@ienkovich ienkovich self-assigned this Feb 15, 2024
@ienkovich
Copy link
Contributor

I tried to trace benchmark execution to see if we tried to allocate too much memory. From the trace, it's not clear what the reason for OOM is. The failing call is always the same:

>>>> [160949440778] zeCommandListAppendLaunchKernel: hCommandList = 0x55c701e4d650 hKernel = 0x55c703075f70 (at::AtenIpexTypeXPU::ElementwiseKernelFunctor<at::AtenIpexTypeXPU::dpcpp_loops_launch_legacy_kernel_functor<at::AtenIpexTypeXPU\
::BinaryFunctor<float, float, bool, at::AtenIpexTypeXPU::impl::EqKernelDpcppFunctor<float> >, bool, OffsetCalculator<3, unsigned int, false>, 3> >) pLaunchFuncArgs = 0x7ffc44794600 {448, 1, 1} hSignalEvent = 0x55c7018eaec0 numWaitEvents\
 = 0 phWaitEvents = 0
<<<< [161952725571] zeCommandListAppendLaunchKernel(847) [1003272957 ns] -> ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY(0x1879048195)

This kernel is previously launched multiple times with the same size with no problems. The total memory allocated using zeMemAllocDevice is 1593835520 bytes. It's unclear why we suddenly get OOM here, but the problem is not in Triton or Triton kernels. The benchmark mostly uses IPEX kernels + 4 small Triton kernels that run with no problems.

@gshimansky
Copy link
Contributor

About fbgemm you may want to take a look at building fbgemm from sources for CPU. Current binaries of fbgemm cannot be used because they don't link with pytorch that we use.
#548 (comment)

@vlad-penkin
Copy link
Contributor

Run failure is no longer reproduceable:

Env:

  • pytorch is built from source, top of the main trunk, commit_id - 9a8ab778d34bd24c5caceb340837483decc4c311
  • triton xpu is built from source, top of the main trunk, commit_id - fe93a00ffe438e9ba8c8392c0b051b1662c810de
  • benchmark is built from source, top of the main trunk, commit_id - d54ca9f80ead108c8797441681e219becaf963d8
  • torchaudio is built from source, top of the main trunk, commit_id - 1980f8af5bcd0bb2ce51965cf79d8d4c25dad8a0
  • torchvision is built from source, top of the main trunk, commit_id - 10239873229e527f8b7e7b3340c40ee38bb1cfc4
  • PyTorch Dependency Bundle 0.5.0
  • Latest Rolling Driver

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants