diff --git a/src/gradcheck.lua b/src/gradcheck.lua index 4e31780..7dc8cf5 100644 --- a/src/gradcheck.lua +++ b/src/gradcheck.lua @@ -1,5 +1,6 @@ -- Autograd local autograd = require 'autograd' +local util = require 'autograd.util' -- Perturbation (finite diffs): local perturbation = 1e-6 @@ -12,20 +13,30 @@ local function jacobianFromAutograd(func, inputs, key) -- Autograd: local df = autograd(func) local grads = df(table.unpack(inputs)) - local gradsVerify = df(table.unpack(inputs)) -- Find grad: local g = autograd.util.nestedGet(grads, key) + local g_clone + if torch.isTensor(g) then + g_clone = g:clone() + end + + -- Get the grad again + local gradsVerify = df(table.unpack(inputs)) local gVerify = autograd.util.nestedGet(gradsVerify, key) local err + local overwrite_err = 0 if torch.isTensor(g) then err = (g - gVerify):abs():max() + overwrite_err = (g - g_clone):abs():max() else err = torch.abs(g - gVerify) end if err ~= 0 then error("autograd gradient not deterministic") + elseif overwrite_err ~= 0 then + error("autograd gradient overwritten when called twice") end -- Return grads: diff --git a/src/gradfuns.lua b/src/gradfuns.lua index 80fade1..24306e7 100644 --- a/src/gradfuns.lua +++ b/src/gradfuns.lua @@ -152,7 +152,12 @@ functions.set = { return nil end, function(g, ans, x, k, v) - return g[k] + local gk = getValue(g[k]) + if type(gk) == 'number' then + return gk + else + return torch.clone(gk) + end end, } diff --git a/test/test.lua b/test/test.lua index 9bdaad3..8c8e8eb 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1687,6 +1687,12 @@ local tests = { return torch.sum(xc) end tester:assert(gradcheck(f4,{x=torch.randn(10,10),y=torch.randn(3)}), "Incorrect gradient") + local f5 = function(params) + local xc = torch.clone(params.x) + xc[2] = params.y * 2.0 + return torch.sum(xc) + end + tester:assert(gradcheck(f5,{x=torch.randn(10,10),y=torch.randn(10)}), "Incorrect gradient") end, ScalarSigmoid = function()