-
Notifications
You must be signed in to change notification settings - Fork 371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FastCell Example Fixes, Generalized trainer for both batch_first args #174
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR #173 trumps over this PR.
@oindrilasaha would be good if you can have a look at this PR |
@SachinG007 please incorporate the changes from #173 for the fixing of optimizer for FC. Earlier gradient updates were not happening for FC due to some convention mismatch. Please look at this: https://github.com/microsoft/EdgeML/pull/173/files#diff-7b39dde7dda6360cbf530db88f5b9f8dR12-R62 and incorporate it. Also, try to keep PRs for different things separate. When we are fixing fastcell_Example.py related stuff, let us stick to that. Thanks |
@adityakusupati , incorporated the changes from PR#173. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have reviewed everything except the bidirectional stuff. @oindrilasaha has to review the bi-directional stuff.
class SimpleFC(nn.Module): | ||
def __init__(self, input_size, num_classes, name="SimpleFC"): | ||
super(SimpleFC, self).__init__() | ||
self.FC = nn.Parameter(torch.randn([input_size, num_classes])) | ||
self.FCbias = nn.Parameter(torch.randn([num_classes])) | ||
|
||
def forward(self, input): | ||
return torch.matmul(input, self.FC) + self.FCbias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if we should place this here or make a different file or place in rnn.py. Ideally, this is the same as Bonsai.py or ProtoNN.py. We need to make a decision about placing this at the right place.
self.FCbias = nn.Parameter(torch.randn( | ||
[self.numClasses])).to(self.device) | ||
|
||
self.simpleFC = SimpleFC(self.FastObj.output_size, self.numClasses).to(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a better name for the instance instead of self.simpleFC (so that we are generic enough to fit the classifier() function.
|
||
def classifier(self, feats): | ||
''' | ||
Can be raplaced by any classifier | ||
TODO: Make this a separate class if needed | ||
''' | ||
return torch.matmul(feats, self.FC) + self.FCbias | ||
return self.simpleFC(feats) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look at the above comment.
pytorch/edgeml_pytorch/graph/rnn.py
Outdated
hiddenStates = torch.zeros( | ||
[input.shape[0], input.shape[1], | ||
self._RNNCell.output_size]).to(self.device) | ||
self.RNNCell.output_size]).to(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the lines after this are about bi-directional stuff. I would defer to @oindrilasaha to check all this stuff. She has to sign off on it as she uses them the most.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@oindrilasaha any comments for the later section of the code ?
self.FC = nn.Parameter(torch.randn([input_size, num_classes])) | ||
self.FCbias = nn.Parameter(torch.randn([num_classes])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SachinG007 change this to self.weight and self.bias.
@SachinG007 any updates on this PR? |
Checked the changes with Aditya