From c9dee969e0d4c9760c0af2ace344035edd1680cd Mon Sep 17 00:00:00 2001 From: Minhua Chen Date: Fri, 13 Sep 2024 22:11:52 -0700 Subject: [PATCH] refactor step_mode in ensemble_rowwise_adagrad (#3137) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3137 X-link: https://github.com/facebookresearch/FBGEMM/pull/230 refactor step_mode in ensemble_rowwise_adagrad Reviewed By: q10, spcyppt Differential Revision: D62608418 --- fbgemm_gpu/codegen/genscript/optimizers.py | 25 ++++++++++++---------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index 2e6b6ea54..67ecec19d 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -1047,27 +1047,24 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]: momentum2[idx] = new_sum_square_grads; multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - coef_ema = fabs(momentum); + coef_ema = momentum*1.0; if (step_mode == 1) { // row_counter[idx] records the number of appearances of this row row_counter[idx] += 1.0; should_ema = ((int64_t)round(fmod(row_counter[idx], step_ema)) == 0); should_swap = (row_counter[idx] > step_start && (int64_t)round(fmod(row_counter[idx], step_swap)) == 0); - } else if (step_mode == 2) { + } else { // row_counter[idx] records the step of last ema; prev_iter[idx] records the step of last swap should_ema = ((iter*1.0 - row_counter[idx]) >= step_ema); should_swap = (iter*1.0 > step_start && (iter*1.0 - prev_iter[idx]) >= step_swap); if (should_ema) { - coef_ema = (momentum>0) ? powf(fabs(momentum), (iter*1.0 - row_counter[idx])/max(1.0, step_ema)) : fabs(momentum); + coef_ema = powf(momentum, (iter*1.0 - row_counter[idx])/max(1.0, step_ema)); row_counter[idx] = iter*1.0; } if (should_swap) { prev_iter[idx] = iter*1.0; } - } else { - should_ema = false; - should_swap = false; - } + } } multiplier = SHFL_SYNC(multiplier, 0); coef_ema = SHFL_SYNC(coef_ema, 0); @@ -1083,10 +1080,16 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]: if (should_ema) { // slow table ema Vec4T m_t(&momentum1[idx * D + d]); - m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (fabs(momentum) - coef_ema) * multiplier * grad.acc.x; - m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (fabs(momentum) - coef_ema) * multiplier * grad.acc.y; - m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z + (fabs(momentum) - coef_ema) * multiplier * grad.acc.z; - m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w + (fabs(momentum) - coef_ema) * multiplier * grad.acc.w; + m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x; + m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y; + m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z; + m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w; + if (step_mode == 2) { + m_t.acc.x = m_t.acc.x + (momentum - coef_ema) * multiplier * grad.acc.x; + m_t.acc.y = m_t.acc.y + (momentum - coef_ema) * multiplier * grad.acc.y; + m_t.acc.z = m_t.acc.z + (momentum - coef_ema) * multiplier * grad.acc.z; + m_t.acc.w = m_t.acc.w + (momentum - coef_ema) * multiplier * grad.acc.w; + } m_t.store(&momentum1[idx * D + d]); }