Skip to content

Commit

Permalink
remove estimator for tf 2.16
Browse files Browse the repository at this point in the history
  • Loading branch information
jq authored and rhdong committed Aug 26, 2024
1 parent bc8a9b6 commit d2ae283
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/util.h"
#if TF_VERSION_INTEGER >= 2160
#include "unsupported/Eigen/CXX11/Tensor"
#include "Eigen/Core"
#include "unsupported/Eigen/CXX11/Tensor"
#else
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
Expand Down Expand Up @@ -84,7 +84,6 @@ class SparseSegmentSumGpuOp : public AsyncOpKernel {
explicit SparseSegmentSumGpuOp(OpKernelConstruction* context)
: AsyncOpKernel(context){};


void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
const Tensor& input_data = context->input(0);
const Tensor& indices = context->input(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@
except:
from tensorflow.python.training.tracking.util import Checkpoint
from tensorflow.python.util import compat
from tensorflow_estimator.python.estimator import estimator
from tensorflow_estimator.python.estimator import estimator_lib

try: # tf version <= 2.15
from tensorflow_estimator.python.estimator import estimator
from tensorflow_estimator.python.estimator import estimator_lib
except:
# do nothing
pass

try: # The data_structures has been moved to the new package in tf 2.11
from tensorflow.python.trackable import data_structures
Expand Down Expand Up @@ -970,52 +975,56 @@ def test_table_save_load_local_file_system(self):

del table

def test_table_save_load_local_file_system_for_estimator(self):
try: # only test for tensorflow <= 2.15

def input_fn():
return {"x": constant_op.constant([1], dtype=dtypes.int64)}
def test_table_save_load_local_file_system_for_estimator(self):

def model_fn(features, labels, mode, params):
file_system_saver = de.FileSystemSaver()
embedding = de.get_variable(
name="embedding",
dim=3,
trainable=False,
key_dtype=dtypes.int64,
value_dtype=dtypes.float32,
initializer=-1.0,
kv_creator=de.CuckooHashTableCreator(saver=file_system_saver),
)
lookup = de.embedding_lookup(embedding, features["x"])
upsert = embedding.upsert(features["x"],
constant_op.constant([[1.0, 2.0, 3.0]]))

with ops.control_dependencies([lookup, upsert]):
train_op = state_ops.assign_add(training.get_global_step(), 1)

scaffold = training.Scaffold(
saver=saver.Saver(sharded=True,
max_to_keep=1,
keep_checkpoint_every_n_hours=None,
defer_build=True,
save_relative_paths=True))
est = estimator_lib.EstimatorSpec(mode=mode,
scaffold=scaffold,
loss=constant_op.constant(0.),
train_op=train_op,
predictions=lookup)
return est
def input_fn():
return {"x": constant_op.constant([1], dtype=dtypes.int64)}

save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
def model_fn(features, labels, mode, params):
file_system_saver = de.FileSystemSaver()
embedding = de.get_variable(
name="embedding",
dim=3,
trainable=False,
key_dtype=dtypes.int64,
value_dtype=dtypes.float32,
initializer=-1.0,
kv_creator=de.CuckooHashTableCreator(saver=file_system_saver),
)
lookup = de.embedding_lookup(embedding, features["x"])
upsert = embedding.upsert(features["x"],
constant_op.constant([[1.0, 2.0, 3.0]]))

with ops.control_dependencies([lookup, upsert]):
train_op = state_ops.assign_add(training.get_global_step(), 1)

scaffold = training.Scaffold(
saver=saver.Saver(sharded=True,
max_to_keep=1,
keep_checkpoint_every_n_hours=None,
defer_build=True,
save_relative_paths=True))
est = estimator_lib.EstimatorSpec(mode=mode,
scaffold=scaffold,
loss=constant_op.constant(0.),
train_op=train_op,
predictions=lookup)
return est

save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")

# train and save
est = estimator.Estimator(model_fn=model_fn, model_dir=save_path)
est.train(input_fn=input_fn, steps=1)
# train and save
est = estimator.Estimator(model_fn=model_fn, model_dir=save_path)
est.train(input_fn=input_fn, steps=1)

# restore and predict
predict_results = next(est.predict(input_fn=input_fn))
self.assertAllEqual(predict_results, [1.0, 2.0, 3.0])
# restore and predict
predict_results = next(est.predict(input_fn=input_fn))
self.assertAllEqual(predict_results, [1.0, 2.0, 3.0])
except:
pass

def test_save_restore_only_table(self):
if context.executing_eagerly():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,12 @@
kinit2 = None
pass # for compatible with TF < 2.3.x

from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.keras import initializers as keras_init_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training import device_setter
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow_recommenders_addons import dynamic_embedding as de

import tensorflow as tf
Expand Down Expand Up @@ -213,10 +195,14 @@ def test_warm_start_rename(self):
self._test_warm_start_rename(num_shards, True)
self._test_warm_start_rename(num_shards, False)

def test_warm_start_estimator(self):
for num_shards in [1, 3]:
self._test_warm_start_estimator(num_shards, True)
self._test_warm_start_estimator(num_shards, False)
try: # tf version <= 2.15

def test_warm_start_estimator(self):
for num_shards in [1, 3]:
self._test_warm_start_estimator(num_shards, True)
self._test_warm_start_estimator(num_shards, False)
except:
print(f"estimator is not supported in this version of tensorflow")


if __name__ == "__main__":
Expand Down

0 comments on commit d2ae283

Please sign in to comment.