You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been trying to use memory_efficient_fusion to see if I can speed up a main bottleneck in my code, but I hit a RuntimeError. This issue continues from #1011.
The code to reproduce this is as follows.
import torch
from torch import nn
import functorch
from functorch import make_functional, vmap, jacrev, grad
from functorch.compile import memory_efficient_fusion
import time
_ = torch.manual_seed(1234)
#version info
print("PyTorch version: ", torch.__version__) #PyTorch version: 2.0.0.dev20230116
print("CUDA version: ", torch.version.cuda) #CUDA version: 11.6
print("FuncTorch version: ", functorch.__version__) #FuncTorch version: 2.0.0.dev20230116
#=============================================#
#time with torch synchronization
def sync_time() -> float:
torch.cuda.synchronize()
return time.perf_counter()
class model(nn.Module):
def __init__(self, num_inputs, num_hidden):
super(model, self).__init__()
self.num_inputs=num_inputs
self.func = nn.Tanh()
self.fc1 = nn.Linear(2, num_hidden)
self.fc2 = nn.Linear(num_hidden, num_inputs)
def forward(self, x):
"""
Takes x in [B,A] and maps it to sign/logabsdet value in Tuple([B,], [B,])
"""
x=x.unsqueeze(-1)
idx=len(x.shape) #creates args for repeat if vmap is used or not
rep=[1 for _ in range(idx)]
rep[-2] = self.num_inputs
g = x.mean(dim=(idx-2), keepdim=True).repeat(*rep)
f = torch.cat((x,g), dim=-1)
h = self.func(self.fc1(f))
mat = self.fc2(h)
sgn, logabs = torch.linalg.slogdet(mat)
return sgn, logabs
#=============================================#
B=4096 #batch
N=2 #input nodes
H=64 #number of hidden nodes
device = torch.device('cuda')
x = torch.randn(B, N, device=device) #input data
net = model(N, H) #our model
net=net.to(device)
sgn, logabs = net(x)
fnet, params = make_functional(net)
def calc_logabs(params, x):
_, logabs = fnet(params, x)
return logabs
def calc_dlogabs_dx(params, x):
dlogabs_dx = jacrev(func=calc_logabs, argnums=1)(params, x)
return dlogabs_dx, dlogabs_dx #return aux
def local_kinetic_from_log_vmap(params, x):
d2logabs_dx2, dlogabs_dx = jacrev(func=calc_dlogabs_dx, argnums=1, has_aux=True)(params, x)
_local_kinetic = -0.5*(d2logabs_dx2.diagonal(0,-2,-1).sum() + dlogabs_dx.pow(2).sum())
return _local_kinetic
#memory efficient fusion here
#with torch.jit.fuser("fuser2"): #is this needed (from functorch/issues/840)
ps_elocal = vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0))
ps_elocal_fusion = memory_efficient_fusion(vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0)))
t1=sync_time()
out1 = ps_elocal(params, x)
t2=sync_time()
ps_elocal_fusion(params, x) #crashes here: aten::is_same_size no batching rule
t3=sync_time()
#Compare memory_efficient_fusion on the function's walltime
print("Laplacian (standard): %4.2e (s)",t2-t1)
print("Laplacian (fusion): %4.2e (s)",t3-t2)
The traceback is as follows,
PyTorch version: 2.0.0.dev20230116
CUDA version: 11.6
FuncTorch version: 2.0.0.dev20230116
Failed to collect metadata on function, produced code may be suboptimal. Known situations this can occur are inference mode only compilation involving resize_ or prims (!schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED); if your situation looks different please file a bug to PyTorch.
Traceback (most recent call last):
File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1368, in aot_wrapper_dedupe
fw_metadata, _out = run_functionalized_fw_and_collect_metadata(flat_fn)(
File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 569, in inner
flat_f_outs = f(*flat_f_args)
... lot more errors
File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/autograd/__init__.py", line 285, in grad
grad_outputs_ = _make_grads(t_outputs, grad_outputs_, is_grads_batched=is_grads_batched)
File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/autograd/__init__.py", line 53, in _make_grads
if not torch.is_same_size(out, first_grad):
RuntimeError: Batching rule not implemented for aten::is_same_size. We could not generate a fallback.
This code was created with the latest nightly version.
PyTorch version: 2.0.0.dev20230116
CUDA version: 11.6
FuncTorch version: 2.0.0.dev20230116
The text was updated successfully, but these errors were encountered:
Hi All,
I've been trying to use
memory_efficient_fusion
to see if I can speed up a main bottleneck in my code, but I hit aRuntimeError
. This issue continues from #1011.The code to reproduce this is as follows.
The traceback is as follows,
This code was created with the latest nightly version.
The text was updated successfully, but these errors were encountered: