-
Notifications
You must be signed in to change notification settings - Fork 348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ValueError: Per sample gradient is not initialized. Not updated in backward pass? #645
Comments
Our DP optimizer requires per_sample_gradient for every trainable parameter in the model (https://github.com/pytorch/opacus/blob/main/opacus/optimizers/optimizer.py#L283C17-L283C18). However, in your case, there is some part of the model not activated during (sometime) in the training, thus leading to a miss of per_sample_gradient. A potential fix might be freezing those parameters (setting param.requires_grad = False). |
Thanks for your advise! I've tried to fix the problem in my code through your suggestion, but the code still can't work perfectly. Could you please tell me how did you fixed that problem through your method? |
Thanks for the update. Could you elaborate the meaning of "the code still can't work perfectly"? Is there any new error popping up? |
Here's the thing, I modified the code according to your suggestion using my own method. I modified the set_submodel method under the convnet class: after generating the submodel, I first froze all parameters and then enabled the parameters of the pre-set submodel. The following is the code for that part: def set_submodel(self, ind, strategy=None):
self.ind = ind
assert ind <= 3
if strategy is None:
strategy = self.strategy
print(strategy, ind, self.dataset)
if strategy == 'progressive':
modules = []
for i in range(ind + 1):
modules.append(self.module_splits[i])
self.enc = nn.Sequential(*modules)
self.head = self.head_splits[ind]
for param in self.parameters():
param.requires_grad = False
for module in self.enc:
for param in module.parameters():
param.requires_grad = True
for param in self.head.parameters():
param.requires_grad = True
elif strategy == 'baseline':
modules = []
for i in range(len(self.module_splits)):
modules.append(self.module_splits[i])
self.enc = nn.Sequential(*modules)
self.head = self.classifier
for param in self.parameters():
param.requires_grad = False
for module in self.enc:
for param in module.parameters():
param.requires_grad = True
for param in self.head.parameters():
param.requires_grad = True
else:
raise NotImplementedError() However, after rewriting it, the same error still occurs. I'm a bit confused about this. Could you please share the code you used to solve this problem? |
🐛 Bug
Hello!
I try to run a deep learning task on CIFAR10 using Opacus, but I get "ValueError: Per sample gradient is not initialized. Not updated in backward pass?" error. I have seen the discussion about this issue but I still can't figure out what is wrong with my code.
Here, I provide the code in the hope that others can reproduce it and help fix it.
Code
And let me introduce my code breifly:
I try to use the method of "Progressive Learning" to train my model. Progressive learning refers to the process where we initially divide the model into smaller blocks and ensure they can be trained by adding some small structures called "head" . At the beginning, only a small part of the model is used. After training for some epochs, subsequent parts are connected in sequence until the model is complete. In the code, the parameter 'update_cycle' controls this process.
Expected behavior
First, if you run the code, it will come up with several warnings:
UserWarning: Secure RNG turned off. This is perfectly fine for experimentation as it allows for much faster training performance, but remember to turn it on and retrain one last time before production with
secure_mode
turned on.UserWarning: Optimal order is the largest alpha. Please consider expanding the range of alphas to get a tighter privacy bound.
RuntimeWarning: invalid value encountered in log
UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
And then, if you choose the strategy "baseline", the code can run perfectly. But if you choose the strategy "progressive", it will throw "ValueError: Per sample gradient is not initialized. Not updated in backward pass?" error like this:
"Traceback (most recent call last):
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\ProgDP\main.py", line 265, in
main()
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\ProgDP\main.py", line 231, in main
optimizer.step()
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 513, in step
if self.pre_step():
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 494, in pre_step
self.clip_and_accumulate()
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 397, in clip_and_accumulate
if len(self.grad_samples[0]) == 0:
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 345, in grad_samples
ret.append(self._get_flat_grad_sample(p))
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 282, in _get_flat_grad_sample
raise ValueError(
ValueError: Per sample gradient is not initialized. Not updated in backward pass?"
Environment
python=3.10.13
pytorch=2.1.0
cuda=12.1
opacus=1.4.0
numpy=1.26.0
The text was updated successfully, but these errors were encountered: