Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FastCell Example Fixes, Generalized trainer for both batch_first args #174

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion examples/pytorch/FastCells/fastcell_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,17 @@ def main():

(dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest,
mean, std) = helpermethods.preProcessData(dataDir)

assert dataDimension % inputDims == 0, "Infeasible per step input, " + \
"Timesteps have to be integer"

timeSteps = int(dataDimension / inputDims)
Xtrain = Xtrain.reshape((-1, timeSteps, inputDims))
Xtest = Xtest.reshape((-1, timeSteps, inputDims))

if not batch_first:
Xtrain = np.swapaxes(Xtrain, 0, 1)
Xtest = np.swapaxes(Xtest, 0, 1)

SachinG007 marked this conversation as resolved.
Show resolved Hide resolved
currDir = helpermethods.createTimeStampDir(dataDir, cell)

helpermethods.dumpCommand(sys.argv, currDir)
Expand Down
190 changes: 153 additions & 37 deletions pytorch/edgeml_pytorch/graph/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def getVars(self):

def get_model_size(self):
'''
Function to get aimed model size
'''
Function to get aimed model size
'''
mats = self.getVars()
endW = self._num_W_matrices
endU = endW + self._num_U_matrices
Expand Down Expand Up @@ -261,7 +261,7 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))

self.copy_previous_UW()
# self.copy_previous_UW()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out? This supports sparsity for KWS codes. The fastcell_Example.py doesn't depend on it because this is a bit broken given the new pytorch updates. Check with Harsha.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is causing segmentation fault in the rnnpool codes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah! I know the code is broken, it needs to be fixed. @harsha-simhadri used this in one of the codes so, we might have to fix it.


@property
def name(self):
Expand Down Expand Up @@ -330,7 +330,7 @@ class FastGRNNCUDACell(RNNCell):
'''
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, wSparsity=1.0, uSparsity=1.0, name="FastGRNNCUDACell"):
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_non_linearity, update_nonlinearity,
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_nonlinearity, update_nonlinearity,
1, 1, 2, wRank, uRank, wSparsity, uSparsity)
if utils.findCUDA() is None:
raise Exception('FastGRNNCUDA is supported only on GPU devices.')
Expand Down Expand Up @@ -967,78 +967,145 @@ class BaseRNN(nn.Module):
[batchSize, timeSteps, inputDims]
'''

def __init__(self, cell: RNNCell, batch_first=False):
def __init__(self, cell: RNNCell, batch_first=False, cell_reverse: RNNCell=None, bidirectional=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this RNNCell=None? I don't know what it does. Why should there be an RNNCell for this argument?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the bidirectional pass doesn't share weights this passes an extra initialized RNNCell for doing the backward pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I resolve this conversation ? @oindrilasaha

super(BaseRNN, self).__init__()
self._RNNCell = cell
self._RNNCell = cell
self._batch_first = batch_first
self._bidirectional = bidirectional
if cell_reverse is not None:
self.RNNCell_reverse = cell_reverse
elif self._bidirectional:
self.RNNCell_reverse = cell
SachinG007 marked this conversation as resolved.
Show resolved Hide resolved

def getVars(self):
return self._RNNCell.getVars()

def forward(self, input, hiddenState=None,
cellState=None):
self.device = input.device
self.num_directions = 2 if self._bidirectional else 1
SachinG007 marked this conversation as resolved.
Show resolved Hide resolved
if self._bidirectional:
self.num_directions = 2
else:
self.num_directions = 1

hiddenStates = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell.output_size]).to(self.device)

if self._bidirectional:
hiddenStates_reverse = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell_reverse.output_size]).to(self.device)

if hiddenState is None:
hiddenState = torch.zeros(
[input.shape[0] if self._batch_first else input.shape[1],
[self.num_directions, input.shape[0] if self._batch_first else input.shape[1],
self._RNNCell.output_size]).to(self.device)

if self._batch_first is True:
if self._RNNCell.cellType == "LSTMLR":
cellStates = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell.output_size]).to(self.device)
if self._bidirectional:
cellStates_reverse = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell_reverse.output_size]).to(self.device)
if cellState is None:
cellState = torch.zeros(
[input.shape[0], self._RNNCell.output_size]).to(self.device)
[self.num_directions, input.shape[0], self._RNNCell.output_size]).to(self.device)
for i in range(0, input.shape[1]):
hiddenState, cellState = self._RNNCell(
input[:, i, :], (hiddenState, cellState))
hiddenStates[:, i, :] = hiddenState
cellStates[:, i, :] = cellState
return hiddenStates, cellStates
hiddenState[0], cellState[0] = self._RNNCell(
input[:, i, :], (hiddenState[0].clone(), cellState[0].clone()))
hiddenStates[:, i, :] = hiddenState[0]
cellStates[:, i, :] = cellState[0]
if self._bidirectional:
hiddenState[1], cellState[1] = self._RNNCell_reverse(
input[:, input.shape[1]-i-1, :], (hiddenState[1].clone(), cellState[1].clone()))
hiddenStates_reverse[:, i, :] = hiddenState[1]
cellStates_reverse[:, i, :] = cellState[1]
if not self._bidirectional:
return hiddenStates, cellStates
else:
return torch.cat([hiddenStates,hiddenStates_reverse],-1), torch.cat([cellStates,cellStates_reverse],-1)
else:
for i in range(0, input.shape[1]):
hiddenState = self._RNNCell(input[:, i, :], hiddenState)
hiddenStates[:, i, :] = hiddenState
return hiddenStates
hiddenState[0] = self._RNNCell(input[:, i, :], hiddenState[0].clone())
hiddenStates[:, i, :] = hiddenState[0]
if self._bidirectional:
hiddenState[1] = self._RNNCell_reverse(
input[:, input.shape[1]-i-1, :], hiddenState[1].clone())
hiddenStates_reverse[:, i, :] = hiddenState[1]
if not self._bidirectional:
return hiddenStates
else:
return torch.cat([hiddenStates,hiddenStates_reverse],-1)
else:
if self._RNNCell.cellType == "LSTMLR":
cellStates = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell.output_size]).to(self.device)
if self._bidirectional:
cellStates_reverse = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell_reverse.output_size]).to(self.device)
if cellState is None:
cellState = torch.zeros(
[input.shape[1], self._RNNCell.output_size]).to(self.device)
[self.num_directions, input.shape[1], self._RNNCell.output_size]).to(self.device)
for i in range(0, input.shape[0]):
hiddenState, cellState = self._RNNCell(
input[i, :, :], (hiddenState, cellState))
hiddenStates[i, :, :] = hiddenState
cellStates[i, :, :] = cellState
return hiddenStates, cellStates
hiddenState[0], cellState[0] = self._RNNCell(
input[i, :, :], (hiddenState[0].clone(), cellState[0].clone()))
hiddenStates[i, :, :] = hiddenState[0]
cellStates[i, :, :] = cellState[0]
if self._bidirectional:
hiddenState[1], cellState[1] = self._RNNCell_reverse(
input[input.shape[0]-i-1, :, :], (hiddenState[1].clone(), cellState[1].clone()))
hiddenStates_reverse[i, :, :] = hiddenState[1]
cellStates_reverse[i, :, :] = cellState[1]
if not self._bidirectional:
return hiddenStates, cellStates
else:
return torch.cat([hiddenStates,hiddenStates_reverse],-1), torch.cat([cellStates,cellStates_reverse],-1)
else:
for i in range(0, input.shape[0]):
hiddenState = self._RNNCell(input[i, :, :], hiddenState)
hiddenStates[i, :, :] = hiddenState
return hiddenStates
hiddenState[0] = self._RNNCell(input[i, :, :], hiddenState[0].clone())
hiddenStates[i, :, :] = hiddenState[0]
if self._bidirectional:
hiddenState[1] = self._RNNCell_reverse(
input[input.shape[0]-i-1, :, :], hiddenState[1].clone())
hiddenStates_reverse[i, :, :] = hiddenState[1]
if not self._bidirectional:
return hiddenStates
else:
return torch.cat([hiddenStates,hiddenStates_reverse],-1)


class LSTM(nn.Module):
"""Equivalent to nn.LSTM using LSTMLRCell"""

def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0, batch_first=False):
wSparsity=1.0, uSparsity=1.0, batch_first=False,
bidirectional=False, is_shared_bidirectional=True):
super(LSTM, self).__init__()
self._bidirectional = bidirectional
self._batch_first = batch_first
self._is_shared_bidirectional = is_shared_bidirectional
self.cell = LSTMLRCell(input_size, hidden_size,
gate_nonlinearity=gate_nonlinearity,
update_nonlinearity=update_nonlinearity,
wRank=wRank, uRank=uRank,
wSparsity=wSparsity, uSparsity=uSparsity)
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first)
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional)

if self._bidirectional is True and self._is_shared_bidirectional is False:
self.cell_reverse = LSTMLRCell(input_size, hidden_size,
gate_nonlinearity=gate_nonlinearity,
update_nonlinearity=update_nonlinearity,
wRank=wRank, uRank=uRank,
wSparsity=wSparsity, uSparsity=uSparsity)
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional)

def forward(self, input, hiddenState=None, cellState=None):
return self.unrollRNN(input, hiddenState, cellState)
Expand All @@ -1049,14 +1116,26 @@ class GRU(nn.Module):

def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0, batch_first=False):
wSparsity=1.0, uSparsity=1.0, batch_first=False,
bidirectional=False, is_shared_bidirectional=True):
super(GRU, self).__init__()
self._bidirectional = bidirectional
self._batch_first = batch_first
self._is_shared_bidirectional = is_shared_bidirectional
self.cell = GRULRCell(input_size, hidden_size,
gate_nonlinearity=gate_nonlinearity,
update_nonlinearity=update_nonlinearity,
wRank=wRank, uRank=uRank,
wSparsity=wSparsity, uSparsity=uSparsity)
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first)
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional)

if self._bidirectional is True and self._is_shared_bidirectional is False:
self.cell_reverse = GRULRCell(input_size, hidden_size,
gate_nonlinearity=gate_nonlinearity,
update_nonlinearity=update_nonlinearity,
wRank=wRank, uRank=uRank,
wSparsity=wSparsity, uSparsity=uSparsity)
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional)

def forward(self, input, hiddenState=None, cellState=None):
return self.unrollRNN(input, hiddenState, cellState)
Expand All @@ -1067,14 +1146,26 @@ class UGRNN(nn.Module):

def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0, batch_first=False):
wSparsity=1.0, uSparsity=1.0, batch_first=False,
bidirectional=False, is_shared_bidirectional=True):
super(UGRNN, self).__init__()
self._bidirectional = bidirectional
self._batch_first = batch_first
self._is_shared_bidirectional = is_shared_bidirectional
self.cell = UGRNNLRCell(input_size, hidden_size,
gate_nonlinearity=gate_nonlinearity,
update_nonlinearity=update_nonlinearity,
wRank=wRank, uRank=uRank,
wSparsity=wSparsity, uSparsity=uSparsity)
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first)
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional)

if self._bidirectional is True and self._is_shared_bidirectional is False:
self.cell_reverse = UGRNNLRCell(input_size, hidden_size,
gate_nonlinearity=gate_nonlinearity,
update_nonlinearity=update_nonlinearity,
wRank=wRank, uRank=uRank,
wSparsity=wSparsity, uSparsity=uSparsity)
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional)

def forward(self, input, hiddenState=None, cellState=None):
return self.unrollRNN(input, hiddenState, cellState)
Expand All @@ -1085,15 +1176,28 @@ class FastRNN(nn.Module):

def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0, alphaInit=-3.0, betaInit=3.0, batch_first=False):
wSparsity=1.0, uSparsity=1.0, alphaInit=-3.0, betaInit=3.0,
batch_first=False, bidirectional=False, is_shared_bidirectional=True):
super(FastRNN, self).__init__()
self._bidirectional = bidirectional
self._batch_first = batch_first
self._is_shared_bidirectional = is_shared_bidirectional
self.cell = FastRNNCell(input_size, hidden_size,
gate_nonlinearity=gate_nonlinearity,
update_nonlinearity=update_nonlinearity,
wRank=wRank, uRank=uRank,
wSparsity=wSparsity, uSparsity=uSparsity,
alphaInit=alphaInit, betaInit=betaInit)
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first)
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional)

if self._bidirectional is True and self._is_shared_bidirectional is False:
self.cell_reverse = FastRNNCell(input_size, hidden_size,
gate_nonlinearity=gate_nonlinearity,
update_nonlinearity=update_nonlinearity,
wRank=wRank, uRank=uRank,
wSparsity=wSparsity, uSparsity=uSparsity,
alphaInit=alphaInit, betaInit=betaInit)
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional)

def forward(self, input, hiddenState=None, cellState=None):
return self.unrollRNN(input, hiddenState, cellState)
Expand All @@ -1105,15 +1209,27 @@ class FastGRNN(nn.Module):
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0,
batch_first=False):
batch_first=False, bidirectional=False, is_shared_bidirectional=True):
super(FastGRNN, self).__init__()
self._bidirectional = bidirectional
self._batch_first = batch_first
self._is_shared_bidirectional = is_shared_bidirectional
self.cell = FastGRNNCell(input_size, hidden_size,
gate_nonlinearity=gate_nonlinearity,
update_nonlinearity=update_nonlinearity,
wRank=wRank, uRank=uRank,
wSparsity=wSparsity, uSparsity=uSparsity,
zetaInit=zetaInit, nuInit=nuInit)
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first)
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional)

if self._bidirectional is True and self._is_shared_bidirectional is False:
self.cell_reverse = FastGRNNCell(input_size, hidden_size,
gate_nonlinearity=gate_nonlinearity,
update_nonlinearity=update_nonlinearity,
wRank=wRank, uRank=uRank,
wSparsity=wSparsity, uSparsity=uSparsity,
zetaInit=zetaInit, nuInit=nuInit)
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional)

def getVars(self):
return self.unrollRNN.getVars()
Expand Down Expand Up @@ -1222,8 +1338,8 @@ def getVars(self):

def get_model_size(self):
'''
Function to get aimed model size
'''
Function to get aimed model size
'''
mats = self.getVars()
endW = self._num_W_matrices
endU = endW + self._num_U_matrices
Expand Down
Loading