You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I am using this linked solution from stack overflow to compute gradients more efficiently than a manual loop.
I notice that there is some small difference in the gradients calculated using the two methods (i.e. torch.abs(grads_torch - grads_func).sum() returns ~10e-06). What might explain this difference? Is one solution more correct than the other?
MWE
import torch
from torchvision import datasets, transforms
import torch.nn as nn
###### SETUP ######
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
h = self.fc1(x)
pred = self.fc2(self.relu(h))
return pred
train_dataset = datasets.MNIST(root='./data', train=True, download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)
X, y = next(iter(train_dataloader)) # take a random batch of data
net = MLP(28*28, 20, 10) # define a network
###### CALCULATE GRADIENTS WITH TORCH AUTOGRAD GRAD ######
def calculate_gradients(model, X):
# Create a tensor to hold the gradients
gradients = torch.zeros(X.shape[0], 10, sum(p.numel() for p in model.parameters()))
# Calculate the gradients for each input and target dimension
for i in range(X.shape[0]):
for j in range(10):
model.zero_grad()
output = model(X[i])
# Calculate the gradients
grads = torch.autograd.grad(output[j], model.parameters())
# Flatten the gradients and store them
gradients[i, j, :] = torch.cat([g.view(-1) for g in grads])
return gradients
grads_torch = calculate_gradients(net, X.view(X.shape[0], -1))
###### NOW CALCULATE THE SAME GRADIENTS WITH FUNCTORCH ######
# extract the parameters and buffers for a functional call
params = {k: v.detach() for k, v in net.named_parameters()}
buffers = {k: v.detach() for k, v in net.named_buffers()}
def one_sample(sample):
# this will calculate the gradients for a single sample
# we want the gradients for each output wrt to the parameters
# this is the same as the jacobian of the network wrt the parameters
# define a function that takes the as input returns the output of the network
call = lambda x: torch.func.functional_call(net, (x, buffers), sample)
# calculate the jacobian of the network wrt the parameters
J = torch.func.jacrev(call)(params)
# J is a dictionary with keys the names of the parameters and values the gradients
# we want a tensor
grads = torch.cat([v.flatten(1) for v in J.values()],-1)
return grads
# no we can use vmap to calculate the gradients for all samples at once
grads_func = torch.vmap(one_sample)(X.flatten(1))
print(torch.allclose(grads_torch, grads_func)) # returns True
print(torch.abs(grads_torch - grads_func).sum()) # returns tensor(1.4454e-05)
The text was updated successfully, but these errors were encountered:
Hello, I am using this linked solution from stack overflow to compute gradients more efficiently than a manual loop.
I notice that there is some small difference in the gradients calculated using the two methods (i.e.
torch.abs(grads_torch - grads_func).sum()
returns ~10e-06). What might explain this difference? Is one solution more correct than the other?MWE
The text was updated successfully, but these errors were encountered: