-
Notifications
You must be signed in to change notification settings - Fork 348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Significant performance difference Ghost clipping and Normal clipping, even on single-gpu #692
Comments
Thanks for this issue! I am curious if you have tested with setting the noise multiplier as 0 and fixing other sources of randomness -- do you still notice performance gaps? |
Thanks, that's an insightful suggestion. The performance gap persists with noise_multiplier=0.0 (see below, 34% vs. 44%). That might point to a difference in the gradient clipping part? One interesting observation is setting the random seed makes
For quick access, this is the opacus tutorial notebook and this is my notebook with random seed set and switching between ghost and hooks clipping |
Following up on the issue, we have two suprising findings: ghost clipping performs significantly worse than clipping.txt Thank you in advance for any suggestion, Rob |
We came across similiar weird behaviors with GhostClipping implementation in Opacus when working on benchmarking different implementations with @sebasrb09 (In our case the loss did not decrease using GhostClipping from Opacus but with other implementations of GhostClipping it did - I can share an updated pre-print with our comparisons between the methods in January). Thanks for bringing up a reproducible example @RobRomijnders. We can try to share details on code with the help of @sebasrb09 in January. |
There's a significant performance difference between
ghost
clipping and the defaulthooks
clipping. This is highly suprising, as ghost clipping is supposed to be only a numerical efficiency change. Where does the performance diff originate from?To reproduce, I take the Opacus tutorial here and I make only one line change by sending
grad_sample_mode='ghost'
to the privacy engine (and using the modified criterion). The results I get are significantly differentReproduce
The colab to reproduce the above result is at github.com/RobRomijnders/random/blob/master/opacus.ipynb
Expected behavior
Expectations: I would expect the performance, e.g. loss and gradient norm to be the same between normal clipping and ghost clipping. Some difference might be attributed to the random seed, but this is a single gpu evaluation.
Environment
PyTorch version: 2.4.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31
Python version: 3.8.10 (default, Sep 11 2024, 16:02:53) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-102-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
GPU 3: NVIDIA GeForce RTX 3090
Nvidia driver version: 535.171.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 40
On-line CPU(s) list: 0-39
Thread(s) per core: 2
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.6.77
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] torch==2.4.1
[pip3] torchao==0.3.1
[pip3] torchtune==0.2.1
[pip3] torchvision==0.19.1
[pip3] triton==3.0.0
Other comments
I started chasing this bug and made this minimally reproducible example. In another project that I'm working on, the performance difference is consistently about 2% classification accuracy.
The text was updated successfully, but these errors were encountered: