diff --git a/.gitignore b/.gitignore index 4cf744f93..c90537e9b 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,5 @@ bazel-genfiles /pip-wheel-metadata/ .bazelrc - +model_dir/ +export_dir/ diff --git a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/README.md b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/README.md index 7eede6c75..8bacf123f 100644 --- a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/README.md +++ b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/README.md @@ -6,10 +6,18 @@ - enable gpu by `python3 -m pip install tensorflow[and-cuda]` - `HOROVOD_WITH_MPI=1 HOROVOD_WITH_GLOO=1 pip install --no-cache-dir horovod` - recommend to use nv docker image `nvcr.io/nvidia/tensorflow:24.02-tf2-py3` -- run `rm -rf model_dir/ export_dir/` to clean up the model and export directory before running the script ## start train: By default, this shell will start a train task with N workers as GPU number on local machine. ```shell sh start.sh ``` +run a debug task with only 1 steps_per_epoch +```shell +sh start.sh 1 +``` +## start export for serving: +```shell +sh test.sh export +sh test.sh inference +``` \ No newline at end of file diff --git a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py index e8fcb2aa9..e963e4449 100644 --- a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py +++ b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py @@ -1,36 +1,57 @@ import os import shutil -import tensorflow as tf -import tensorflow_datasets as tfds from absl import flags from absl import app + +os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" #VERY IMPORTANT! +os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" +# Because of the two environment variables above no non-standard library imports should happen before this. +import tensorflow as tf from tensorflow_recommenders_addons import dynamic_embedding as de try: from tensorflow.keras.legacy.optimizers import Adam except: from tensorflow.keras.optimizers import Adam - +import tensorflow_datasets as tfds import horovod.tensorflow as hvd +# optimal performance +os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit' -os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" #VERY IMPORTANT! -os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" +def has_horovod() -> bool: + return 'OMPI_COMM_WORLD_RANK' in os.environ or 'PMI_RANK' in os.environ + -# Horovod: initialize Horovod. -hvd.init() +def config(): + # callback calls hvd.rank() so we need to initialize horovod here + hvd.init() + if has_horovod(): + print("Horovod is enabled.") + if hvd.rank() > 0: + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + # Horovod: pin GPU to be used to process local rank (one GPU per process) + config_gpu(hvd.local_rank()) + else: + config_gpu() -if hvd.rank() > 0: - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' -# Horovod: pin GPU to be used to process local rank (one GPU per process) -physical_devices = tf.config.list_physical_devices('GPU') -tf.config.set_visible_devices(physical_devices[hvd.local_rank()], 'GPU') -tf.config.experimental.set_memory_growth(physical_devices[hvd.local_rank()], - True) +def config_gpu(rank=0): + physical_devices = tf.config.list_physical_devices('GPU') + if physical_devices: + tf.config.set_visible_devices(physical_devices[rank], 'GPU') + tf.config.experimental.set_memory_growth(physical_devices[rank], True) + else: + print("No GPU found, using CPU instead.") + + +def get_cluster_size() -> int: + return hvd.size() if has_horovod() else 1 + + +def get_rank() -> int: + return hvd.rank() if has_horovod() else 0 -# optimal performance -os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit' flags.DEFINE_string('mode', 'train', 'Select the running mode: train or test.') flags.DEFINE_string('model_dir', 'model_dir', @@ -41,8 +62,9 @@ flags.DEFINE_integer('epochs', 1, 'Number of training epochs.') flags.DEFINE_integer('embedding_size', 32, 'Embedding size for users and movies') -flags.DEFINE_integer('test_steps', 128, 'Embedding size for users and movies') -flags.DEFINE_integer('test_batch', 1024, 'Embedding size for users and movies') +flags.DEFINE_integer('test_steps', 128, 'test steps.') +flags.DEFINE_integer('test_batch', 1024, 'test batch size.') +flags.DEFINE_bool('shuffle', True, 'shuffle dataset.') FLAGS = flags.FLAGS input_spec = { @@ -190,10 +212,6 @@ def __init__(self, boundaries, **kwargs): self.boundaries = boundaries super(Bucketize, self).__init__(**kwargs) - def build(self, input_shape): - # Be sure to call this somewhere! - super(Bucketize, self).build(input_shape) - def call(self, x, **kwargs): return tf.raw_ops.Bucketize(input=x, boundaries=self.boundaries) @@ -203,6 +221,25 @@ def get_config(self,): return dict(list(base_config.items()) + list(config.items())) +def get_kv_creator(mpi_size: int, + mpi_rank: int, + vocab_size: int = 1, + value_size: int = 4, + dim: int = 16): + gpus = tf.config.list_physical_devices('GPU') + saver = de.FileSystemSaver(proc_size=mpi_size, proc_rank=mpi_rank) + if gpus: + max_capacity = 2 * vocab_size + config = de.HkvHashTableConfig(init_capacity=vocab_size, + max_capacity=max_capacity, + max_hbm_for_values=max_capacity * + value_size * dim) + return de.HkvHashTableCreator(config=config, saver=saver) + else: + # The saver parameter of kv_creator saves the K-V in the hash table into a separate KV file. + return de.CuckooHashTableCreator(saver=saver) + + class ChannelEmbeddingLayers(tf.keras.layers.Layer): def __init__(self, @@ -214,13 +251,10 @@ def __init__(self, mpi_rank=0): super(ChannelEmbeddingLayers, self).__init__() - - self.gpu_device = ["GPU:0"] - self.cpu_device = ["CPU:0"] - - # The saver parameter of kv_creator saves the K-V in the hash table into a separate KV file. - self.kv_creator = de.CuckooHashTableCreator( - saver=de.FileSystemSaver(proc_size=mpi_size, proc_rank=mpi_rank)) + init_capacity = 4096000 + kv_creator_dense = get_kv_creator(mpi_size, mpi_rank, init_capacity, + tf.dtypes.float32.size, + dense_embedding_size) self.dense_embedding_layer = de.keras.layers.HvdAllToAllEmbedding( mpi_size=mpi_size, @@ -228,22 +262,23 @@ def __init__(self, key_dtype=tf.int64, value_dtype=tf.float32, initializer=embedding_initializer, - devices=self.gpu_device, name=name + '_DenseUnifiedEmbeddingLayer', bp_v2=True, - init_capacity=4096000, - kv_creator=self.kv_creator) + init_capacity=init_capacity, + kv_creator=kv_creator_dense) + kv_creator_sparse = get_kv_creator(mpi_size, mpi_rank, init_capacity, + tf.dtypes.float32.size, + sparse_embedding_size) self.sparse_embedding_layer = de.keras.layers.HvdAllToAllEmbedding( mpi_size=mpi_size, embedding_size=sparse_embedding_size, key_dtype=tf.int64, value_dtype=tf.float32, initializer=embedding_initializer, - devices=self.cpu_device, name=name + '_SparseUnifiedEmbeddingLayer', init_capacity=4096000, - kv_creator=self.kv_creator) + kv_creator=kv_creator_sparse) self.dnn = tf.keras.layers.Dense( 128, @@ -251,9 +286,6 @@ def __init__(self, kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1), bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1)) - def build(self, input_shape): - super(ChannelEmbeddingLayers, self).build(input_shape) - def __call__(self, features_info): dense_inputs = [] dense_input_dims = [] @@ -324,7 +356,7 @@ def __init__(self, super(DualChannelsDeepModel, self).__init__() self.user_embedding_size = user_embedding_size self.movie_embedding_size = movie_embedding_size - + print(f"mpi_size {mpi_size}, mpi_rank {mpi_rank}") self.user_embedding = ChannelEmbeddingLayers( name='user', dense_embedding_size=user_embedding_size, @@ -444,10 +476,11 @@ def get_dataset(batch_size=1): tf.one_hot(tf.cast(x["user_rating"] - 1, dtype=tf.int64), 5) }) dataset = tf.data.Dataset.zip((features, ratings)) - shuffled = dataset.shuffle(1_000_000, - seed=2021, - reshuffle_each_iteration=False) - dataset = shuffled.repeat(1).batch(batch_size).prefetch(tf.data.AUTOTUNE) + if FLAGS.shuffle: + dataset = dataset.shuffle(1_000_000, + seed=2021, + reshuffle_each_iteration=False) + dataset = dataset.repeat(1).batch(batch_size).prefetch(tf.data.AUTOTUNE) # Only GPU:0 since TF is set to be visible to GPU:X dataset = dataset.apply( tf.data.experimental.prefetch_to_device('GPU:0', buffer_size=2)) @@ -495,28 +528,30 @@ def export_to_savedmodel(model, savedmodel_dir): options=save_options) +def save_spec(save_model): + if hasattr(save_model, 'save_spec'): + # tf version >= 2.6 + return save_model.save_spec() + else: + arg_specs = list() + kwarg_specs = dict() + for i in save_model.inputs: + arg_specs.append(i.type_spec) + return [arg_specs], kwarg_specs + + +@tf.function +def serve(save_model, *args, **kwargs): + return save_model(*args, **kwargs) + + def export_for_serving(model, export_dir): save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA']) if not os.path.exists(export_dir): os.mkdir(export_dir) - def save_spec(): - if hasattr(model, 'save_spec'): - # tf version >= 2.6 - return model.save_spec() - else: - arg_specs = list() - kwarg_specs = dict() - for i in model.inputs: - arg_specs.append(i.type_spec) - return [arg_specs], kwarg_specs - - @tf.function - def serve(*args, **kwargs): - return model(*args, **kwargs) - - arg_specs, kwarg_specs = save_spec() + arg_specs, kwarg_specs = save_spec(model) ########################## What really happened ########################## # if hvd.rank() == 0: @@ -550,31 +585,37 @@ def serve(*args, **kwargs): options=save_options, signatures={ 'serving_default': - serve.get_concrete_function(*arg_specs, **kwarg_specs) + serve.get_concrete_function(model, *arg_specs, **kwarg_specs) }, ) - if hvd.rank() == 0: + if get_rank() == 0: # Modify the inference graph to a stand-alone version - from tensorflow.python.saved_model import save as tf_save tf.keras.backend.clear_session() + from tensorflow.python.saved_model import save as tf_save de.enable_inference_mode() export_model = DualChannelsDeepModel(FLAGS.embedding_size, FLAGS.embedding_size, tf.keras.initializers.Zeros(), False, - hvd.size(), hvd.rank()) + 1, 0) # The save_and_return_nodes function is used to overwrite the saved_model.pb file generated by the save_model function and rewrite the inference graph. tf_save.save_and_return_nodes(obj=export_model, export_dir=export_dir, options=save_options, - experimental_skip_checkpoint=True) + experimental_skip_checkpoint=True, + signatures={ + 'serving_default': + serve.get_concrete_function( + export_model, *arg_specs, + **kwarg_specs) + }) def train(): dataset = get_dataset(batch_size=32) model = DualChannelsDeepModel(FLAGS.embedding_size, FLAGS.embedding_size, tf.keras.initializers.RandomNormal(0.0, 0.5), - True, hvd.size(), hvd.rank()) + True, get_cluster_size(), get_rank()) optimizer = Adam(1E-3) optimizer = de.DynamicEmbeddingOptimizer(optimizer) @@ -586,19 +627,23 @@ def train(): ]) if os.path.exists(FLAGS.model_dir + '/variables'): - model.load_weights(FLAGS.model_dir + '/variables/variables') + model.load_weights(FLAGS.model_dir) tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=FLAGS.model_dir) save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA']) - # horovod callback is used to broadcast the value generated by initializer of rank0. - hvd_opt_init_callback = de.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback( - root_rank=0) ckpt_callback = de.keras.callbacks.ModelCheckpoint( filepath=FLAGS.model_dir + '/weights_epoch{epoch:03d}_loss{loss:.4f}', options=save_options) - callbacks_list = [hvd_opt_init_callback, ckpt_callback] + if has_horovod(): + # horovod callback is used to broadcast the value generated by initializer of rank0. + hvd_opt_init_callback = de.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback( + root_rank=0) + callbacks_list = [hvd_opt_init_callback, ckpt_callback] + else: + callbacks_list = [ckpt_callback] + # The log class callback only takes effect in rank0 for convenience - if hvd.rank() == 0: + if get_rank() == 0: callbacks_list.extend([tensorboard_callback]) # If there are callbacks such as evaluation metrics that call model calculations, take effect on all ranks. # callbacks_list.extend([my_auc_callback]) @@ -607,7 +652,7 @@ def train(): callbacks=callbacks_list, epochs=FLAGS.epochs, steps_per_epoch=FLAGS.steps_per_epoch, - verbose=1 if hvd.rank() == 0 else 0) + verbose=1 if get_rank() == 0 else 0) export_to_savedmodel(model, FLAGS.model_dir) export_for_serving(model, FLAGS.export_dir) @@ -625,13 +670,30 @@ def export(): mpi_size=1, mpi_rank=0) save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA']) + dummy_features = { + 'movie_id': tf.constant([0], dtype=tf.int64), + 'movie_genres': tf.constant([0], dtype=tf.int64), + 'user_id': tf.constant([0], dtype=tf.int64), + 'user_gender': tf.constant([0], dtype=tf.int64), + 'user_occupation_label': tf.constant([0], dtype=tf.int64), + 'bucketized_user_age': tf.constant([0], dtype=tf.int64), + 'timestamp': tf.constant([0], dtype=tf.int64) + } + export_model(dummy_features) + arg_specs, kwarg_specs = save_spec(export_model) # Modify the inference graph to a stand-alone version from tensorflow.python.saved_model import save as tf_save # The save_and_return_nodes function is used to overwrite the saved_model.pb file generated by the save_model function and rewrite the inference graph. tf_save.save_and_return_nodes(obj=export_model, export_dir=FLAGS.export_dir, options=save_options, - experimental_skip_checkpoint=True) + experimental_skip_checkpoint=True, + signatures={ + 'serving_default': + serve.get_concrete_function( + export_model, *arg_specs, + **kwarg_specs) + }) def test(): @@ -639,7 +701,6 @@ def test(): dataset = get_dataset(batch_size=FLAGS.test_batch) model = tf.keras.models.load_model(FLAGS.export_dir) - signature = model.signatures['serving_default'] def get_close_or_equal_cnt(model, features, ratings): preds = model(features) @@ -660,14 +721,49 @@ def get_close_or_equal_cnt(model, features, ratings): f' accurate, {equal_cnt}/{FLAGS.test_batch} are absolutely accurate.') +def inference(): + de.enable_inference_mode() + # model = keras.models.load_model( + model = tf.keras.models.load_model(FLAGS.export_dir) + print(f"model signature keys: {model.signatures.keys()} {model.signatures}") + inference_func = model.signatures['serving_default'] + + dataset = get_dataset(batch_size=FLAGS.test_batch) + it = iter(dataset) + + def get_close_or_equal_cnt(preds, ratings): + preds = tf.math.argmax(preds, axis=1) + ratings = tf.math.argmax(ratings, axis=1) + close_cnt = tf.reduce_sum( + tf.cast(tf.math.abs(preds - ratings) <= 1, dtype=tf.int32)) + equal_cnt = tf.reduce_sum( + tf.cast(tf.math.abs(preds - ratings) == 0, dtype=tf.int32)) + return close_cnt, equal_cnt + + for step in range(FLAGS.test_steps): + features, ratings = next(it) + ratings = ratings['user_rating'] + outputs = inference_func(**features) + preds = outputs['user_rating'] + + close_cnt, equal_cnt = get_close_or_equal_cnt(preds, ratings) + + print( + f'In batch prediction, step: {step}, {close_cnt}/{FLAGS.test_batch} are closely' + f' accurate, {equal_cnt}/{FLAGS.test_batch} are absolutely accurate.') + + def main(argv): del argv + config() if FLAGS.mode == 'train': train() elif FLAGS.mode == 'export': export() elif FLAGS.mode == 'test': test() + elif FLAGS.mode == 'inference': + inference() else: raise ValueError('running mode only supports `train` or `test`') diff --git a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/start.sh b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/start.sh index 0b1d0ab29..ae0e918f1 100644 --- a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/start.sh +++ b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/start.sh @@ -2,4 +2,5 @@ rm -rf ./export_dir gpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) export gpu_num -horovodrun -np $gpu_num python movielens-1m-keras-with-horovod.py --mode="train" --model_dir="./model_dir" --export_dir="./export_dir" \ No newline at end of file +horovodrun -np $gpu_num python movielens-1m-keras-with-horovod.py --mode="train" --model_dir="./model_dir" --export_dir="./export_dir" \ + --steps_per_epoch=${1:-20000} --shuffle={2:-True} \ No newline at end of file diff --git a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/test.sh b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/test.sh new file mode 100644 index 000000000..b3109d55a --- /dev/null +++ b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/test.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +python movielens-1m-keras-with-horovod.py --mode=${1:-"test"} --export_dir="./export_dir" \ No newline at end of file diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks_test.py index 3982eaa49..6bc142a7c 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks_test.py @@ -1,6 +1,6 @@ import pytest from tensorflow_recommenders_addons.dynamic_embedding.python.keras.callbacks import \ - DEHvdBroadcastGlobalVariablesCallback, DEHvdModelCheckpoint + DEHvdBroadcastGlobalVariablesCallback @pytest.fixture diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py index 99003cf73..6f01550eb 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py @@ -144,7 +144,8 @@ def __init__(self, devices: List of devices to place the embedding layer parameter. name: Name of the embedding layer. with_unique: Bool. Whether if the layer does unique on `ids`. Default is True. - + must set with_unique to true in the GPU case due to the default kv is HKV hashtable, + and HKV requires unique key **kwargs: trainable: Bool. Whether if the layer is trainable. Default is True. bp_v2: Bool. If true, the embedding layer will be updated by incremental @@ -263,7 +264,8 @@ def call(self, ids): Args: ids: feature ids of the input. It should be same dtype as the key_dtype - of the layer. + of the layer. ids must be unique or set with_unique to true in the GPU case + due to the default kv is HKV hashtable and HKV requires unique key Returns: A embedding output with shape (shape(ids), embedding_size). diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding_test.py index 465fb4b71..7d050fade 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding_test.py @@ -687,7 +687,7 @@ def test_forward(self): init = tf.keras.initializers.Zeros() de_layer = de.keras.layers.SquashedEmbedding(2, initializer=init, - key_dtype=dtypes.int32, + key_dtype=dtypes.int64, value_dtype=dtypes.float32, name='tr423') dense_init = tf.keras.initializers.Ones() @@ -696,13 +696,13 @@ def test_forward(self): embeddings_initializer=dense_init, name='mt047') - preset_ids = constant_op.constant([3, 0, 1], dtype=dtypes.int32) + preset_ids = constant_op.constant([3, 0, 1], dtype=dtypes.int64) preset_values = constant_op.constant([[1, 1], [1, 1], [1, 1]], dtype=dtypes.float32) de_layer.params.upsert(preset_ids, preset_values) - de_ids = constant_op.constant([3, 0, 1, 2], dtype=tf.int32) + de_ids = constant_op.constant([3, 0, 1, 2], dtype=tf.int64) output = de_layer(de_ids) - tf_ids = constant_op.constant([3, 0, 1], dtype=tf.int32) + tf_ids = constant_op.constant([3, 0, 1], dtype=tf.int64) expected = tf_layer(tf_ids) expected = tf.reduce_sum(expected, axis=0) self.assertAllEqual(output, expected) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py index f1280dba6..8b991226a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py @@ -138,7 +138,7 @@ def _ids_and_weights_2d(embed_dim=4, ragged=False): # Row 3: single id # Row 4: all ids have <=0 weight indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [4, 0], [4, 1]] - ids = [0, 1, -1, -1, 2, 0, 1] + ids = [0, 1, -100, -100, 2, 0, 1] weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5] shape = [5, embed_dim] @@ -177,7 +177,7 @@ def _ids_and_weights_3d( [1, 1, 0], [1, 1, 1], ] - ids = [0, 1, -1, -1, 2, 0, 1] + ids = [0, 1, -100, -100, 2, 0, 1] weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5] shape = [2, 3, embed_dim] @@ -945,7 +945,7 @@ def test_safe_embedding_lookup_sparse_return_zero_vector(self, ragged=False): 0, 1, 2, - -1, + -100, ]) # init @@ -987,7 +987,7 @@ def test_safe_embedding_lookup_sparse_return_special_vector( embedding_weights = _random_weights(embed_dim=dim) sparse_ids, sparse_weights = _ids_and_weights_2d(embed_dim=dim, ragged=ragged) - valid_ids = np.array([0, 1, 2, 3, -1]) + valid_ids = np.array([0, 1, 2, 3, -100]) # init weights = embedding_weights.lookup(valid_ids) @@ -1025,7 +1025,7 @@ def test_safe_embedding_lookup_sparse_no_weights(self, ragged=False): embedding_weights = _random_weights(embed_dim=dim) sparse_ids, sparse_weights = _ids_and_weights_2d(embed_dim=dim, ragged=ragged) - valid_ids = np.array([0, 1, 2, -1]) + valid_ids = np.array([0, 1, 2, -100]) # init weights = embedding_weights.lookup(valid_ids) @@ -1063,7 +1063,7 @@ def test_safe_embedding_lookup_sparse_partitioned(self, ragged=False): embedding_weights = _random_weights(embed_dim=dim, num_shards=3) sparse_ids, sparse_weights = _ids_and_weights_2d(embed_dim=dim, ragged=ragged) - valid_ids = np.array([0, 1, 2, -1]) + valid_ids = np.array([0, 1, 2, -100]) # init weights = embedding_weights.lookup(valid_ids) @@ -1135,7 +1135,7 @@ def fn(): def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): - valid_ids = np.array([0, 1, 2, -1]) + valid_ids = np.array([0, 1, 2, -100]) embedding_weights, embedding_weights_values, sparse_ids, sparse_weights = self._get_ids_and_weights_3d( valid_ids) @@ -1163,7 +1163,7 @@ def test_safe_embedding_lookup_sparse_3d_return_special_vector(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): embedding_weights, embedding_weights_values, sparse_ids, sparse_weights = self._get_ids_and_weights_3d( - np.array([0, 1, 2, 3, -1])) + np.array([0, 1, 2, 3, -100])) embedding_lookup_result = de.safe_embedding_lookup_sparse( embedding_weights, sparse_ids, sparse_weights, default_id=3) embedding_lookup_result = embedding_lookup_result.numpy( @@ -1190,7 +1190,7 @@ def test_safe_embedding_lookup_sparse_3d_return_special_vector(self): def test_safe_embedding_lookup_sparse_3d_no_weights(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): - valid_ids = np.array([0, 1, 2, -1]) + valid_ids = np.array([0, 1, 2, -100]) embedding_weights, embedding_weights_values, sparse_ids, _ = self._get_ids_and_weights_3d( valid_ids) embedding_lookup_result = de.safe_embedding_lookup_sparse( @@ -1221,7 +1221,7 @@ def test_safe_embedding_lookup_sparse_3d_partitioned(self): config=default_config): embedding_weights = _random_weights(num_shards=3) sparse_ids, _ = _ids_and_weights_3d() - valid_ids = np.array([0, 1, 2, -1]) + valid_ids = np.array([0, 1, 2, -100]) # init embedding_weights_values = embedding_weights.lookup(valid_ids) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py index 036525fae..950d56adf 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py @@ -1556,29 +1556,6 @@ def test_dynamic_embedding_variable_invalid_shape(self): self.evaluate(table.upsert(keys, values)) self.assertAllEqual(3, self.evaluate(table.size())) - def test_dynamic_embedding_variable_duplicate_insert(self): - with self.session(use_gpu=test_util.is_gpu_available(), - config=default_config): - default_val = -1 - keys = constant_op.constant([0, 1, 2, 2], dtypes.int64) - values = constant_op.constant([[0.0], [1.0], [2.0], [3.0]], - dtypes.float32) - table = de.get_variable("t130", - dtypes.int64, - dtypes.float32, - initializer=default_val) - self.assertAllEqual(0, self.evaluate(table.size())) - - self.evaluate(table.upsert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) - - input_keys = constant_op.constant([0, 1, 2], dtypes.int64) - output = table.lookup(input_keys) - - result = self.evaluate(output) - self.assertTrue( - list(result) in [[[0.0], [1.0], [3.0]], [[0.0], [1.0], [2.0]]]) - def test_dynamic_embedding_variable_find_high_rank(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/ragged_embedding_ops_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/ragged_embedding_ops_test.py index 93ac3fd46..7db6c4869 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/ragged_embedding_ops_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/ragged_embedding_ops_test.py @@ -20,3 +20,8 @@ def test_fill_empty_rows(self): tf.debugging.assert_equal(filled_ragged_tensor.to_tensor(), expected_filled.to_tensor()) tf.debugging.assert_equal(is_row_empty, expected_empty) + + +from tensorflow.python.platform import test +if __name__ == "__main__": + test.main() \ No newline at end of file diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py index 35f65dba5..cf7709e5a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py @@ -24,7 +24,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import device as tf_device -from tensorflow.python.ops import array_ops from tensorflow.python.ops.lookup_ops import LookupInterface from tensorflow.python.training.saver import BaseSaverBuilder diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py index c943bc6eb..e8fd70367 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py @@ -1328,7 +1328,8 @@ def embedding_lookup( Ids are flattened to a 1d tensor before being passed to embedding_lookup then, they are unflattend to match the original ids shape plus an extra leading dimension of the size of the embeddings. - + ids must be unique or call safe_embedding_lookup_sparse in the GPU case + if you use HKV hashtable since HKV requires unique key Args: params: A dynamic_embedding.Variable instance. ids: A tensor with any shape as same dtype of params.key_dtype.