Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significant performance difference Ghost clipping and Normal clipping, even on single-gpu #692

Open
RobRomijnders opened this issue Nov 28, 2024 · 4 comments

Comments

@RobRomijnders
Copy link

There's a significant performance difference between ghost clipping and the default hooks clipping. This is highly suprising, as ghost clipping is supposed to be only a numerical efficiency change. Where does the performance diff originate from?

To reproduce, I take the Opacus tutorial here and I make only one line change by sending grad_sample_mode='ghost' to the privacy engine (and using the modified criterion). The results I get are significantly different

Normal clipping:	
        Train Epoch: 1 	Loss: 2.774419 Acc@1: 15.042706 (ε = 13.85, δ = 1e-05)
	Train Epoch: 1 	Loss: 2.356322 Acc@1: 22.598047 (ε = 16.28, δ = 1e-05)
	Train Epoch: 2 	Loss: 1.738953 Acc@1: 40.154446 (ε = 18.27, δ = 1e-05)
	Train Epoch: 2 	Loss: 1.714563 Acc@1: 42.186048 (ε = 19.86, δ = 1e-05)
	Train Epoch: 3 	Loss: 1.737356 Acc@1: 46.379587 (ε = 21.35, δ = 1e-05)
	Train Epoch: 3 	Loss: 1.716496 Acc@1: 47.698608 (ε = 22.60, δ = 1e-05)

Ghost clipping:
        Train Epoch: 1 	Loss: 2.819867 Acc@1: 13.947351 (ε = 13.74, δ = 1e-05)
	Train Epoch: 1 	Loss: 2.425520 Acc@1: 19.355440 (ε = 16.24, δ = 1e-05)
	Train Epoch: 2 	Loss: 1.842602 Acc@1: 32.977792 (ε = 18.31, δ = 1e-05)
	Train Epoch: 2 	Loss: 1.820248 Acc@1: 34.835068 (ε = 19.94, δ = 1e-05)
	Train Epoch: 3 	Loss: 1.777791 Acc@1: 39.394069 (ε = 21.32, δ = 1e-05)
	Train Epoch: 3 	Loss: 1.779722 Acc@1: 39.848439 (ε = 22.55, δ = 1e-05)

Reproduce

The colab to reproduce the above result is at github.com/RobRomijnders/random/blob/master/opacus.ipynb

Expected behavior

Expectations: I would expect the performance, e.g. loss and gradient norm to be the same between normal clipping and ghost clipping. Some difference might be attributed to the random seed, but this is a single gpu evaluation.

Environment

PyTorch version: 2.4.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.8.10 (default, Sep 11 2024, 16:02:53) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-102-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
GPU 3: NVIDIA GeForce RTX 3090

Nvidia driver version: 535.171.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 40
On-line CPU(s) list: 0-39
Thread(s) per core: 2

[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.6.77
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] torch==2.4.1
[pip3] torchao==0.3.1
[pip3] torchtune==0.2.1
[pip3] torchvision==0.19.1
[pip3] triton==3.0.0

Other comments

I started chasing this bug and made this minimally reproducible example. In another project that I'm working on, the performance difference is consistently about 2% classification accuracy.

@EnayatUllah
Copy link
Contributor

Thanks for this issue! I am curious if you have tested with setting the noise multiplier as 0 and fixing other sources of randomness -- do you still notice performance gaps?

@RobRomijnders
Copy link
Author

RobRomijnders commented Dec 4, 2024

Thanks, that's an insightful suggestion. The performance gap persists with noise_multiplier=0.0 (see below, 34% vs. 44%). That might point to a difference in the gradient clipping part?

One interesting observation is setting the random seed makes hooks clipping deterministic, but ghost still seems to hava randomness. Perhaps that's tangential, but any idea why they're different? With noise_multiplier=0, both should be deterministic.

Ghost clipping (first run)
	Train Epoch: 1 	Loss: 3.098412 Acc@1: 10.270831 (ε = inf, δ = 1e-05)
	Train Epoch: 1 	Loss: 2.657283 Acc@1: 13.484471 (ε = inf, δ = 1e-05)
	Train Epoch: 2 	Loss: 2.062224 Acc@1: 22.660401 (ε = inf, δ = 1e-05)
	Train Epoch: 2 	Loss: 2.042849 Acc@1: 24.484757 (ε = inf, δ = 1e-05)
	Train Epoch: 3 	Loss: 2.000369 Acc@1: 32.717716 (ε = inf, δ = 1e-05)
	Train Epoch: 3 	Loss: 1.987917 Acc@1: 34.639087 (ε = inf, δ = 1e-05)
Ghost clipping (second run, same seed, why randomness?)
	Train Epoch: 1 	Loss: 3.098769 Acc@1: 10.394780 (ε = inf, δ = 1e-05)
	Train Epoch: 1 	Loss: 2.670149 Acc@1: 13.307869 (ε = inf, δ = 1e-05)
	Train Epoch: 2 	Loss: 2.081255 Acc@1: 21.829020 (ε = inf, δ = 1e-05)
	Train Epoch: 2 	Loss: 2.060630 Acc@1: 23.208688 (ε = inf, δ = 1e-05)
	Train Epoch: 3 	Loss: 2.015526 Acc@1: 29.352595 (ε = inf, δ = 1e-05)
	Train Epoch: 3 	Loss: 2.010941 Acc@1: 31.919048 (ε = inf, δ = 1e-05)

Hooks clipping (first run)
	
        Train Epoch: 1 	Loss: 2.948032 Acc@1: 11.263424 (ε = inf, δ = 1e-05)
	Train Epoch: 1 	Loss: 2.535477 Acc@1: 15.813970 (ε = inf, δ = 1e-05)
	Train Epoch: 2 	Loss: 2.029866 Acc@1: 29.735081 (ε = inf, δ = 1e-05)
	Train Epoch: 2 	Loss: 2.000195 Acc@1: 32.668854 (ε = inf, δ = 1e-05)
	Train Epoch: 3 	Loss: 1.889851 Acc@1: 42.416855 (ε = inf, δ = 1e-05)
	Train Epoch: 3 	Loss: 1.850450 Acc@1: 44.689258 (ε = inf, δ = 1e-05)

Hooks clipping (second run, same seed, deterministic)
	Train Epoch: 1 	Loss: 2.948032 Acc@1: 11.263424 (ε = inf, δ = 1e-05)
	Train Epoch: 1 	Loss: 2.535477 Acc@1: 15.813970 (ε = inf, δ = 1e-05)
	Train Epoch: 2 	Loss: 2.029866 Acc@1: 29.735081 (ε = inf, δ = 1e-05)
	Train Epoch: 2 	Loss: 2.000195 Acc@1: 32.668854 (ε = inf, δ = 1e-05)
	Train Epoch: 3 	Loss: 1.889851 Acc@1: 42.416855 (ε = inf, δ = 1e-05)
	Train Epoch: 3 	Loss: 1.850450 Acc@1: 44.689258 (ε = inf, δ = 1e-05)

For quick access, this is the opacus tutorial notebook and this is my notebook with random seed set and switching between ghost and hooks clipping

@RobRomijnders
Copy link
Author

Following up on the issue, we have two suprising findings: ghost clipping performs significantly worse than hooks, ew, or functorch clipping. The other issue is that hooks and ew are deterministic by setting the random seed, but ghost and functorch are not. Can we somehow approach these issues jointly?

clipping.txt
Attached the results on the four methods of clipping, with the off-the-shelve opacus notebook.

Thank you in advance for any suggestion, Rob

@Solosneros
Copy link
Contributor

There's a significant performance difference between ghost clipping and the default hooks clipping. This is highly suprising, as ghost clipping is supposed to be only a numerical efficiency change. Where does the performance diff originate from?

We came across similiar weird behaviors with GhostClipping implementation in Opacus when working on benchmarking different implementations with @sebasrb09 (In our case the loss did not decrease using GhostClipping from Opacus but with other implementations of GhostClipping it did - I can share an updated pre-print with our comparisons between the methods in January). Thanks for bringing up a reproducible example @RobRomijnders.

We can try to share details on code with the help of @sebasrb09 in January.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants