Skip to content
This repository has been archived by the owner on Nov 1, 2021. It is now read-only.

Fix zero gradient for subtensor assignment. #127

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/gradcheck.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
-- Autograd
local autograd = require 'autograd'
local util = require 'autograd.util'

-- Perturbation (finite diffs):
local perturbation = 1e-6
Expand All @@ -12,20 +13,30 @@ local function jacobianFromAutograd(func, inputs, key)
-- Autograd:
local df = autograd(func)
local grads = df(table.unpack(inputs))
local gradsVerify = df(table.unpack(inputs))

-- Find grad:
local g = autograd.util.nestedGet(grads, key)
local g_clone
if torch.isTensor(g) then
g_clone = g:clone()
end

-- Get the grad again
local gradsVerify = df(table.unpack(inputs))
local gVerify = autograd.util.nestedGet(gradsVerify, key)
local err
local overwrite_err = 0
if torch.isTensor(g) then
err = (g - gVerify):abs():max()
overwrite_err = (g - g_clone):abs():max()
else
err = torch.abs(g - gVerify)
end

if err ~= 0 then
error("autograd gradient not deterministic")
elseif overwrite_err ~= 0 then
error("autograd gradient overwritten when called twice")
end

-- Return grads:
Expand Down
7 changes: 6 additions & 1 deletion src/gradfuns.lua
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,12 @@ functions.set = {
return nil
end,
function(g, ans, x, k, v)
return g[k]
local gk = getValue(g[k])
if type(gk) == 'number' then
return gk
else
return torch.clone(gk)
end
end,
}

Expand Down
6 changes: 6 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,12 @@ local tests = {
return torch.sum(xc)
end
tester:assert(gradcheck(f4,{x=torch.randn(10,10),y=torch.randn(3)}), "Incorrect gradient")
local f5 = function(params)
local xc = torch.clone(params.x)
xc[2] = params.y * 2.0
return torch.sum(xc)
end
tester:assert(gradcheck(f5,{x=torch.randn(10,10),y=torch.randn(10)}), "Incorrect gradient")
end,

ScalarSigmoid = function()
Expand Down