Skip to content

Commit

Permalink
fix raggedtensor should pass name in the eager mode
Browse files Browse the repository at this point in the history
  • Loading branch information
jq authored and rhdong committed Jun 6, 2024
1 parent 2c79a0d commit f5d6861
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,6 @@ def test_embedding_lookup_unique(self):
np.testing.assert_almost_equal(embedded_np, embedded_de)


@test_util.run_all_in_graph_and_eager_modes
class EmbeddingLookupSparseTest(test.TestCase, parameterized.TestCase):

def _random_ids_and_weights(self,
Expand Down Expand Up @@ -801,6 +800,7 @@ def _group_by_batch_entry(self, vals, vals_per_batch_entry):
index += num_val
return grouped_vals

@test_util.run_all_in_graph_and_eager_modes
@parameterized.parameters(itertools.product([True, False]))
def test_embedding_lookup_sparse(self, ragged):
var_id = 0
Expand Down Expand Up @@ -863,8 +863,6 @@ def test_embedding_lookup_sparse(self, ragged):
) else random_init.eval()
grouped_params = self._group_by_batch_entry(np_params,
vals_per_batch_entry)
if context.executing_eagerly():
params = de.shadow_ops.ShadowVariable(params)
if ragged:
embedding_sum = embedding_lookup_sparse(
params,
Expand Down Expand Up @@ -900,6 +898,7 @@ def test_embedding_lookup_sparse(self, ragged):
atol = rtol
self.assertAllClose(np_embedding_sum, tf_embedding_sum, rtol, atol)

@test_util.run_all_in_graph_and_eager_modes
def test_embedding_lookup_sparse_shape_checking(self):
if context.executing_eagerly():
self.skipTest("Skip eager test")
Expand All @@ -920,7 +919,6 @@ def test_embedding_lookup_sparse_shape_checking(self):
embedding_lookup_test.get_shape().as_list())


@test_util.run_all_in_graph_and_eager_modes
class SafeEmbeddingLookupSparseTest(test.TestCase, parameterized.TestCase):

def _get_ids_and_weights_3d(self, valid_ids):
Expand All @@ -932,10 +930,9 @@ def _get_ids_and_weights_3d(self, valid_ids):
embedding_weights_values = embedding_weights_values.numpy(
) if context.executing_eagerly() else embedding_weights_values.eval()
self.evaluate(embedding_weights.upsert(valid_ids, embedding_weights_values))
if context.executing_eagerly():
embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights)
return embedding_weights, embedding_weights_values, sparse_ids, sparse_weights

@test_util.run_all_in_graph_and_eager_modes
@parameterized.parameters(itertools.product([True, False]))
def test_safe_embedding_lookup_sparse_return_zero_vector(self, ragged=False):
with self.session(use_gpu=test_util.is_gpu_available(),
Expand All @@ -958,9 +955,6 @@ def test_safe_embedding_lookup_sparse_return_zero_vector(self, ragged=False):
self.evaluate(
embedding_weights.upsert(valid_ids, embedding_weights_values))

# check
if context.executing_eagerly():
embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights)
if ragged:
embedding_lookup_result = safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, sparse_weights)
Expand All @@ -983,6 +977,7 @@ def test_safe_embedding_lookup_sparse_return_zero_vector(self, ragged=False):
],
)

@test_util.run_all_in_graph_and_eager_modes
@parameterized.parameters(itertools.product([True, False]))
def test_safe_embedding_lookup_sparse_return_special_vector(
self, ragged=False):
Expand All @@ -1000,10 +995,6 @@ def test_safe_embedding_lookup_sparse_return_special_vector(
) else weights.eval()
self.evaluate(
embedding_weights.upsert(valid_ids, embedding_weights_values))

# check
if context.executing_eagerly():
embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights)
if ragged:
embedding_lookup_result = safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, sparse_weights, default_id=3)
Expand All @@ -1025,6 +1016,7 @@ def test_safe_embedding_lookup_sparse_return_special_vector(
],
)

@test_util.run_all_in_graph_and_eager_modes
@parameterized.parameters(itertools.product([True, False]))
def test_safe_embedding_lookup_sparse_no_weights(self, ragged=False):
with self.session(use_gpu=test_util.is_gpu_available(),
Expand All @@ -1041,9 +1033,6 @@ def test_safe_embedding_lookup_sparse_no_weights(self, ragged=False):
) else weights.eval()
self.evaluate(
embedding_weights.upsert(valid_ids, embedding_weights_values))

if context.executing_eagerly():
embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights)
if ragged:
embedding_lookup_result = safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, None)
Expand All @@ -1065,6 +1054,7 @@ def test_safe_embedding_lookup_sparse_no_weights(self, ragged=False):
],
)

@test_util.run_all_in_graph_and_eager_modes
@parameterized.parameters(itertools.product([True, False]))
def test_safe_embedding_lookup_sparse_partitioned(self, ragged=False):
with self.session(use_gpu=test_util.is_gpu_available(),
Expand All @@ -1081,9 +1071,6 @@ def test_safe_embedding_lookup_sparse_partitioned(self, ragged=False):
) else weights.eval()
self.evaluate(
embedding_weights.upsert(valid_ids, embedding_weights_values))

if context.executing_eagerly():
embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights)
if ragged:
embedding_lookup_result = safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, None)
Expand All @@ -1105,6 +1092,7 @@ def test_safe_embedding_lookup_sparse_partitioned(self, ragged=False):
],
)

@test_util.run_all_in_graph_and_eager_modes
@parameterized.parameters(itertools.product([True, False]))
def test_safe_embedding_lookup_sparse_inconsistent_ids_type(
self, ragged=False):
Expand All @@ -1115,8 +1103,6 @@ def fn():
embedding_weights = _random_weights(num_shards=3,
key_dtype=dtypes.int32)
sparse_ids, sparse_weights = _ids_and_weights_2d(ragged=ragged)
if context.executing_eagerly():
embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights)
if ragged:
safe_embedding_lookup_sparse(embedding_weights, sparse_ids,
sparse_weights)
Expand All @@ -1126,6 +1112,7 @@ def fn():

self.assertRaises(TypeError, fn)

@test_util.run_all_in_graph_and_eager_modes
@parameterized.parameters(itertools.product([True, False]))
def test_safe_embedding_lookup_sparse_inconsistent_weights_type(
self, ragged=False):
Expand All @@ -1135,8 +1122,6 @@ def test_safe_embedding_lookup_sparse_inconsistent_weights_type(
def fn():
embedding_weights = _random_weights(num_shards=3, key_dtype=dtypes.half)
sparse_ids, sparse_weights = _ids_and_weights_2d(ragged=ragged)
if context.executing_eagerly():
embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights)
if ragged:
safe_embedding_lookup_sparse(embedding_weights, sparse_ids,
sparse_weights)
Expand All @@ -1146,6 +1131,7 @@ def fn():

self.assertRaises(TypeError, fn)

@test_util.run_all_in_graph_and_eager_modes
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
with self.session(use_gpu=test_util.is_gpu_available(),
config=default_config):
Expand All @@ -1172,6 +1158,7 @@ def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
],
)

@test_util.run_all_in_graph_and_eager_modes
def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
with self.session(use_gpu=test_util.is_gpu_available(),
config=default_config):
Expand Down Expand Up @@ -1199,6 +1186,7 @@ def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
],
)

@test_util.run_all_in_graph_and_eager_modes
def test_safe_embedding_lookup_sparse_3d_no_weights(self):
with self.session(use_gpu=test_util.is_gpu_available(),
config=default_config):
Expand Down Expand Up @@ -1227,6 +1215,7 @@ def test_safe_embedding_lookup_sparse_3d_no_weights(self):
],
)

@test_util.run_all_in_graph_and_eager_modes
def test_safe_embedding_lookup_sparse_3d_partitioned(self):
with self.session(use_gpu=test_util.is_gpu_available(),
config=default_config):
Expand All @@ -1240,8 +1229,6 @@ def test_safe_embedding_lookup_sparse_3d_partitioned(self):
) if context.executing_eagerly() else embedding_weights_values.eval()
self.evaluate(
embedding_weights.upsert(valid_ids, embedding_weights_values))
if context.executing_eagerly():
embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights)
embedding_lookup_result = de.safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, None)
embedding_lookup_result = embedding_lookup_result.numpy(
Expand All @@ -1265,6 +1252,7 @@ def test_safe_embedding_lookup_sparse_3d_partitioned(self):
],
)

@test_util.run_all_in_graph_and_eager_modes
def test_safe_embedding_lookup_sparse_with_initializer(self):
id = 0
embed_dim = 8
Expand Down Expand Up @@ -1315,9 +1303,6 @@ def test_safe_embedding_lookup_sparse_with_initializer(self):
constant_op.constant(ids, dtypes.int64),
constant_op.constant(dense_shape, dtypes.int64),
)
if context.executing_eagerly():
embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights)

vals_op = de.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None,
Expand All @@ -1333,6 +1318,7 @@ def test_safe_embedding_lookup_sparse_with_initializer(self):
self.assertAllClose(target_mean, mean, rtol, atol)
self.assertAllClose(target_stddev, stddev, rtol, atol)

@test_util.run_all_in_graph_and_eager_modes
def test_safe_embedding_lookup_sparse_shape_checking(self):
if context.executing_eagerly():
self.skipTest("Skip eager test")
Expand All @@ -1354,6 +1340,7 @@ def test_safe_embedding_lookup_sparse_shape_checking(self):
self.assertAllEqual(embedding_lookup_base.get_shape(),
embedding_lookup_test.get_shape())

@test_util.run_all_in_graph_and_eager_modes
def test_dynamic_embedding_variable_clear(self):
with self.session(use_gpu=test_util.is_gpu_available(),
config=default_config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Dynamic Embedding is designed for Large-scale Sparse Weights Training.
See [Sparse Domain Isolation](https://github.com/tensorflow/community/pull/237)
"""
from tensorflow.python.ops.variables import VariableAggregation

from tensorflow_recommenders_addons import dynamic_embedding as de
from tensorflow_recommenders_addons.utils.resource_loader import get_tf_version_triple
Expand Down Expand Up @@ -62,7 +63,6 @@
except:
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.util import compat, dispatch
from tensorflow.python.util.tf_export import tf_export

from tensorflow.python.keras.utils import tf_utils
try: # tf version >= 2.14.0
Expand Down Expand Up @@ -649,7 +649,7 @@ def _create_or_get_trainable(trainable_name):

with ops.colocate_with(ids, ignore_existing=True):
if distribute_ctx.has_strategy():
trainable_ = _distribute_trainable_store.get(name, None)
trainable_ = params._distribute_trainable_store.get(name, None)
if trainable_ is None:
strategy_devices = distribute_ctx.get_strategy(
).extended.worker_devices
Expand Down Expand Up @@ -773,7 +773,8 @@ def embedding_lookup_sparse(
Args:
params: A single `dynamic_embedding.Variable` instance representing
the complete embedding tensor or a `ShadowVariable` instance.
the complete embedding tensor and a new TrainableWrapper will be created and return
or a `ShadowVariable` instance, then params will be return without creating a new TrainableWrapper
sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size
and M is arbitrary.
sp_weights: either a `SparseTensor` of float / double weights, or `None` to
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes, ops
from tensorflow.python.ops import resource_variable_ops, array_ops, math_ops, gen_ragged_array_ops, gen_math_ops
from tensorflow.python.ops.bincount_ops import validate_dense_weights
Expand All @@ -19,7 +20,7 @@ def _bincount(arr,
binary_output=False):

name = "bincount" if name is None else name
with ops.name_scope(name):
with tf.name_scope(name):
arr = tf.convert_to_tensor(arr, name="arr")
if weights is not None:
weights = tf.convert_to_tensor(weights, name="weights")
Expand Down Expand Up @@ -144,7 +145,10 @@ def _embedding_lookup_sparse_impl(
if isinstance(params, de.shadow_ops.ShadowVariable):
embeddings = de.shadow_ops.embedding_lookup(params, ids)
else:
embeddings = de.embedding_lookup(params, ids)
if context.executing_eagerly():
embeddings = de.embedding_lookup(params, ids, name=name)
else:
embeddings = de.embedding_lookup(params, ids)

if not ignore_weights:
if segment_ids.dtype != dtypes.int32:
Expand Down Expand Up @@ -314,8 +318,8 @@ def embedding_lookup_sparse(
rt_ids.values.get_shape().assert_is_compatible_with(
rt_weights.values.get_shape())
rt_ids.get_shape().assert_is_compatible_with(rt_weights.get_shape())
#
with ops.name_scope(name, "embedding_lookup_sparse") as name:

with tf.name_scope(name or "embedding_lookup_sparse") as name:
segment_ids = rt_ids.value_rowids()
ids = rt_ids.flat_values
return _embedding_lookup_sparse_impl(
Expand Down

0 comments on commit f5d6861

Please sign in to comment.