Skip to content

Commit

Permalink
Merge branch 'main' into vla
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyever authored Oct 4, 2024
2 parents 32c179c + 7a4472a commit b4f3bd7
Show file tree
Hide file tree
Showing 80 changed files with 3,920 additions and 1,427 deletions.
6 changes: 4 additions & 2 deletions defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def get_fbgemm_inline_avx2_srcs(msvc = False, buck = False):
asm_srcs = ["src/FbgemmFP16UKernelsAvx2.cc"]
if buck:
return select({
"DEFAULT": asm_srcs if not msvc else intrinsics_srcs,
"DEFAULT": asm_srcs,
"ovr_config//compiler:cl": intrinsics_srcs,
"ovr_config//cpu:arm64": intrinsics_srcs,
})
return asm_srcs if not msvc else intrinsics_srcs
Expand Down Expand Up @@ -135,7 +136,8 @@ def get_fbgemm_inline_avx512_srcs(msvc = False, buck = False):
]
if buck:
return select({
"DEFAULT": asm_srcs if not msvc else intrinsics_srcs,
"DEFAULT": asm_srcs,
"ovr_config//compiler:cl": intrinsics_srcs,
"ovr_config//cpu:arm64": intrinsics_srcs,
})
return asm_srcs if not msvc else intrinsics_srcs
Expand Down
9 changes: 5 additions & 4 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ set(GPU_ONLY_OPTIMIZERS
lamb
partial_rowwise_adam
partial_rowwise_lamb
ensemble_rowwise_adagrad
lars_sgd
none
rowwise_adagrad_with_counter)
Expand All @@ -87,7 +86,6 @@ set(GPU_OPTIMIZERS ${COMMON_OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS})
set(VBE_OPTIMIZERS
rowwise_adagrad
rowwise_adagrad_with_counter
ensemble_rowwise_adagrad
sgd
dense)

Expand Down Expand Up @@ -265,10 +263,10 @@ list(APPEND gen_gpu_host_source_files
foreach(optimizer ${ALL_OPTIMIZERS})
list(APPEND gen_cpu_source_files
"gen_embedding_backward_split_${optimizer}_cpu.cpp"
"gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp")
"gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp"
"gen_embedding_split_${optimizer}_pt2_autograd.cpp")
list(APPEND gen_gpu_host_source_files
"gen_embedding_backward_split_${optimizer}.cpp"
"gen_embedding_split_${optimizer}_pt2_autograd.cpp"
"gen_embedding_backward_split_${optimizer}_pt2_cuda_wrapper.cpp")
endforeach()

Expand Down Expand Up @@ -456,6 +454,7 @@ set(fbgemm_gpu_sources_static_cpu
codegen/training/forward/embedding_forward_split_cpu.cpp
codegen/inference/embedding_forward_quantized_host_cpu.cpp
codegen/training/backward/embedding_backward_dense_host_cpu.cpp
codegen/training/pt2/pt2_autograd_utils.cpp
codegen/utils/embedding_bounds_check_host_cpu.cpp
src/config/feature_gates.cpp
src/memory_utils/memory_utils.cpp
Expand All @@ -473,6 +472,7 @@ set(fbgemm_gpu_sources_static_cpu
src/layout_transform_ops/layout_transform_ops_cpu.cpp
src/quantize_ops/quantize_ops_cpu.cpp
src/quantize_ops/quantize_ops_meta.cpp
src/sparse_ops/sparse_async_cumsum.cpp
src/sparse_ops/sparse_ops_cpu.cpp
src/sparse_ops/sparse_ops_meta.cpp
src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
Expand All @@ -481,6 +481,7 @@ set(fbgemm_gpu_sources_static_cpu
src/split_embeddings_cache/lru_cache_populate_byte.cpp
src/split_embeddings_cache/lxu_cache.cpp
src/split_embeddings_cache/split_embeddings_cache_ops.cpp
src/split_embeddings_utils/split_embeddings_utils_cpu.cpp
codegen/training/index_select/batch_index_select_dim0_ops.cpp
codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)

Expand Down
6 changes: 1 addition & 5 deletions fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@
else:
from fbgemm_gpu.bench.bench_utils import benchmark_torch_function

if torch.version.hip:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
else:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")


def generate_unary_feature(
Expand Down
6 changes: 1 addition & 5 deletions fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@
# pyre-ignore[21]
from fbgemm_gpu import open_source # noqa: F401
except Exception:
if torch.version.hip:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
else:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")


def benchmark_hbc_function(
Expand Down
6 changes: 1 addition & 5 deletions fbgemm_gpu/bench/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
else:
from fbgemm_gpu.bench.bench_utils import benchmark_torch_function

if torch.version.hip:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
else:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
)
Expand All @@ -47,7 +44,6 @@
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu"
)
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")


@click.group()
Expand Down
6 changes: 1 addition & 5 deletions fbgemm_gpu/bench/quantize_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@
else:
from fbgemm_gpu.bench.bench_utils import benchmark_torch_function

if torch.version.hip:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
else:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")


@click.group()
Expand Down
6 changes: 1 addition & 5 deletions fbgemm_gpu/bench/stride_gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
# pyre-ignore[21]
from fbgemm_gpu import open_source # noqa: F401
except Exception:
if torch.version.hip:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
else:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")


@click.group()
Expand Down
1 change: 0 additions & 1 deletion fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ def generate() -> None:
lars_sgd(),
partial_rowwise_adam(),
partial_rowwise_lamb(),
ensemble_rowwise_adagrad(),
rowwise_adagrad(),
approx_rowwise_adagrad(),
rowwise_adagrad_with_weight_decay(),
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/codegen/genscript/generate_forward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def generate_pt2_wrappers() -> None:
f"gen_embedding_forward_split_pt2_cpu_wrapper.cpp",
has_cpu_support=True,
is_forward=True,
has_vbe_support=True,
)

# Generate PT2 forward wrapper (CUDA)
Expand Down
121 changes: 0 additions & 121 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,127 +1020,6 @@ def adam() -> Dict[str, Any]:
}


def ensemble_rowwise_adagrad() -> Dict[str, Any]:
split_precomputation = """
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
"""
split_precomputation += generate_optimized_grad_sum_loop_access(
"""
const float4* grad = &{grad_vec}.acc;
auto gx = grad->x;
auto gy = grad->y;
auto gz = grad->z;
auto gw = grad->w;
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
"""
)
split_precomputation += """
const at::acc_type<cache_t, true> g_avg_square =
GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type<cache_t, true>) / D;
at::acc_type<cache_t, true> multiplier;
at::acc_type<cache_t, true> coef_ema;
at::acc_type<cache_t, true> should_ema;
at::acc_type<cache_t, true> should_swap;
if (threadIdx.x == 0) {
at::acc_type<cache_t, true> new_sum_square_grads = momentum2[idx] + g_avg_square;
momentum2[idx] = new_sum_square_grads;
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
coef_ema = (row_counter[idx] > step_start) ? (momentum*1.0) : 0.0;
if (step_mode == 1) {
// row_counter[idx] tracks the number of appearances of this ID
row_counter[idx] += 1.0;
should_ema = floorf(row_counter[idx] / step_ema) - floorf((row_counter[idx]-1.0) / step_ema);
should_swap = floorf(row_counter[idx] / step_swap) - floorf((row_counter[idx]-1.0) / step_swap);
} else if (step_mode == 2) {
should_ema = floorf((iter*1.0 - row_counter[idx]) / step_ema);
should_swap = floorf((iter*1.0 - prev_iter[idx]) / step_swap);
// row_counter[idx] records the step of last ema
if (should_ema > 0.5) {
coef_ema = powf(coef_ema, (iter*1.0 - row_counter[idx]) / step_ema);
row_counter[idx] = iter*1.0;
}
// prev_iter[idx] records the step of last swap
if (should_swap > 0.5) {
prev_iter[idx] = iter*1.0;
}
} else {
should_ema = 0.0;
should_swap = 0.0;
}
}
multiplier = SHFL_SYNC(multiplier, 0);
coef_ema = SHFL_SYNC(coef_ema, 0);
should_ema = SHFL_SYNC(should_ema, 0);
should_swap = SHFL_SYNC(should_swap, 0);
"""

split_weight_update = """
weight_new.acc.x = weight_new.acc.x - multiplier * grad.acc.x;
weight_new.acc.y = weight_new.acc.y - multiplier * grad.acc.y;
weight_new.acc.z = weight_new.acc.z - multiplier * grad.acc.z;
weight_new.acc.w = weight_new.acc.w - multiplier * grad.acc.w;
if (should_ema > 0.5) { // slow table ema
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (momentum - coef_ema) * multiplier * grad.acc.x;
m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (momentum - coef_ema) * multiplier * grad.acc.y;
m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z + (momentum - coef_ema) * multiplier * grad.acc.z;
m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w + (momentum - coef_ema) * multiplier * grad.acc.w;
m_t.store(&momentum1[idx * D + d]);
}
if (should_swap > 0.5) { // slow-to-fast swap
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
weight_new.acc.x = m_t.acc.x * 1.0;
weight_new.acc.y = m_t.acc.y * 1.0;
weight_new.acc.z = m_t.acc.z * 1.0;
weight_new.acc.w = m_t.acc.w * 1.0;
}
"""

split_weight_update_cpu = "" # TODO

return {
"optimizer": "ensemble_rowwise_adagrad",
"is_prototype_optimizer": True,
"args": OptimizerArgsSet.create(
[
OptimItem(
ArgType.PLACEHOLDER_TENSOR,
"momentum1",
ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR],
),
OptimItem(
ArgType.PLACEHOLDER_TENSOR,
"momentum2",
ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR],
),
OptimItem(ArgType.TENSOR, "prev_iter"),
OptimItem(ArgType.TENSOR, "row_counter"),
OptimItem(ArgType.FLOAT, "learning_rate"),
OptimItem(ArgType.FLOAT, "eps"),
OptimItem(ArgType.FLOAT, "step_ema"),
OptimItem(ArgType.FLOAT, "step_swap"),
OptimItem(ArgType.FLOAT, "step_start"),
OptimItem(ArgType.FLOAT, "momentum"),
OptimItem(ArgType.INT, "iter"),
OptimItem(ArgType.INT, "step_mode"),
]
),
"split_precomputation": split_precomputation,
"split_weight_update": split_weight_update,
"split_post_update": "",
"split_weight_update_cpu": split_weight_update_cpu,
"has_cpu_support": False,
"has_gpu_support": True,
"has_vbe_support": True,
"has_global_weight_decay_support": False,
"has_ssd_support": False,
}


def partial_rowwise_adam() -> Dict[str, Any]:
split_precomputation = """
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
Expand Down
Loading

0 comments on commit b4f3bd7

Please sign in to comment.