Skip to content

Commit

Permalink
ensemble rowwise adagrad (fbgemm backend) (pytorch#2889)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#49

Pull Request resolved: pytorch#2889

ensemble rowwise adagrad (fbgemm diff)

Reviewed By: jiayixu64, csmiler, spcyppt

Differential Revision: D60189486

fbshipit-source-id: e67609654c4bedd3848e4d76dd7beb406416c89b
  • Loading branch information
minhua-chen authored and facebook-github-bot committed Aug 21, 2024
1 parent 1b07a5d commit d3af64c
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 1 deletion.
2 changes: 2 additions & 0 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ set(GPU_ONLY_OPTIMIZERS
lamb
partial_rowwise_adam
partial_rowwise_lamb
ensemble_rowwise_adagrad
lars_sgd
none
rowwise_adagrad_with_counter)
Expand All @@ -86,6 +87,7 @@ set(GPU_OPTIMIZERS ${COMMON_OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS})
set(VBE_OPTIMIZERS
rowwise_adagrad
rowwise_adagrad_with_counter
ensemble_rowwise_adagrad
sgd
dense)

Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def generate() -> None:
lars_sgd(),
partial_rowwise_adam(),
partial_rowwise_lamb(),
ensemble_rowwise_adagrad(),
rowwise_adagrad(),
approx_rowwise_adagrad(),
rowwise_adagrad_with_weight_decay(),
Expand Down
110 changes: 110 additions & 0 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,116 @@ def adam() -> Dict[str, Any]:
}


def ensemble_rowwise_adagrad() -> Dict[str, Any]:
split_precomputation = """
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
"""
split_precomputation += generate_optimized_grad_sum_loop_access(
"""
const float4* grad = &{grad_vec}.acc;
auto gx = grad->x;
auto gy = grad->y;
auto gz = grad->z;
auto gw = grad->w;
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
"""
)
split_precomputation += """
const at::acc_type<cache_t, true> g_avg_square =
GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type<cache_t, true>) / D;
at::acc_type<cache_t, true> multiplier;
at::acc_type<bool, true> should_ema;
at::acc_type<bool, true> should_swap;
if (threadIdx.x == 0) {
at::acc_type<cache_t, true> new_sum_square_grads = momentum2[idx] + g_avg_square;
momentum2[idx] = new_sum_square_grads;
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
if (step_mode == 1) {
// row_counter[idx] records the number of appearances of this row
row_counter[idx] += 1.0;
should_ema = (row_counter[idx] > step_start && (int64_t)round(fmod(row_counter[idx], step_ema)) == 0);
should_swap = (row_counter[idx] > step_start && (int64_t)round(fmod(row_counter[idx], step_swap)) == 0);
} else if (step_mode == 2) {
// row_counter[idx] records the iter when this row appeard last time
should_ema = (iter * 1.0 > step_start && floorf(iter*1.0 / step_ema) - floorf(row_counter[idx] / step_ema) > 0.5);
should_swap = (iter * 1.0 > step_start && floorf(iter*1.0 / step_swap) - floorf(row_counter[idx] / step_swap) > 0.5);
row_counter[idx] = iter * 1.0;
} else {
should_ema = false;
should_swap = false;
}
}
multiplier = SHFL_SYNC(multiplier, 0);
should_ema = SHFL_SYNC(should_ema, 0);
should_swap = SHFL_SYNC(should_swap, 0);
"""

split_weight_update = """
weight_new.acc.x = weight_new.acc.x - multiplier * grad.acc.x;
weight_new.acc.y = weight_new.acc.y - multiplier * grad.acc.y;
weight_new.acc.z = weight_new.acc.z - multiplier * grad.acc.z;
weight_new.acc.w = weight_new.acc.w - multiplier * grad.acc.w;
if (should_ema) { // slow table ema
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
m_t.acc.x = (1.0 - momentum) * weight_new.acc.x + momentum * m_t.acc.x;
m_t.acc.y = (1.0 - momentum) * weight_new.acc.y + momentum * m_t.acc.y;
m_t.acc.z = (1.0 - momentum) * weight_new.acc.z + momentum * m_t.acc.z;
m_t.acc.w = (1.0 - momentum) * weight_new.acc.w + momentum * m_t.acc.w;
m_t.store(&momentum1[idx * D + d]);
}
if (should_swap) { // slow-to-fast swap
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
weight_new.acc.x = m_t.acc.x;
weight_new.acc.y = m_t.acc.y;
weight_new.acc.z = m_t.acc.z;
weight_new.acc.w = m_t.acc.w;
}
"""

split_weight_update_cpu = "" # TODO

return {
"optimizer": "ensemble_rowwise_adagrad",
"is_prototype_optimizer": True,
"args": OptimizerArgsSet.create(
[
OptimItem(
ArgType.PLACEHOLDER_TENSOR,
"momentum1",
ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR],
),
OptimItem(
ArgType.PLACEHOLDER_TENSOR,
"momentum2",
ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR],
),
OptimItem(ArgType.TENSOR, "row_counter"),
OptimItem(ArgType.FLOAT, "learning_rate"),
OptimItem(ArgType.FLOAT, "eps"),
OptimItem(ArgType.FLOAT, "step_ema"),
OptimItem(ArgType.FLOAT, "step_swap"),
OptimItem(ArgType.FLOAT, "step_start"),
OptimItem(ArgType.FLOAT, "momentum"),
OptimItem(ArgType.INT, "iter"),
OptimItem(ArgType.INT, "step_mode"),
]
),
"split_precomputation": split_precomputation,
"split_weight_update": split_weight_update,
"split_post_update": "",
"split_weight_update_cpu": split_weight_update_cpu,
"has_cpu_support": False,
"has_gpu_support": True,
"has_vbe_support": True,
"has_global_weight_decay_support": False,
"has_ssd_support": False,
}


def partial_rowwise_adam() -> Dict[str, Any]:
split_precomputation = """
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
{%- if "beta2" in args.split_function_arg_names %}
beta2: float = 0.999,
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
step_ema: float = 10000,
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
step_swap: float = 10000,
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
step_start: float = 0,
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
step_mode: int = 1,
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
weight_decay: float = 0.0,
{%- endif %}
Expand Down Expand Up @@ -95,6 +107,18 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
{%- if "beta2" in args.split_function_arg_names %}
beta2=beta2,
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
step_ema=step_ema,
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
step_swap=step_swap,
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
step_start=step_start,
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
step_mode=step_mode,
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
weight_decay=weight_decay,
{%- endif %}
Expand Down Expand Up @@ -139,7 +163,7 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
rowwise = False
{% endif %}
{% elif state_tensor == "momentum2" %}
{% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb"] %}
{% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb", "ensemble_rowwise_adagrad"] %}
rowwise = True
{% else %}
rowwise = False
Expand Down Expand Up @@ -189,6 +213,18 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
{%- if "beta2" in args.split_function_arg_names %}
self.beta2 = beta2
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
self.step_ema = step_ema
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
self.step_swap = step_swap
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
self.step_start = step_start
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
self.step_mode = step_mode
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
self.weight_decay = weight_decay
{%- endif %}
Expand Down

0 comments on commit d3af64c

Please sign in to comment.