diff --git a/MANIFEST.in b/MANIFEST.in
index e69de29bb..5576404e2 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -0,0 +1,2 @@
+include easy_rec/python/ops/1.12/*.so*
+include easy_rec/python/ops/1.15/*.so*
diff --git a/easy_rec/__init__.py b/easy_rec/__init__.py
index cfafba708..f3c7afd0b 100644
--- a/easy_rec/__init__.py
+++ b/easy_rec/__init__.py
@@ -15,6 +15,16 @@
logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
+ops_dir = os.path.join(curr_dir, 'python/ops')
+if 'PAI' in tf.__version__:
+ ops_dir = os.path.join(ops_dir, '1.12_pai')
+elif tf.__version__.startswith('1.12'):
+ ops_dir = os.path.join(ops_dir, '1.12')
+elif tf.__version__.startswith('1.15'):
+ ops_dir = os.path.join(ops_dir, '1.15')
+else:
+ ops_dir = None
+
from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402
from easy_rec.python.main import evaluate # isort:skip # noqa: E402
from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402
@@ -32,12 +42,6 @@
_global_config = {}
-ops_dir = os.path.join(curr_dir, 'python/ops')
-if tf.__version__.startswith('1.12'):
- ops_dir = os.path.join(ops_dir, '1.12')
-elif tf.__version__.startswith('1.15'):
- ops_dir = os.path.join(ops_dir, '1.15')
-
def help():
print("""
diff --git a/easy_rec/python/compat/optimizers.py b/easy_rec/python/compat/optimizers.py
index 21fede4b8..37969cb1e 100644
--- a/easy_rec/python/compat/optimizers.py
+++ b/easy_rec/python/compat/optimizers.py
@@ -35,6 +35,9 @@
from tensorflow.python.training import moving_averages
from tensorflow.python.training import optimizer as optimizer_
from tensorflow.python.training import training as train
+from easy_rec.python.ops.incr_record import set_sparse_indices
+import tensorflow as tf
+import logging
OPTIMIZER_CLS_NAMES = {
'Adagrad':
@@ -75,7 +78,8 @@ def optimize_loss(loss,
summaries=None,
colocate_gradients_with_ops=False,
not_apply_grad_after_first_step=False,
- increment_global_step=True):
+ increment_global_step=True,
+ incr_save=False):
"""Given loss and parameters for optimizer, returns a training op.
Various ways of passing optimizers include:
@@ -146,6 +150,7 @@ def optimize_loss(loss,
calls `optimize_loss` multiple times per training step (e.g. to optimize
different parts of the model), use this arg to avoid incrementing
`global_step` more times than necessary.
+ incr_save: increment dump checkpoints.
Returns:
Training op.
@@ -300,11 +305,23 @@ def optimize_loss(loss,
# Create gradient updates.
def _apply_grad():
+ incr_save_ops = []
+ if incr_save:
+ for grad, var in gradients:
+ if isinstance(grad, ops.IndexedSlices):
+ with ops.colocate_with(var):
+ incr_save_op = set_sparse_indices(grad.indices, var_name=var.op.name)
+ incr_save_ops.append(incr_save_op)
+ ops.add_to_collection('SPARSE_UPDATE_VARIABLES', (var, grad.indices.dtype))
+ else:
+ ops.add_to_collection('DENSE_UPDATE_VARIABLES', var)
+
grad_updates = opt.apply_gradients(
gradients,
global_step=global_step if increment_global_step else None,
name='train')
- return control_flow_ops.with_dependencies([grad_updates], loss)
+
+ return control_flow_ops.with_dependencies([grad_updates] + incr_save_ops, loss)
if not_apply_grad_after_first_step:
train_tensor = control_flow_ops.cond(global_step > 0, lambda: loss,
diff --git a/easy_rec/python/input/datahub_input.py b/easy_rec/python/input/datahub_input.py
index 8e86feab7..d5ef8b90d 100644
--- a/easy_rec/python/input/datahub_input.py
+++ b/easy_rec/python/input/datahub_input.py
@@ -4,7 +4,9 @@
import time
import numpy as np
+import json
import tensorflow as tf
+import traceback
from easy_rec.python.input.input import Input
from easy_rec.python.utils import odps_util
@@ -18,14 +20,18 @@
from datahub.exceptions import DatahubException
from datahub.models import RecordType
from datahub.models import CursorType
-except Exception:
+ import urllib3
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
+ logging.getLogger('datahub.account').setLevel(logging.INFO)
+except Exception as ex:
+ # logging.warning(traceback.format_exc(ex))
logging.warning(
'DataHub is not installed. You can install it by: pip install pydatahub')
DataHub = None
-
class DataHubInput(Input):
- """Common IO based interface, could run at local or on data science."""
+ """DataHubInput is used for online train."""
+
def __init__(self,
data_config,
@@ -35,27 +41,59 @@ def __init__(self,
task_num=1):
super(DataHubInput, self).__init__(data_config, feature_config, '',
task_index, task_num)
- if DataHub is None:
- logging.error('please install datahub: ',
- 'pip install pydatahub ;Python 3.6 recommended')
+
try:
- self._datahub_config = datahub_config
- if self._datahub_config is None:
- pass
- self._datahub = DataHub(self._datahub_config.akId,
- self._datahub_config.akSecret,
- self._datahub_config.region)
self._num_epoch = 0
+ self._datahub_config = datahub_config
+ if self._datahub_config is not None:
+ akId = self._datahub_config.akId
+ akSecret = self._datahub_config.akSecret
+ region = self._datahub_config.region
+ if not isinstance(akId, str):
+ akId = akId.encode('utf-8')
+ akSecret = akSecret.encode('utf-8')
+ region = region.encode('utf-8')
+ self._datahub = DataHub(akId, akSecret, region)
+ else:
+ self._datahub = None
except Exception as ex:
- logging.info('exception in init datahub:', str(ex))
+ logging.info('exception in init datahub: %s' % str(ex))
pass
+ self._offset_dict = {}
+ if datahub_config:
+ if self._datahub_config.offset_info:
+ self._offset_dict = json.loads(self._datahub_config.offset_info)
+ shard_result = self._datahub.list_shard(self._datahub_config.project,
+ self._datahub_config.topic)
+ shards = shard_result.shards
+ self._shards = [shards[i] for i in range(len(shards)) if (i % task_num) == task_index]
+ logging.info('all shards: %s' % str(self._shards))
+ offset_dict = {}
+ for x in self._shards:
+ if x.shard_id in self._offset_dict:
+ offset_dict[x.shard_id] = self._offset_dict[x.shard_id]
+ self._offset_dict = offset_dict
def _parse_record(self, *fields):
fields = list(fields)
- inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
+ field_dict = {self._input_fields[x]: fields[x] for x in self._effective_fids}
for x in self._label_fids:
- inputs[self._input_fields[x]] = fields[x]
- return inputs
+ field_dict[self._input_fields[x]] = fields[x]
+ field_dict[Input.DATA_OFFSET] = fields[-1]
+ return field_dict
+
+ def _preprocess(self, field_dict):
+ output_dict = super(DataHubInput, self)._preprocess(field_dict)
+
+ # append offset fields
+ if Input.DATA_OFFSET in field_dict:
+ output_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
+
+ # for _get_features to include DATA_OFFSET
+ if Input.DATA_OFFSET not in self._appended_fields:
+ self._appended_fields.append(Input.DATA_OFFSET)
+
+ return output_dict
def _datahub_generator(self):
logging.info('start epoch[%d]' % self._num_epoch)
@@ -65,62 +103,87 @@ def _datahub_generator(self):
self.get_type_defaults(x, v)
for x, v in zip(self._input_field_types, self._input_field_defaults)
]
- batch_defaults = [
- np.array([x] * self._data_config.batch_size) for x in record_defaults
+ batch_data = [
+ np.asarray([x] * self._data_config.batch_size, order='C', dtype=object)
+ if isinstance(x, str) else
+ np.array([x] * self._data_config.batch_size)
+ for x in record_defaults
]
+ batch_data.append(json.dumps(self._offset_dict))
+
try:
self._datahub.wait_shards_ready(self._datahub_config.project,
self._datahub_config.topic)
topic_result = self._datahub.get_topic(self._datahub_config.project,
self._datahub_config.topic)
if topic_result.record_type != RecordType.TUPLE:
- logging.error('topic type illegal !')
+ logging.error('datahub topic type(%s) illegal' % str(topic_result.record_type))
record_schema = topic_result.record_schema
- shard_result = self._datahub.list_shard(self._datahub_config.project,
- self._datahub_config.topic)
- shards = shard_result.shards
- for shard in shards:
- shard_id = shard._shard_id
- cursor_result = self._datahub.get_cursor(self._datahub_config.project,
+
+ batch_size = self._data_config.batch_size
+
+ tid = 0
+ while True:
+ shard_id = self._shards[tid].shard_id
+ tid += 1
+ if tid >= len(self._shards):
+ tid = 0
+ if shard_id not in self._offset_dict:
+ cursor_result = self._datahub.get_cursor(self._datahub_config.project,
self._datahub_config.topic,
shard_id, CursorType.OLDEST)
- cursor = cursor_result.cursor
- limit = self._data_config.batch_size
- while True:
- get_result = self._datahub.get_tuple_records(
- self._datahub_config.project, self._datahub_config.topic,
- shard_id, record_schema, cursor, limit)
- batch_data_np = [x.copy() for x in batch_defaults]
- for row_id, record in enumerate(get_result.records):
- for col_id in range(len(record_defaults)):
- if record.values[col_id] not in ['', 'Null', None]:
- batch_data_np[col_id][row_id] = record.values[col_id]
- yield tuple(batch_data_np)
- if 0 == get_result.record_count:
- time.sleep(1)
- cursor = get_result.next_cursor
- except DatahubException as e:
- logging.error(e)
+ cursor = cursor_result.cursor
+ else:
+ cursor = self._offset_dict[shard_id]['cursor']
+
+ get_result = self._datahub.get_tuple_records(
+ self._datahub_config.project, self._datahub_config.topic,
+ shard_id, record_schema, cursor, batch_size)
+ count = get_result.record_count
+ if count == 0:
+ continue
+ time_offset = 0
+ sequence_offset = 0
+ for row_id, record in enumerate(get_result.records):
+ if record.system_time > time_offset:
+ time_offset = record.system_time
+ if record.sequence > sequence_offset:
+ sequence_offset = record.sequence
+ for col_id in range(len(record_defaults)):
+ if record.values[col_id] not in ['', 'Null', 'null', 'NULL', None]:
+ batch_data[col_id][row_id] = record.values[col_id]
+ else:
+ batch_data[col_id][row_id] = record_defaults[col_id]
+ cursor = get_result.next_cursor
+ self._offset_dict[shard_id] = {'sequence_offset': sequence_offset,
+ 'time_offset': time_offset,
+ 'cursor': cursor
+ }
+ batch_data[-1] = json.dumps(self._offset_dict)
+ yield tuple(batch_data)
+ except DatahubException as ex:
+ logging.error('DatahubException: %s' % str(ex))
def _build(self, mode, params):
- # get input type
- list_type = [self.get_tf_type(x) for x in self._input_field_types]
- list_type = tuple(list_type)
- list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
+ # get input types
+ list_types = [self.get_tf_type(x) for x in self._input_field_types]
+ list_types.append(tf.string)
+ list_types = tuple(list_types)
+ list_shapes = [tf.TensorShape([None]) for x in range(0, len(self._input_field_types))]
+ list_shapes.append(tf.TensorShape([]))
list_shapes = tuple(list_shapes)
# read datahub
dataset = tf.data.Dataset.from_generator(
self._datahub_generator,
- output_types=list_type,
+ output_types=list_types,
output_shapes=list_shapes)
if mode == tf.estimator.ModeKeys.TRAIN:
- dataset = dataset.shuffle(
- self._data_config.shuffle_buffer_size,
- seed=2020,
- reshuffle_each_iteration=True)
- dataset = dataset.repeat(self.num_epochs)
- else:
- dataset = dataset.repeat(1)
+ if self._data_config.shuffle:
+ dataset = dataset.shuffle(
+ self._data_config.shuffle_buffer_size,
+ seed=2020,
+ reshuffle_each_iteration=True)
+
dataset = dataset.map(
self._parse_record,
num_parallel_calls=self._data_config.num_parallel_calls)
diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py
index c0d8653bf..9bb20e599 100644
--- a/easy_rec/python/input/input.py
+++ b/easy_rec/python/input/input.py
@@ -23,6 +23,8 @@
class Input(six.with_metaclass(_meta_type, object)):
+ DATA_OFFSET = 'DATA_OFFSET'
+
def __init__(self,
data_config,
feature_configs,
diff --git a/easy_rec/python/input/kafka_dataset.py b/easy_rec/python/input/kafka_dataset.py
new file mode 100644
index 000000000..5e2ba3a6c
--- /dev/null
+++ b/easy_rec/python/input/kafka_dataset.py
@@ -0,0 +1,144 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kafka Dataset."""
+
+import logging
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+import traceback
+
+try:
+ from easy_rec.python.ops import gen_kafka_ops
+except ImportError as ex:
+ logging.warning('failed to import gen_kafka_ops: %s' % traceback.format_exc(ex))
+
+
+class KafkaDataset(dataset_ops.Dataset):
+ """A Kafka Dataset that consumes the message."""
+
+ def __init__(self,
+ topics,
+ servers='localhost',
+ group='',
+ eof=False,
+ timeout=1000,
+ config_global=None,
+ config_topic=None,
+ message_key=False,
+ message_offset=False):
+ """Create a KafkaReader.
+
+ Args:
+ topics: A `tf.string` tensor containing one or more subscriptions,
+ in the format of [topic:partition:offset:length],
+ by default length is -1 for unlimited.
+ servers: A list of bootstrap servers.
+ group: The consumer group id.
+ eof: If True, the kafka reader will stop on EOF.
+ timeout: The timeout value for the Kafka Consumer to wait
+ (in millisecond).
+ config_global: A `tf.string` tensor containing global configuration
+ properties in [Key=Value] format,
+ eg. ["enable.auto.commit=false",
+ "heartbeat.interval.ms=2000"],
+ please refer to 'Global configuration properties'
+ in librdkafka doc.
+ config_topic: A `tf.string` tensor containing topic configuration
+ properties in [Key=Value] format,
+ eg. ["auto.offset.reset=earliest"],
+ please refer to 'Topic configuration properties'
+ in librdkafka doc.
+ message_key: If True, the kafka will output both message value and key.
+ message_offset: If True, the kafka will output both message value and offset.
+ """
+ self._topics = ops.convert_to_tensor(
+ topics, dtype=dtypes.string, name='topics')
+ self._servers = ops.convert_to_tensor(
+ servers, dtype=dtypes.string, name='servers')
+ self._group = ops.convert_to_tensor(
+ group, dtype=dtypes.string, name='group')
+ self._eof = ops.convert_to_tensor(eof, dtype=dtypes.bool, name='eof')
+ self._timeout = ops.convert_to_tensor(
+ timeout, dtype=dtypes.int64, name='timeout')
+ config_global = config_global if config_global else []
+ self._config_global = ops.convert_to_tensor(
+ config_global, dtype=dtypes.string, name='config_global')
+ config_topic = config_topic if config_topic else []
+ self._config_topic = ops.convert_to_tensor(
+ config_topic, dtype=dtypes.string, name='config_topic')
+ self._message_key = message_key
+ self._message_offset = message_offset
+ super(KafkaDataset, self).__init__()
+
+ def _inputs(self):
+ return []
+
+ def _as_variant_tensor(self):
+ return gen_kafka_ops.io_kafka_dataset_v2(
+ self._topics,
+ self._servers,
+ self._group,
+ self._eof,
+ self._timeout,
+ self._config_global,
+ self._config_topic,
+ self._message_key,
+ self._message_offset,
+ )
+
+ @property
+ def output_classes(self):
+ if self._message_key ^ self._message_offset:
+ return (ops.Tensor, ops.Tensor)
+ elif self._message_key and self._message_offset:
+ return (ops.Tensor, ops.Tensor, ops.Tensor)
+ return (ops.Tensor)
+
+ @property
+ def output_shapes(self):
+ if self._message_key ^ self._message_offset:
+ return ((tensor_shape.TensorShape([]), tensor_shape.TensorShape([])))
+ elif self._message_key and self._message_offset:
+ return ((tensor_shape.TensorShape([]), tensor_shape.TensorShape([]),
+ tensor_shape.TensorShape([])))
+ return ((tensor_shape.TensorShape([])))
+
+ @property
+ def output_types(self):
+ if self._message_key ^ self._message_offset:
+ return ((dtypes.string, dtypes.string))
+ elif self._message_key and self._message_offset:
+ return ((dtypes.string, dtypes.string, dtypes.string))
+ return ((dtypes.string))
+
+
+def write_kafka_v2(message, topic, servers='localhost', name=None):
+ """Write kafka.
+
+ Args:
+ message: A `Tensor` of type `string`. 0-D.
+ topic: A `tf.string` tensor containing one subscription,
+ in the format of topic:partition.
+ servers: A list of bootstrap servers.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `string`. 0-D.
+ """
+ return gen_kafka_ops.io_write_kafka_v2(
+ message=message, topic=topic, servers=servers, name=name)
diff --git a/easy_rec/python/input/kafka_input.py b/easy_rec/python/input/kafka_input.py
index 63bf5a4d2..40956c425 100644
--- a/easy_rec/python/input/kafka_input.py
+++ b/easy_rec/python/input/kafka_input.py
@@ -2,10 +2,19 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import sys
+import traceback
+import json
+import six
import tensorflow as tf
from easy_rec.python.input.input import Input
+from easy_rec.python.input.kafka_dataset import KafkaDataset
+
+try:
+ from kafka import KafkaConsumer
+except ImportError as ex:
+ logging.warning('kafka-python is not installed: %s' % traceback.format_exc(ex))
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -13,6 +22,8 @@
class KafkaInput(Input):
+ DATA_OFFSET = 'DATA_OFFSET'
+
def __init__(self,
data_config,
feature_config,
@@ -22,92 +33,113 @@ def __init__(self,
super(KafkaInput, self).__init__(data_config, feature_config, '',
task_index, task_num)
self._kafka = kafka_config
+ self._offset_dict = {}
+ if self._kafka is not None:
+ # each topic in the format: topic:partition_id:offset
+ self._topics = []
+ if self._kafka.offset_info:
+ offset_dict = json.loads(self._kafka.offset_info)
+ for part in offset_dict:
+ part_id = int(part)
+ if (part_id % self._task_num) == self._task_index:
+ self._offset_dict[part_id] = offset_dict[part]
+ consumer = KafkaConsumer(group_id='kafka_dataset_consumer',
+ bootstrap_servers=[self._kafka.server])
+ partitions = consumer.partitions_for_topic(self._kafka.topic)
+ num_partition = len(partitions)
+ logging.info('all partitions[%d]: %s' % (num_partition, partitions))
+ for part_id in range(num_partition):
+ if (part_id % self._task_num) == self._task_index:
+ offset = self._offset_dict.get(part_id, 0)
+ self._topics.append('%s:%d:%d' % (self._kafka.topic, part_id, offset))
+ logging.info('assigned topic partitions: %s' % (','.join(self._topics)))
+ assert len(self._topics) > 0, 'no partitions are assigned for this task(%d/%d)' % (
+ self._task_index, self._task_num)
+ else:
+ self._topics = None
+
+ def _preprocess(self, field_dict):
+ output_dict = super(KafkaInput, self)._preprocess(field_dict)
+ output_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
- def _parse_csv(self, line):
+ if Input.DATA_OFFSET not in self._appended_fields:
+ self._appended_fields.append(Input.DATA_OFFSET)
+ return output_dict
+
+ def _parse_csv(self, line, message_key, message_offset):
record_defaults = [
self.get_type_defaults(t, v)
for t, v in zip(self._input_field_types, self._input_field_defaults)
]
- def _check_data(line):
- sep = self._data_config.separator
- if type(sep) != type(str):
- sep = sep.encode('utf-8')
- field_num = len(line[0].split(sep))
- assert field_num == len(record_defaults),\
- 'sep[%s] maybe invalid: field_num=%d, required_num=%d' % (sep, field_num, len(record_defaults))
- return True
-
- check_op = tf.py_func(_check_data, [line], Tout=tf.bool)
- with tf.control_dependencies([check_op]):
- fields = tf.decode_csv(
- line,
- field_delim=self._data_config.separator,
- record_defaults=record_defaults,
- name='decode_csv')
+ fields = tf.decode_csv(
+ line,
+ use_quote_delim=False,
+ field_delim=self._data_config.separator,
+ record_defaults=record_defaults,
+ name='decode_csv')
inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
for x in self._label_fids:
inputs[self._input_fields[x]] = fields[x]
+
+ # record current offset
+ def _parse_offset(message_offset):
+ for kv in message_offset:
+ if six.PY3:
+ kv = kv.decode('utf-8')
+ k,v = kv.split(':')
+ v = int(v)
+ if k not in self._offset_dict or v > self._offset_dict[k]:
+ self._offset_dict[k] = v
+ return json.dumps(self._offset_dict)
+
+ inputs[Input.DATA_OFFSET] = tf.py_func(_parse_offset, [message_offset], tf.string)
return inputs
- def _build(self, mode, params):
- try:
- import tensorflow_io.kafka as kafka_io
- except ImportError:
- logging.error(
- 'Please install tensorflow-io, '
- 'version compatibility can refer to https://github.com/tensorflow/io#tensorflow-version-compatibility'
- )
+ def _preprocess(self, field_dict):
+ output_dict = super(KafkaInput, self)._preprocess(field_dict)
+ # append offset fields
+ if Input.DATA_OFFSET in field_dict:
+ output_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
+
+ # for _get_features to include DATA_OFFSET
+ if Input.DATA_OFFSET not in self._appended_fields:
+ self._appended_fields.append(Input.DATA_OFFSET)
+
+ return output_dict
+
+ def _build(self, mode, params):
num_parallel_calls = self._data_config.num_parallel_calls
if mode == tf.estimator.ModeKeys.TRAIN:
- train = self._kafka
- topics = []
- i = self._task_index
- assert len(train.offset) == 1 or len(train.offset) == train.partitions, \
- 'number of train.offset must be 1 or train.partitions'
- while i < train.partitions:
- offset_i = train.offset[i] if i < len(
- train.offset) else train.offset[-1]
- topics.append(train.topic + ':' + str(i) + ':' + str(offset_i) + ':-1')
- i = i + self._task_num
-
+ train_kafka = self._kafka
logging.info(
'train kafka server: %s topic: %s task_num: %d task_index: %d topics: %s'
%
- (train.server, train.topic, self._task_num, self._task_index, topics))
- if len(topics) == 0:
- logging.info('train kafka topic is empty')
- sys.exit(1)
-
- dataset = kafka_io.KafkaDataset(
- topics, servers=train.server, group=train.group, eof=False)
- dataset = dataset.repeat(1)
+ (train_kafka.server, train_kafka.topic, self._task_num, self._task_index, self._topics))
+
+ dataset = KafkaDataset(
+ self._topics,
+ servers=train_kafka.server,
+ group=train_kafka.group,
+ eof=False,
+ config_global = list(self._kafka.config_global),
+ config_topic = list(self._kafka.config_topic),
+ message_key=True,
+ message_offset=True)
else:
- eval = self._kafka
- topics = []
- i = 0
- assert len(eval.offset) == 1 or len(eval.offset) == eval.partitions, \
- 'number of eval.offset must be 1 or eval.partitions'
- while i < eval.partitions:
- offset_i = eval.offset[i] if i < len(eval.offset) else eval.offset[-1]
- topics.append(eval.topic + ':' + str(i) + ':' + str(eval.offset) +
- ':-1')
- i = i + 1
-
+ eval_kafka = self._kafka
logging.info(
'eval kafka server: %s topic: %s task_num: %d task_index: %d topics: %s'
- % (eval.server, eval.topic, self._task_num, self._task_index, topics))
-
- if len(topics) == 0:
- logging.info('eval kafka topic is empty')
- sys.exit(1)
+ % (eval_kafka.server, eval_kafka.topic, self._task_num, self._task_index, self._topics))
- dataset = kafka_io.KafkaDataset(
- topics, servers=eval.server, group=eval.group, eof=False)
- dataset = dataset.repeat(1)
+ dataset = KafkaDataset(self._topics, servers=self._kafka.server,
+ group=eval_kafka.group, eof=True,
+ config_global = list(self._kafka.config_global),
+ config_topic = list(self._kafka.config_topic),
+ message_key=True, message_offset=True)
dataset = dataset.batch(self._data_config.batch_size)
dataset = dataset.map(
diff --git a/easy_rec/python/main.py b/easy_rec/python/main.py
index cbaaf5ed2..aee434fe1 100644
--- a/easy_rec/python/main.py
+++ b/easy_rec/python/main.py
@@ -288,8 +288,12 @@ def _train_and_evaluate_impl(pipeline_config, continue_train=False):
eval_data = _get_input_object_by_name(pipeline_config, 'eval')
distribution = strategy_builder.build(train_config)
+ params = {}
+ if train_config.is_profiling:
+ params['log_device_placement'] = True
estimator, run_config = _create_estimator(
- pipeline_config, distribution=distribution)
+ pipeline_config, distribution=distribution,
+ params=params)
master_stat_file = os.path.join(pipeline_config.model_dir, 'master.stat')
version_file = os.path.join(pipeline_config.model_dir, 'version')
@@ -312,6 +316,20 @@ def _train_and_evaluate_impl(pipeline_config, continue_train=False):
if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
+ # support for datahub/kafka offset restore
+ final_ckpt = estimator_utils.latest_checkpoint(pipeline_config.model_dir)
+ if final_ckpt is not None:
+ final_offset_path = final_ckpt + '.offset'
+ logging.info('restore offset_info from %s' % final_offset_path)
+ if gfile.Exists(final_offset_path):
+ with gfile.GFile(final_offset_path) as fin:
+ offset_info = json.load(fin)
+ if train_data:
+ train_data.offset_info = json.dumps(offset_info)
+ if eval_data is not None:
+ eval_data.offset_info = json.dumps(ofset_info)
+
+
# create train input
train_input_fn = _get_input_fn(data_config, feature_configs, train_data,
**input_fn_kwargs)
@@ -362,10 +380,7 @@ def evaluate(pipeline_config,
pipeline_config.eval_input_path = eval_data_path
train_config = pipeline_config.train_config
- if pipeline_config.WhichOneof('eval_path') == 'kafka_eval_input':
- eval_data = pipeline_config.kafka_eval_input
- else:
- eval_data = pipeline_config.eval_input_path
+ eval_data = _get_input_object_by_name(pipeline_config, 'eval')
server_target = None
if 'TF_CONFIG' in os.environ:
@@ -726,6 +741,13 @@ def export(export_dir,
serving_input_fn = _get_input_fn(data_config, feature_configs, None,
export_config, **input_fn_kwargs)
if 'oss_path' in extra_params:
+ if pipeline_config.train_config.HasField('incr_save_config'):
+ incr_save_config = pipeline_config.train_config.incr_save_config
+ extra_params['incr_save'] = {}
+ if incr_save_config.HasField('kafka'):
+ extra_params['incr_save']['kafka'] = incr_save_config.kafka
+ if incr_save_config.HasField('datahub'):
+ extra_params['incr_save']['datahub'] = incr_save_config.datahub
return export_big_model_to_oss(export_dir, pipeline_config, extra_params,
serving_input_fn, estimator, checkpoint_path,
verbose)
diff --git a/easy_rec/python/model/easy_rec_estimator.py b/easy_rec/python/model/easy_rec_estimator.py
index 7cb7d56bf..69b6a13ac 100644
--- a/easy_rec/python/model/easy_rec_estimator.py
+++ b/easy_rec/python/model/easy_rec_estimator.py
@@ -29,7 +29,10 @@
from easy_rec.python.protos.train_pb2 import DistributionStrategy
from easy_rec.python.utils import estimator_utils
from easy_rec.python.utils import pai_util
+from easy_rec.python.utils import constant
from easy_rec.python.utils.multi_optimizer import MultiOptimizer
+from easy_rec.python.input.input import Input
+from tensorflow.python.platform import gfile
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -70,6 +73,11 @@ def eval_config(self):
def train_config(self):
return self._pipeline_config.train_config
+ @property
+ def incr_save_config(self):
+ return self.train_config.incr_save_config \
+ if self.train_config.HasField('incr_save_config') else None
+
@property
def export_config(self):
return self._pipeline_config.export_config
@@ -106,9 +114,25 @@ def _train_model_fn(self, features, labels, run_config):
for key in loss_dict:
tf.summary.scalar(key, loss_dict[key], family='loss')
+ if Input.DATA_OFFSET in features:
+ task_index, task_num = estimator_utils.get_task_index_and_num()
+ data_offset_var = tf.get_variable(name=Input.DATA_OFFSET, dtype=tf.string,
+ shape=[task_num],
+ collections=[tf.GraphKeys.GLOBAL_VARIABLES, Input.DATA_OFFSET],
+ trainable=False)
+ update_offset = tf.assign(data_offset_var[task_index], features[Input.DATA_OFFSET])
+ tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_offset)
+ else:
+ data_offset_var = None
+
# update op, usually used for batch-norm
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
+ # register for increment update, such as batchnorm moving_mean and moving_variance
+ global_vars = { x.name:x for x in tf.global_variables() }
+ for x in update_ops:
+ if x.inputs[0].name in global_vars:
+ ops.add_to_collection(constant.DENSE_UPDATE_VARIABLES, global_vars[x.inputs[0].name])
update_op = tf.group(*update_ops, name='update_barrier')
with tf.control_dependencies([update_op]):
loss = tf.identity(loss, name='total_loss')
@@ -231,7 +255,9 @@ def _train_model_fn(self, features, labels, run_config):
colocate_gradients_with_ops=True,
not_apply_grad_after_first_step=run_config.is_chief and
self._pipeline_config.data_config.chief_redundant,
- name='') # Preventing scope prefix on all variables.
+ name='', # Preventing scope prefix on all variables.
+ incr_save=(self.incr_save_config is not None))
+
# online evaluation
metric_update_op_dict = None
@@ -284,6 +310,7 @@ def format_fn(tensor_dict):
if self.train_config.train_distribute in [
DistributionStrategy.CollectiveAllReduceStrategy,
+ DistributionStrategy.MirroredStrategy,
DistributionStrategy.MultiWorkerMirroredStrategy
]:
# for multi worker strategy, we could not replace the
@@ -294,35 +321,36 @@ def format_fn(tensor_dict):
var_list = (
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
- initialize_var_list = [
- x for x in var_list if 'WorkQueue' not in str(type(x))
- ]
+
+ # exclude data_offset_var
+ var_list = [ x for x in var_list if x != data_offset_var ]
# early_stop flag will not be saved in checkpoint
# and could not be restored from checkpoint
early_stop_var = find_early_stop_var(var_list)
+ var_list = [x for x in var_list if x != early_stop_var]
+
+ initialize_var_list = [
+ x for x in var_list if 'WorkQueue' not in str(type(x))
+ ]
+
# incompatiable shape restore will not be saved in checkpoint
# but must be able to restore from checkpoint
incompatiable_shape_restore = tf.get_collection('T_E_M_P_RESTROE')
- if early_stop_var is not None:
- var_list = [x for x in var_list if x != early_stop_var]
- local_init_op = tf.group([
- tf.initializers.local_variables(),
- tf.initializers.variables([early_stop_var] +
- incompatiable_shape_restore)
- ])
- elif len(incompatiable_shape_restore) > 0:
- local_init_op = tf.group([
- tf.initializers.local_variables(),
- tf.initializers.variables(incompatiable_shape_restore)
- ])
- else:
- local_init_op = None
+
+ local_init_ops = [tf.train.Scaffold.default_local_init_op()]
+ if data_offset_var is not None and estimator_utils.is_chief():
+ local_init_ops.append(tf.initializers.variables([data_offset_var]))
+ if early_stop_var is not None and estimator_utils.is_chief():
+ local_init_ops.append(tf.initializers.variables([early_stop_var]))
+ if len(incompatiable_shape_restore) > 0:
+ local_init_ops.append(tf.initializers.variables(incompatiable_shape_restore))
+
scaffold = tf.train.Scaffold(
saver=tf.train.Saver(
var_list=var_list,
sharded=True,
max_to_keep=self.train_config.keep_checkpoint_max),
- local_init_op=local_init_op,
+ local_init_op=tf.group(local_init_ops),
ready_for_local_init_op=tf.report_uninitialized_variables(
var_list=initialize_var_list))
# saver hook
@@ -331,11 +359,19 @@ def format_fn(tensor_dict):
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=scaffold,
- write_graph=self.train_config.write_graph)
+ write_graph=self.train_config.write_graph,
+ data_offset_var=data_offset_var,
+ increment_save_config=self.incr_save_config)
chief_hooks = []
if estimator_utils.is_chief():
hooks.append(saver_hook)
+ # oss stop signal hook
+ if self.train_config.enable_oss_stop_signal:
+ oss_stop_signal = estimator_utils.OssStopSignalHook(
+ model_dir=self.model_dir)
+ hooks.append(oss_stop_signal)
+
# profiling hook
if self.train_config.is_profiling and estimator_utils.is_chief():
profile_hook = tf.train.ProfilerHook(
@@ -461,13 +497,25 @@ def _export_model_fn(self, features, labels, run_config, params):
# save train pipeline.config for debug purpose
pipeline_path = os.path.join(self._model_dir, 'pipeline.config')
- if tf.gfile.Exists(pipeline_path):
+ if gfile.Exists(pipeline_path):
tf.add_to_collection(
tf.GraphKeys.ASSET_FILEPATHS,
tf.constant(pipeline_path, dtype=tf.string, name='pipeline.config'))
else:
print('train pipeline_path(%s) does not exist' % pipeline_path)
+ # restore DENSE_UPDATE_VARIABLES collection
+ dense_train_var_path = os.path.join(self.model_dir, constant.DENSE_UPDATE_VARIABLES)
+ if gfile.Exists(dense_train_var_path):
+ with gfile.GFile(dense_train_var_path, 'r') as fin:
+ var_name_to_id_map = json.load(fin)
+ var_name_id_lst = [ (x, var_name_to_id_map[x]) for x in var_name_to_id_map ]
+ var_name_id_lst.sort(key=lambda x : x[1])
+ all_vars = { x.op.name:x for x in tf.global_variables() }
+ for var_name, var_id in var_name_id_lst:
+ assert var_name in all_vars, 'dense_train_var[%s] is not found' % var_name
+ tf.add_to_collection(constant.DENSE_UPDATE_VARIABLES, all_vars[var_name])
+
# add more asset files
if 'asset_files' in params:
for asset_name in params['asset_files']:
@@ -505,7 +553,7 @@ def _write_rtp_fg_config_to_col(fg_config=None, fg_config_path=None):
fg_config_path: path to the RTP config file.
"""
if fg_config is None:
- with tf.gfile.GFile(fg_config_path, 'r') as f:
+ with gfile.GFile(fg_config_path, 'r') as f:
fg_config = json.load(f)
col = ops.get_collection_ref(GraphKeys.RANK_SERVICE_FG_CONF)
if len(col) == 0:
diff --git a/easy_rec/python/ops/1.12/incr_record.so b/easy_rec/python/ops/1.12/incr_record.so
new file mode 100755
index 000000000..3f258e06e
Binary files /dev/null and b/easy_rec/python/ops/1.12/incr_record.so differ
diff --git a/easy_rec/python/ops/1.12/kafka.so b/easy_rec/python/ops/1.12/kafka.so
index d5b33cc46..42164529a 100755
Binary files a/easy_rec/python/ops/1.12/kafka.so and b/easy_rec/python/ops/1.12/kafka.so differ
diff --git a/easy_rec/python/ops/1.12/libembed_op.so b/easy_rec/python/ops/1.12/libembed_op.so
index 5f46ee7f8..8a41da0b2 100644
Binary files a/easy_rec/python/ops/1.12/libembed_op.so and b/easy_rec/python/ops/1.12/libembed_op.so differ
diff --git a/easy_rec/python/ops/1.12/librdkafka++.so.1 b/easy_rec/python/ops/1.12/librdkafka++.so.1
new file mode 100755
index 000000000..8a448378c
Binary files /dev/null and b/easy_rec/python/ops/1.12/librdkafka++.so.1 differ
diff --git a/easy_rec/python/ops/1.12/librdkafka.so.1 b/easy_rec/python/ops/1.12/librdkafka.so.1
new file mode 100755
index 000000000..c7ab65e96
Binary files /dev/null and b/easy_rec/python/ops/1.12/librdkafka.so.1 differ
diff --git a/easy_rec/python/ops/1.12_pai/__init__.py b/easy_rec/python/ops/1.12_pai/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/easy_rec/python/ops/1.12_pai/incr_record.so b/easy_rec/python/ops/1.12_pai/incr_record.so
new file mode 100755
index 000000000..e6c0d42b0
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/incr_record.so differ
diff --git a/easy_rec/python/ops/1.12_pai/kafka.so b/easy_rec/python/ops/1.12_pai/kafka.so
new file mode 100755
index 000000000..2df02c3a5
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/kafka.so differ
diff --git a/easy_rec/python/ops/1.12_pai/kafka.so.bak b/easy_rec/python/ops/1.12_pai/kafka.so.bak
new file mode 100755
index 000000000..e232014c0
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/kafka.so.bak differ
diff --git a/easy_rec/python/ops/1.12_pai/libembed_op.so b/easy_rec/python/ops/1.12_pai/libembed_op.so
new file mode 100644
index 000000000..5f46ee7f8
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libembed_op.so differ
diff --git a/easy_rec/python/ops/1.12_pai/libhiredis.so.1.0.0 b/easy_rec/python/ops/1.12_pai/libhiredis.so.1.0.0
new file mode 100644
index 000000000..63ae04d40
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libhiredis.so.1.0.0 differ
diff --git a/easy_rec/python/ops/1.12_pai/libkafka.so b/easy_rec/python/ops/1.12_pai/libkafka.so
new file mode 100755
index 000000000..566ce198b
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libkafka.so differ
diff --git a/easy_rec/python/ops/1.12_pai/librdkafka++.so.1 b/easy_rec/python/ops/1.12_pai/librdkafka++.so.1
new file mode 100755
index 000000000..8a448378c
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/librdkafka++.so.1 differ
diff --git a/easy_rec/python/ops/1.12_pai/librdkafka.so.1 b/easy_rec/python/ops/1.12_pai/librdkafka.so.1
new file mode 100755
index 000000000..c7ab65e96
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/librdkafka.so.1 differ
diff --git a/easy_rec/python/ops/1.12_pai/libredis++.so b/easy_rec/python/ops/1.12_pai/libredis++.so
new file mode 100644
index 000000000..cadfccc27
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libredis++.so differ
diff --git a/easy_rec/python/ops/1.12_pai/libredis++.so.1 b/easy_rec/python/ops/1.12_pai/libredis++.so.1
new file mode 100644
index 000000000..cadfccc27
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libredis++.so.1 differ
diff --git a/easy_rec/python/ops/1.12_pai/libredis++.so.1.2.3 b/easy_rec/python/ops/1.12_pai/libredis++.so.1.2.3
new file mode 100644
index 000000000..cadfccc27
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libredis++.so.1.2.3 differ
diff --git a/easy_rec/python/ops/1.12_pai/libwrite_sparse_kv.so b/easy_rec/python/ops/1.12_pai/libwrite_sparse_kv.so
new file mode 100755
index 000000000..d50ee8edc
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libwrite_sparse_kv.so differ
diff --git a/easy_rec/python/ops/1.15/incr_record.so b/easy_rec/python/ops/1.15/incr_record.so
new file mode 100755
index 000000000..e92ea0f36
Binary files /dev/null and b/easy_rec/python/ops/1.15/incr_record.so differ
diff --git a/easy_rec/python/ops/1.15/kafka.so b/easy_rec/python/ops/1.15/kafka.so
index 3ba64834a..6886446d8 100755
Binary files a/easy_rec/python/ops/1.15/kafka.so and b/easy_rec/python/ops/1.15/kafka.so differ
diff --git a/easy_rec/python/ops/1.15/libembed_op.so b/easy_rec/python/ops/1.15/libembed_op.so
index 69d396100..4d8f6275c 100755
Binary files a/easy_rec/python/ops/1.15/libembed_op.so and b/easy_rec/python/ops/1.15/libembed_op.so differ
diff --git a/easy_rec/python/ops/1.15/librdkafka++.so b/easy_rec/python/ops/1.15/librdkafka++.so
new file mode 100755
index 000000000..969f8ab1d
Binary files /dev/null and b/easy_rec/python/ops/1.15/librdkafka++.so differ
diff --git a/easy_rec/python/ops/1.15/librdkafka++.so.1 b/easy_rec/python/ops/1.15/librdkafka++.so.1
new file mode 100755
index 000000000..969f8ab1d
Binary files /dev/null and b/easy_rec/python/ops/1.15/librdkafka++.so.1 differ
diff --git a/easy_rec/python/ops/1.15/librdkafka.so b/easy_rec/python/ops/1.15/librdkafka.so
new file mode 100755
index 000000000..c83248971
Binary files /dev/null and b/easy_rec/python/ops/1.15/librdkafka.so differ
diff --git a/easy_rec/python/ops/1.15/librdkafka.so.1 b/easy_rec/python/ops/1.15/librdkafka.so.1
new file mode 100755
index 000000000..c83248971
Binary files /dev/null and b/easy_rec/python/ops/1.15/librdkafka.so.1 differ
diff --git a/easy_rec/python/ops/gen_kafka_ops.py b/easy_rec/python/ops/gen_kafka_ops.py
new file mode 100644
index 000000000..d971f4563
--- /dev/null
+++ b/easy_rec/python/ops/gen_kafka_ops.py
@@ -0,0 +1,189 @@
+"""Python wrappers around TensorFlow ops.
+
+This file is MACHINE GENERATED! Do not edit.
+Original C++ source file: kafka_ops_deprecated.cc
+"""
+
+import os
+import logging
+
+import six as _six
+import tensorflow as tf
+from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
+from tensorflow.python.eager import context as _context
+from tensorflow.python.eager import core as _core
+from tensorflow.python.eager import execute as _execute
+
+# Needed to trigger the call to _set_call_cpp_shape_fn.
+from tensorflow.python.framework import dtypes as _dtypes
+from tensorflow.python.framework import ops as _ops
+from tensorflow.python.util.tf_export import tf_export
+import easy_rec
+
+
+try:
+ kafka_module = tf.load_op_library(os.path.join(easy_rec.ops_dir, 'kafka.so'))
+except Exception as ex:
+ logging.error("failed to load kafka.so: %s" % str(ex))
+ kafka_module = None
+
+
+@tf_export('io_kafka_dataset_v2')
+def io_kafka_dataset_v2(topics,
+ servers,
+ group,
+ eof,
+ timeout,
+ config_global,
+ config_topic,
+ message_key,
+ message_offset,
+ name=None):
+ """Creates a dataset that emits the messages of one or more Kafka topics.
+
+ Args:
+ topics: A `Tensor` of type `string`.
+ A `tf.string` tensor containing one or more subscriptions,
+ in the format of [topic:partition:offset].
+ servers: A `Tensor` of type `string`. A list of bootstrap servers.
+ group: A `Tensor` of type `string`. The consumer group id.
+ eof: A `Tensor` of type `bool`.
+ If True, the kafka reader will stop on EOF.
+ timeout: A `Tensor` of type `int64`.
+ The timeout value for the Kafka Consumer to wait
+ (in millisecond).
+ config_global: A `Tensor` of type `string`.
+ A `tf.string` tensor containing global configuration
+ properties in [Key=Value] format,
+ eg. ["enable.auto.commit=false", "heartbeat.interval.ms=2000"],
+ please refer to 'Global configuration properties' in librdkafka doc.
+ config_topic: A `Tensor` of type `string`.
+ A `tf.string` tensor containing topic configuration
+ properties in [Key=Value] format, eg. ["auto.offset.reset=earliest"],
+ please refer to 'Topic configuration properties' in librdkafka doc.
+ message_key: A `Tensor` of type `bool`.
+ message_offset: A `Tensor` of type `bool`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `variant`.
+ """
+ return kafka_module.io_kafka_dataset_v2(
+ topics=topics,
+ servers=servers,
+ group=group,
+ eof=eof,
+ timeout=timeout,
+ config_global=config_global,
+ config_topic=config_topic,
+ message_key=message_key,
+ message_offset=message_offset,
+ name=name)
+
+
+def io_kafka_dataset_eager_fallback(topics,
+ servers,
+ group,
+ eof,
+ timeout,
+ config_global,
+ config_topic,
+ message_key,
+ message_offset,
+ name=None,
+ ctx=None):
+ """This is the slowpath function for Eager mode.
+
+ This is for function io_kafka_dataset
+ """
+ _ctx = ctx if ctx else _context.context()
+ topics = _ops.convert_to_tensor(topics, _dtypes.string)
+ servers = _ops.convert_to_tensor(servers, _dtypes.string)
+ group = _ops.convert_to_tensor(group, _dtypes.string)
+ eof = _ops.convert_to_tensor(eof, _dtypes.bool)
+ timeout = _ops.convert_to_tensor(timeout, _dtypes.int64)
+ config_global = _ops.convert_to_tensor(config_global, _dtypes.string)
+ config_topic = _ops.convert_to_tensor(config_topic, _dtypes.string)
+ message_key = _ops.convert_to_tensor(message_key, _dtypes.bool)
+ message_offset = _ops.convert_to_tensor(message_offset, _dtypes.bool)
+ _inputs_flat = [
+ topics, servers, group, eof, timeout, config_global, config_topic,
+ message_key, message_offset
+ ]
+ _attrs = None
+ _result = _execute.execute(
+ b'IOKafkaDataset',
+ 1,
+ inputs=_inputs_flat,
+ attrs=_attrs,
+ ctx=_ctx,
+ name=name)
+ _execute.record_gradient('IOKafkaDataset', _inputs_flat, _attrs, _result,
+ name)
+ _result, = _result
+ return _result
+
+
+@tf_export('io_write_kafka_v2')
+def io_write_kafka_v2(message, topic, servers, name=None):
+ r"""TODO: add doc.
+
+ Args:
+ message: A `Tensor` of type `string`.
+ topic: A `Tensor` of type `string`.
+ servers: A `Tensor` of type `string`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `string`.
+ """
+ _ctx = _context._context
+ if _ctx is None or not _ctx._eager_context.is_eager:
+ _op = kafka_module.io_write_kafka_v2(
+ message=message, topic=topic, servers=servers, name=name)
+ _result = _op.outputs[:]
+ _inputs_flat = _op.inputs
+ _attrs = None
+ _execute.record_gradient('IOWriteKafka', _inputs_flat, _attrs, _result,
+ name)
+ _result, = _result
+ return _result
+
+ else:
+ try:
+ _result = _pywrap_tensorflow.TFE_Py_FastPathExecute(
+ _ctx._context_handle, _ctx._eager_context.device_name, 'IOWriteKafka',
+ name, _ctx._post_execution_callbacks, message, topic, servers)
+ return _result
+ except _core._FallbackException:
+ return io_write_kafka_eager_fallback(
+ message, topic, servers, name=name, ctx=_ctx)
+ except _core._NotOkStatusException as e:
+ if name is not None:
+ message = e.message + ' name: ' + name
+ else:
+ message = e.message
+ _six.raise_from(_core._status_to_exception(e.code, message), None)
+
+
+def io_write_kafka_eager_fallback(message, topic, servers, name=None, ctx=None):
+ """This is the slowpath function for Eager mode.
+
+ This is for function io_write_kafka
+ """
+ _ctx = ctx if ctx else _context.context()
+ message = _ops.convert_to_tensor(message, _dtypes.string)
+ topic = _ops.convert_to_tensor(topic, _dtypes.string)
+ servers = _ops.convert_to_tensor(servers, _dtypes.string)
+ _inputs_flat = [message, topic, servers]
+ _attrs = None
+ _result = _execute.execute(
+ b'IOWriteKafka',
+ 1,
+ inputs=_inputs_flat,
+ attrs=_attrs,
+ ctx=_ctx,
+ name=name)
+ _execute.record_gradient('IOWriteKafka', _inputs_flat, _attrs, _result, name)
+ _result, = _result
+ return _result
diff --git a/easy_rec/python/ops/incr_record.py b/easy_rec/python/ops/incr_record.py
new file mode 100644
index 000000000..eee1f9e17
--- /dev/null
+++ b/easy_rec/python/ops/incr_record.py
@@ -0,0 +1,20 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import logging
+import easy_rec
+import tensorflow as tf
+
+try:
+ op_path = os.path.join(easy_rec.ops_dir, "incr_record.so")
+ op = tf.load_op_library(op_path)
+ get_sparse_indices = op.get_sparse_indices
+ set_sparse_indices = op.set_sparse_indices
+except ImportError as ex:
+ get_sparse_indices = None
+ set_sparse_indices = None
+ logging.warning('failed to import gen_io_ops.collect_sparse_indices: %s' % str(ex))
+except Exception as ex:
+ get_sparse_indices = None
+ set_sparse_indices = None
+ logging.warning('failed to import gen_io_ops.collect_sparse_indices: %s' % str(ex))
diff --git a/easy_rec/python/protos/data_source.proto b/easy_rec/python/protos/data_source.proto
index a05134d12..d30e04a96 100644
--- a/easy_rec/python/protos/data_source.proto
+++ b/easy_rec/python/protos/data_source.proto
@@ -5,8 +5,12 @@ message KafkaServer {
required string server = 1;
required string topic = 2;
required string group = 3;
- required uint32 partitions = 4;
- repeated uint32 offset = 5;
+ // in json format: {'0':10, '1':20}
+ optional string offset_info = 4;
+ // kafka global config, such as: fetch.max.bytes=1024
+ repeated string config_global = 5;
+ // kafka topic config, such as: max.partition.fetch.bytes=1024
+ repeated string config_topic = 6;
}
message DatahubServer{
@@ -15,6 +19,5 @@ message DatahubServer{
required string region = 3;
required string project = 4;
required string topic = 5;
- required uint32 shard_num = 6;
- required uint32 life_cycle = 7;
+ optional string offset_info = 6;
}
diff --git a/easy_rec/python/protos/train.proto b/easy_rec/python/protos/train.proto
index b86e37a72..536fc614d 100644
--- a/easy_rec/python/protos/train.proto
+++ b/easy_rec/python/protos/train.proto
@@ -21,6 +21,44 @@ enum DistributionStrategy {
MultiWorkerMirroredStrategy = 5;
}
+message IncrementSaveConfig {
+ message Kafka {
+ message Consumer {
+ optional string config_topic = 1;
+ optional string config_global = 2;
+ optional int64 offset = 3 [default=0];
+ optional int32 timeout = 4 [default=600];
+ }
+ required string server = 1;
+ required string topic = 2;
+ required Consumer consumer = 3;
+ }
+
+ message Datahub {
+ message Consumer {
+ optional int64 offset = 1 [default=0];
+ optional int32 timeout = 2 [default=600];
+ }
+ required string akId = 1;
+ required string akSecret = 2;
+ required string region = 3;
+ required string project = 4;
+ required string topic = 5;
+ required Consumer consumer = 6;
+ }
+
+
+ optional int32 sparse_save_secs = 1 [default=0];
+ optional int32 dense_save_secs = 2 [default=0];
+ optional int32 sparse_save_steps = 3 [default=0];
+ optional int32 dense_save_steps = 4 [default=0];
+
+ oneof incr_update_hub {
+ Kafka kafka = 501;
+ Datahub datahub = 502;
+ }
+}
+
// Message for configuring EasyRecModel training jobs (train.py).
// Next id: 25
message TrainConfig {
@@ -107,4 +145,11 @@ message TrainConfig {
// match variable patterns to freeze
repeated string freeze_gradient = 30;
+
+ // increment save config
+ optional IncrementSaveConfig incr_save_config = 31;
+
+ // enable oss stop signal
+ // stop by create OSS_STOP_SIGNAL under model_dir
+ optional bool enable_oss_stop_signal = 32 [default = false];
}
diff --git a/easy_rec/python/test/dh_local_run.py b/easy_rec/python/test/dh_local_run.py
index a4282f891..c22fbaddb 100644
--- a/easy_rec/python/test/dh_local_run.py
+++ b/easy_rec/python/test/dh_local_run.py
@@ -37,7 +37,8 @@ def test_datahub_train_eval(self):
odps_cmd = OdpsCommand(odps_oss_config)
self._success = test_utils.test_datahub_train_eval(
- '%s/configs/deepfm.config' % odps_oss_config.temp_dir, self._test_dir)
+ '%s/configs/deepfm.config' % odps_oss_config.temp_dir,
+ odps_oss_config, self._test_dir)
odps_cmd.run_list(end)
self.assertTrue(self._success)
@@ -48,8 +49,6 @@ def test_datahub_train_eval(self):
'--odps_config', type=str, default=None, help='odps config path')
parser.add_argument(
'--oss_config', type=str, default=None, help='ossutilconfig path')
- parser.add_argument(
- '--datahub_config', type=str, default=None, help='datahub_config')
parser.add_argument(
'--bucket_name', type=str, default=None, help='test oss bucket name')
parser.add_argument('--arn', type=str, default=None, help='oss rolearn')
@@ -73,8 +72,6 @@ def test_datahub_train_eval(self):
if args.odps_config:
odps_oss_config.load_odps_config(args.odps_config)
os.environ['ODPS_CONFIG_FILE_PATH'] = args.odps_config
- if args.datahub_config:
- odps_oss_config.load_dh_config(args.datahub_config)
if args.oss_config:
odps_oss_config.load_oss_config(args.oss_config)
if args.odpscmd:
@@ -89,7 +86,6 @@ def test_datahub_train_eval(self):
odps_oss_config.arn = args.arn
if args.bucket_name:
odps_oss_config.bucket_name = args.bucket_name
- print(args)
prepare(odps_oss_config)
start = [
'deep_fm/create_external_deepfm_table.sql',
diff --git a/easy_rec/python/test/export_test.py b/easy_rec/python/test/export_test.py
index 05b0c4aa9..99f5e2f0c 100644
--- a/easy_rec/python/test/export_test.py
+++ b/easy_rec/python/test/export_test.py
@@ -440,8 +440,7 @@ def _test_big_model_export_to_oss(self,
--input_path %s
--output_path %s
""" % (config_path, test_data_path, result_path)
- proc = test_utils.run_cmd(predict_cmd % (),
- '%s/log_%s.txt' % (test_dir, 'predict'))
+ proc = test_utils.run_cmd(predict_cmd, '%s/log_%s.txt' % (test_dir, 'predict'))
proc.wait()
self.assertTrue(proc.returncode == 0)
with open(result_path, 'r') as fin:
diff --git a/easy_rec/python/test/kafka_test.py b/easy_rec/python/test/kafka_test.py
new file mode 100644
index 000000000..8dfd8a23e
--- /dev/null
+++ b/easy_rec/python/test/kafka_test.py
@@ -0,0 +1,317 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import numpy as np
+import os
+import json
+import time
+import logging
+import unittest
+import traceback
+import threading
+import six
+
+import tensorflow as tf
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.platform import gfile
+
+from easy_rec.python.inference.predictor import Predictor
+from easy_rec.python.input.kafka_dataset import KafkaDataset
+from easy_rec.python.utils import test_utils
+
+try:
+ import kafka
+ from kafka import KafkaProducer, KafkaAdminClient
+ from kafka.admin import NewTopic
+except ImportError as ex:
+ logging.warning('kafka-python is not installed: %s' % traceback.format_exc(ex))
+
+
+class KafkaTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._success = True
+ self._test_dir = test_utils.get_tmp_dir()
+ if self._testMethodName == 'test_session':
+ self._kafka_server_proc = None
+ self._zookeeper_proc = None
+ return
+
+ logging.info('Testing %s.%s, test_dir=%s' % (type(self).__name__, self._testMethodName,
+ self._test_dir))
+ self._log_dir = os.path.join(self._test_dir, 'logs')
+ if not gfile.IsDirectory(self._log_dir):
+ gfile.MakeDirs(self._log_dir)
+
+ self._kafka_servers = ['127.0.0.1:9092']
+ self._test_topic = 'kafka_op_test_topic'
+
+ if 'kafka_install_dir' in os.environ:
+ kafka_install_dir = os.environ.get('kafka_install_dir', None)
+
+ zookeeper_config_raw = '%s/config/zookeeper.properties' % kafka_install_dir
+ zookeeper_config = os.path.join(self._test_dir, 'zookeeper.properties')
+ with open(zookeeper_config, 'w') as fout:
+ with open(zookeeper_config_raw, 'r') as fin:
+ for line_str in fin:
+ if line_str.startswith('dataDir='):
+ fout.write('dataDir=%s/zookeeper\n' % self._test_dir)
+ else:
+ fout.write(line_str)
+ cmd = 'bash %s/bin/zookeeper-server-start.sh %s' % (
+ kafka_install_dir, zookeeper_config)
+ log_file = os.path.join(self._log_dir, 'zookeeper.log')
+ self._zookeeper_proc = test_utils.run_cmd(cmd, log_file)
+
+ kafka_config_raw = '%s/config/server.properties' % kafka_install_dir
+ kafka_config = os.path.join(self._test_dir, 'server.properties')
+ with open(kafka_config, 'w') as fout:
+ with open(kafka_config_raw, 'r') as fin:
+ for line_str in fin:
+ if line_str.startswith('log.dirs='):
+ fout.write('log.dirs=%s/kafka\n' % self._test_dir)
+ else:
+ fout.write(line_str)
+ cmd = 'bash %s/bin/kafka-server-start.sh %s' % (
+ kafka_install_dir, kafka_config)
+ log_file = os.path.join(self._log_dir, 'kafka_server.log')
+ self._kafka_server_proc = test_utils.run_cmd(cmd, log_file)
+
+ started = False
+ while not started:
+ if self._kafka_server_proc.poll() and self._kafka_server_proc.returncode:
+ logging.warning('start kafka server failed, will retry.')
+ os.system('cat %s' % log_file)
+ self._kafka_server_proc = test_utils.run_cmd(cmd, log_file)
+ time.sleep(5)
+ else:
+ try:
+ admin_clt = KafkaAdminClient(bootstrap_servers=self._kafka_servers)
+ logging.info('old topics: %s' % (','.join(admin_clt.list_topics())))
+ admin_clt.close()
+ started = True
+ except kafka.errors.NoBrokersAvailable:
+ time.sleep(2)
+ self._create_topic()
+ else:
+ self._zookeeper_proc = None
+ self._kafka_server_proc = None
+ self._should_stop = False
+ self._producer = None
+
+ def _create_topic(self, num_partitions=2):
+ admin_clt = KafkaAdminClient(bootstrap_servers=self._kafka_servers)
+
+ logging.info('create topic: %s' % self._test_topic)
+ topic_list = [NewTopic(name=self._test_topic, num_partitions=num_partitions,
+ replication_factor=1)]
+ admin_clt.create_topics(new_topics=topic_list, validate_only=False)
+ logging.info('all topics: %s' % (','.join(admin_clt.list_topics())))
+ admin_clt.close()
+
+ def _create_producer(self, generate_func):
+ # start produce thread
+
+ prod = threading.Thread(target=generate_func)
+ prod.start()
+ return prod
+
+ def _stop_producer(self):
+ if self._producer is not None:
+ self._should_stop = True
+ self._producer.join()
+
+ def tearDown(self):
+ try:
+ self._stop_producer()
+ if self._kafka_server_proc is not None:
+ self._kafka_server_proc.terminate()
+ except Exception as ex:
+ logging.warning('exception terminate kafka proc: %s' % str(ex))
+
+ try:
+ if self._zookeeper_proc is not None:
+ self._zookeeper_proc.terminate()
+ except Exception as ex:
+ logging.warning('exception terminate zookeeper proc: %s' % str(ex))
+
+ test_utils.set_gpu_id(None)
+ if self._success:
+ test_utils.clean_up(self._test_dir)
+
+ @unittest.skipIf('kafka_install_dir' not in os.environ, 'Only execute when kafka is available')
+ def test_kafka_ops(self):
+ try:
+ test_utils.set_gpu_id(None)
+
+ def _generate():
+ producer = KafkaProducer(
+ bootstrap_servers=self._kafka_servers, api_version=(0, 10, 1))
+ i = 0
+ while not self._should_stop:
+ msg = 'user_id_%d' % i
+ producer.send(self._test_topic, msg)
+ producer.close()
+
+ self._producer = self._create_producer(_generate)
+
+ group = 'dataset_consumer'
+ k = KafkaDataset(
+ servers=self._kafka_servers[0],
+ topics=[self._test_topic + ':0', self._test_topic + ':1'],
+ group=group,
+ eof=True,
+ # control the maximal read of each partition
+ config_global=['max.partition.fetch.bytes=1048576'],
+ message_key=True,
+ message_offset=True)
+
+ batch_dataset = k.batch(5)
+
+ iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
+ init_batch_op = iterator.make_initializer(batch_dataset)
+ get_next = iterator.get_next()
+
+ sess = tf.Session()
+ sess.run(init_batch_op)
+
+ p = sess.run(get_next)
+
+ self.assertEquals(len(p), 3)
+ offset = p[2]
+ self.assertEquals(offset[0], '0:0')
+ self.assertEquals(offset[1], '0:1')
+
+ p = sess.run(get_next)
+ offset = p[2]
+ self.assertEquals(offset[0], '0:5')
+ self.assertEquals(offset[1], '0:6')
+
+ max_iter = 300
+ while max_iter > 0:
+ sess.run(get_next)
+ max_iter -= 1
+ except tf.errors.OutOfRangeError as ex:
+ pass
+ except Exception as ex:
+ self._success = False
+ raise ex
+
+ @unittest.skipIf('kafka_install_dir' not in os.environ, 'Only execute when kafka is available')
+ def test_kafka_train(self):
+ try:
+ # start produce thread
+ def _generate():
+ producer = KafkaProducer(
+ bootstrap_servers=self._kafka_servers, api_version=(0, 10, 1))
+ while not self._should_stop:
+ with open('data/test/dwd_avazu_ctr_deepmodel_10w.csv', 'r') as fin:
+ for line_str in fin:
+ line_str = line_str.strip()
+ if self._should_stop:
+ break
+ if six.PY3:
+ line_str = line_str.encode('utf-8')
+ producer.send(self._test_topic, line_str)
+ producer.close()
+ logging.info('data generation thread done.')
+
+ self._producer = self._create_producer(_generate)
+
+ test_utils.set_gpu_id(None)
+
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_combo_avazu_kafka.config', self._test_dir)
+
+ self.assertTrue(self._success)
+ except Exception as ex:
+ self._success = False
+ raise ex
+
+ @unittest.skipIf('kafka_install_dir' not in os.environ, 'Only execute when kafka is available')
+ def test_kafka_train_v2(self):
+ try:
+ # start produce thread
+ def _generate():
+ producer = KafkaProducer(
+ bootstrap_servers=self._kafka_servers, api_version=(0, 10, 1))
+ while not self._should_stop:
+ with open('data/test/dwd_avazu_ctr_deepmodel_10w.csv', 'r') as fin:
+ for line_str in fin:
+ line_str = line_str.strip()
+ if self._should_stop:
+ break
+ if six.PY3:
+ line_str = line_str.encode('utf-8')
+ producer.send(self._test_topic, line_str)
+ producer.close()
+ logging.info('data generation thread done.')
+
+ self._producer = self._create_producer(_generate)
+
+ test_utils.set_gpu_id(None)
+
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_combo_avazu_kafka_time_offset.config', self._test_dir)
+
+ self.assertTrue(self._success)
+ except Exception as ex:
+ self._success = False
+ raise ex
+
+ @unittest.skipIf('kafka_install_dir' not in os.environ or 'oss_path' not in os.environ \
+ or 'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ \
+ or 'oss_sk' not in os.environ, 'Only execute when kafka is available')
+ def test_kafka_processor(self):
+ self._success = False
+ success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/taobao_fg_incr_save.config', self._test_dir)
+ self.assertTrue(success)
+ export_cmd = """
+ python -m easy_rec.python.export --pipeline_config_path %s/pipeline.config
+ --export_dir %s/export/sep/ --oss_path=%s --oss_ak=%s --oss_sk=%s --oss_endpoint=%s
+ --asset_files ./samples/rtp_fg/fg.json
+ --checkpoint_path %s/train/model.ckpt-0
+ """ % (self._test_dir, self._test_dir, os.environ['oss_path'], os.environ['oss_ak'],
+ os.environ['oss_sk'], os.environ['oss_endpoint'], self._test_dir)
+ proc = test_utils.run_cmd(export_cmd, '%s/log_export_sep.txt' % self._test_dir)
+ proc.wait()
+ self.assertTrue(proc.returncode == 0)
+ files = gfile.Glob(os.path.join(self._test_dir, 'export/sep/[1-9][0-9]*'))
+ export_sep_dir = files[0]
+
+ predict_cmd = """
+ python processor/test.py --saved_model_dir %s
+ --input_path data/test/rtp/taobao_test_feature.txt
+ --output_path %s/processor.out
+ --data_config processor/dataset.config
+ """ % (export_sep_dir, self._test_dir)
+ proc = test_utils.run_cmd(predict_cmd, '%s/log_processor.txt' % self._test_dir)
+ proc.wait()
+ self.assertTrue(proc.returncode == 0)
+
+ with open('%s/processor.out' % self._test_dir, 'r') as fin:
+ processor_out = []
+ for line_str in fin:
+ line_str = line_str.strip()
+ processor_out.append(json.loads(line_str))
+
+ predictor = Predictor(os.path.join(self._test_dir, 'train/export/final'))
+ with open('data/test/rtp/taobao_test_feature.txt', 'r') as fin:
+ inputs = []
+ for line_str in fin:
+ line_str = line_str.strip()
+ line_tok = line_str.split(';')[-1]
+ line_tok = line_tok.split(chr(2))
+ inputs.append(line_tok)
+ output_res = predictor.predict(inputs, batch_size=32)
+
+ for i in range(len(output_res)):
+ val0 = output_res[i]['probs']
+ val1 = processor_out[i]['probs']
+ diff = np.abs(val0 - val1)
+ assert diff < 1e-4, 'too much difference[%.6f] >= 1e-4' % diff
+ self._success = True
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/easy_rec/python/test/odps_test_prepare.py b/easy_rec/python/test/odps_test_prepare.py
index 775b6d1ce..de5294a0d 100644
--- a/easy_rec/python/test/odps_test_prepare.py
+++ b/easy_rec/python/test/odps_test_prepare.py
@@ -138,7 +138,7 @@ def put_data_to_bucket(odps_oss_config):
odps_oss_config.oss_secret,
odps_oss_config.endpoint,
odps_oss_config.bucket_name)
- for sub_dir in ['configs']: #, 'test_data']:
+ for sub_dir in ['configs']:
for root, dirs, files in os.walk(
os.path.join(odps_oss_config.temp_dir, sub_dir)):
for one_file in files:
diff --git a/easy_rec/python/test/odps_test_util.py b/easy_rec/python/test/odps_test_util.py
index f9bb1965f..cc7f9781c 100644
--- a/easy_rec/python/test/odps_test_util.py
+++ b/easy_rec/python/test/odps_test_util.py
@@ -59,16 +59,16 @@ def __init__(self, script_path='./samples/odps_script'):
self.odpscmd_path = os.environ.get('ODPS_CMD_PATH', 'odpscmd')
self.odps_config_path = ''
- # input table project name replace {ODPS_PROJ_NAME} in
- # samples/odps_script:
- # grep ODPS_PROJ_NAME -r samples/odps_script/
+
self.project_name = ''
self.dh_id = ''
self.dh_key = ''
- self.dh_endpoint = ''
- self.dh_topic = ''
- self.dh_project = ''
+
+ self.dh_endpoint = 'https://dh-cn-beijing.aliyuncs.com'
+ self.dh_topic = 'easy_rec_test'
+ self.dh_project = 'easy_rec_test'
+
self.odps_endpoint = ''
self.dh = None
@@ -83,17 +83,6 @@ def __init__(self, script_path='./samples/odps_script'):
# the difference are ossHost buckets arn settings
self.is_outer = True
- def load_dh_config(self, config_path):
- import pdb
- pdb.set_trace()
- configer = configparser.ConfigParser()
- configer.read(config_path, encoding='utf-8')
- self.dh_id = configer.get('datahub', 'access_id')
- self.dh_key = configer.get('datahub', 'access_key')
- self.dh_endpoint = configer.get('datahub', 'endpoint')
- self.dh_topic = configer.get('datahub', 'topic_name')
- self.dh_project = configer.get('datahub', 'project')
-
def load_oss_config(self, config_path):
with open(config_path, 'r') as fin:
for line_str in fin:
@@ -112,10 +101,18 @@ def load_odps_config(self, config_path):
for line_str in fin:
line_str = line_str.strip()
line_str = line_str.replace(' ', '')
- if line_str.startswith('project_name='):
- self.project_name = line_str[len('project_name='):]
- if line_str.startswith('end_point='):
- self.odps_endpoint = line_str[len('end_point='):]
+ key_str = 'project_name='
+ if line_str.startswith(key_str):
+ self.project_name = line_str[len(key_str):]
+ key_str = 'end_point='
+ if line_str.startswith(key_str):
+ self.odps_endpoint = line_str[len(key_str):]
+ key_str = 'access_id='
+ if line_str.startswith(key_str):
+ self.dh_id = line_str[len(key_str):]
+ key_str = 'access_key='
+ if line_str.startswith(key_str):
+ self.dh_key = line_str[len(key_str):]
def clean_topic(self, dh_project):
if not dh_project:
@@ -160,47 +157,44 @@ def init_dh_and_odps(self):
self.odpsTable = 'deepfm_train_%s' % self.time_stamp
self.clean_project()
read_odps = DataFrame(self.odps.get_table(self.odpsTable))
- col = read_odps.schema.names
+ col_name = read_odps.schema.names
col_type = [self.get_input_type(str(i)) for i in read_odps.schema.types]
try:
- self.dh.create_project(self.dh_project, 'EasyRecTest')
+ self.dh.create_project(self.dh_project, comment='EasyRecTest')
logging.info('create project success!')
except ResourceExistException:
- logging.info('project %s already exist!' % self.dh_project)
+ logging.warning('project %s already exist!' % self.dh_project)
except Exception as ex:
- logging.info(traceback.format_exc(ex))
- record_schema = RecordSchema.from_lists(col, col_type)
+ logging.error(traceback.format_exc(ex))
+ record_schema = RecordSchema.from_lists(col_name, col_type)
try:
# project_name, topic_name, shard_count, life_cycle, record_schema, comment
self.dh.create_tuple_topic(self.dh_project, self.dh_topic, 7, 3,
- record_schema, 'easyrec_datahub')
- logging.info('create tuple topic success!')
+ record_schema, comment='EasyRecTest')
+ logging.info('create tuple topic %s success!' % self.dh_topic)
except ResourceExistException:
logging.info('topic %s already exist!' % self.dh_topic)
except Exception as ex:
- logging.error('exception:', ex)
+ logging.error('exception:%s' % str(ex))
logging.error(traceback.format_exc())
try:
self.dh.wait_shards_ready(self.dh_project, self.dh_topic)
- logging.info('shards all ready')
+ logging.info('datahub[%s,%s] shards all ready' % (self.dh_project, self.dh_topic))
topic_result = self.dh.get_topic(self.dh_project, self.dh_topic)
if topic_result.record_type != RecordType.TUPLE:
- logging.error('topic type illegal! ')
+ logging.error('invalid topic type: %s' % str(topic_result.record_type))
record_schema = topic_result.record_schema
t = self.odps.get_table(self.odpsTable)
with t.open_reader() as reader:
- size = 0
record_list = []
- for data in reader[0:1000]:
+ for data in reader:
record = TupleRecord(values=data.values, schema=record_schema)
record_list.append(record)
- if size % 1000:
- self.dh.put_records(self.dh_project, self.dh_topic, record_list)
- record_list = []
- size += 1
- except Exception as e:
- logging.error(e)
-
+ for i in range(10):
+ self.dh.put_records(self.dh_project, self.dh_topic, record_list)
+ except Exception as ex:
+ logging.error('exception: %s' % str(ex))
+ logging.error(traceback.format_exc())
def get_oss_bucket(oss_key, oss_secret, endpoint, bucket_name):
"""Build oss2.Bucket instance.
diff --git a/easy_rec/python/tools/predict_and_chk.py b/easy_rec/python/tools/predict_and_chk.py
index 51fa945be..0adf2724f 100644
--- a/easy_rec/python/tools/predict_and_chk.py
+++ b/easy_rec/python/tools/predict_and_chk.py
@@ -4,11 +4,19 @@
import json
import logging
import sys
+import os
+import easy_rec
import numpy as np
from easy_rec.python.inference.predictor import Predictor
+try:
+ import tensorflow as tf
+ tf.load_op_library(os.path.join(easy_rec.ops_dir, 'libembed_op.so'))
+except Exception as ex:
+ logging.warning('exception: %s' % str(ex))
+
logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
@@ -77,7 +85,7 @@
x for fid, x in enumerate(feature.split(args.separator))
if fid not in args.label_id
]
- if len(predictor.input_names) == 1:
+ if 'features' in predictor.input_names:
feature = args.separator.join(feature)
batch_input.append(feature)
output = predictor.predict(batch_input)
diff --git a/easy_rec/python/tools/read_kafka.py b/easy_rec/python/tools/read_kafka.py
new file mode 100644
index 000000000..27ccc35f0
--- /dev/null
+++ b/easy_rec/python/tools/read_kafka.py
@@ -0,0 +1,46 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import sys
+import logging
+import argparse
+from kafka import KafkaProducer
+from kafka import KafkaConsumer, KafkaProducer, KafkaAdminClient
+from kafka.admin import NewTopic
+from kafka.structs import TopicPartition
+
+logging.basicConfig(
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--servers', type=str, default='localhost:9092')
+ parser.add_argument('--topic', type=str, default=None)
+ parser.add_argument('--group', type=str, default='consumer')
+ parser.add_argument('--partitions', type=str, default=None)
+ parser.add_argument('--timeout', type=float, default=float('inf'))
+ args = parser.parse_args()
+
+ if args.topic is None:
+ logging.error('--topic is not set')
+ sys.exit(1)
+
+ servers = args.servers.split(',')
+ consumer = KafkaConsumer(group_id=args.group, bootstrap_servers=servers,
+ consumer_timeout_ms=args.timeout * 1000)
+
+ if args.partitions is not None:
+ partitions = [ int(x) for x in args.partitions.split(',') ]
+ else:
+ partitions = consumer.partitions_for_topic(args.topic)
+ logging.info('partitions: %s' % partitions)
+
+ topics = [ TopicPartition(topic=args.topic, partition=part_id) \
+ for part_id in partitions ]
+ consumer.assign(topics)
+ consumer.seek_to_beginning()
+
+ record_id = 0
+ for x in consumer:
+ logging.info("%d: key=%s\toffset=%d\ttimestamp=%d\tlen=%d" % (record_id, x.key, x.offset,
+ x.timestamp, len(x.value)))
+ record_id += 1
diff --git a/easy_rec/python/tools/write_kafka.py b/easy_rec/python/tools/write_kafka.py
new file mode 100644
index 000000000..8a2fbe2b2
--- /dev/null
+++ b/easy_rec/python/tools/write_kafka.py
@@ -0,0 +1,57 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from kafka import KafkaProducer
+from kafka import KafkaConsumer, KafkaProducer, KafkaAdminClient
+from kafka.admin import NewTopic
+from kafka.structs import TopicPartition
+import time
+import sys
+import logging
+import argparse
+
+logging.basicConfig(
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--servers', type=str, default='localhost:9092')
+ parser.add_argument('--topic', type=str, default=None)
+ parser.add_argument('--group', type=str, default='consumer')
+ parser.add_argument('--partitions', type=str, default=None)
+ parser.add_argument('--timeout', type=float, default=float('inf'))
+ # file to send
+ parser.add_argument('--input_path', type=str, default=None)
+ args = parser.parse_args()
+
+ if args.input_path is None:
+ logging.error('input_path is not set')
+ sys.exit(1)
+
+ if args.topic is None:
+ logging.error('topic is not set')
+ sys.exit(1)
+
+ servers = args.servers.split(',')
+
+ admin_clt = KafkaAdminClient(bootstrap_servers=servers)
+ if args.topic not in admin_clt.list_topics():
+ admin_clt.create_topics(new_topics=[NewTopic(name=args.topic,
+ num_partitions=1, replication_factor=1,
+ topic_configs={'max.message.bytes': 1024 * 1024 * 1024})],
+ validate_only=False)
+ logging.info('create increment save topic: %s' % args.topic)
+ admin_clt.close()
+
+ producer = KafkaProducer(
+ bootstrap_servers=servers,
+ request_timeout_ms=args.timeout * 1000,
+ api_version=(0, 10, 1))
+
+ i = 1
+ with open(args.input_path, 'r') as fin:
+ for line_str in fin:
+ producer.send(args.topic, line_str.encode('utf-8'))
+ i += 1
+ if i % 100 == 0:
+ logging.info('progress: %d' % i)
+ producer.close()
diff --git a/easy_rec/python/utils/constant.py b/easy_rec/python/utils/constant.py
index 9df831a89..8caecaba8 100644
--- a/easy_rec/python/utils/constant.py
+++ b/easy_rec/python/utils/constant.py
@@ -2,3 +2,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
SAMPLE_WEIGHT = 'SAMPLE_WEIGHT'
+
+DENSE_UPDATE_VARIABLES = 'DENSE_UPDATE_VARIABLES'
+
+SPARSE_UPDATE_VARIABLES = 'SPARSE_UPDATE_VARIABLES'
diff --git a/easy_rec/python/utils/embedding_utils.py b/easy_rec/python/utils/embedding_utils.py
new file mode 100644
index 000000000..2e0497f5e
--- /dev/null
+++ b/easy_rec/python/utils/embedding_utils.py
@@ -0,0 +1,45 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+from easy_rec.python.utils import proto_util
+from easy_rec.python.utils import constant
+from tensorflow.python.framework import ops
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def get_norm_name_to_ids():
+ """Get normalize embedding name(including kv variables) to ids.
+
+ Return:
+ normalized names to ids mapping.
+ """
+ norm_name_to_ids = {}
+ for x in ops.get_collection(constant.SPARSE_UPDATE_VARIABLES):
+ norm_name, part_id = proto_util.get_norm_embed_name(x[0].name)
+ norm_name_to_ids[norm_name] = 1
+
+ for tid, t in enumerate(norm_name_to_ids.keys()):
+ norm_name_to_ids[t] = str(tid)
+ return norm_name_to_ids
+
+def get_sparse_name_to_ids():
+ """Get embedding variable(including kv variables) name to ids mapping.
+
+ Return:
+ variable names to ids mappping.
+ """
+ norm_name_to_ids = get_norm_name_to_ids()
+ name_to_ids = {}
+ for x in ops.get_collection(constant.SPARSE_UPDATE_VARIABLES):
+ norm_name, _ = proto_util.get_norm_embed_name(x[0].name)
+ name_to_ids[x[0].name] = norm_name_to_ids[norm_name]
+ return name_to_ids
+
+def get_dense_name_to_ids():
+ dense_train_vars = ops.get_collection(constant.DENSE_UPDATE_VARIABLES)
+ norm_name_to_ids = {}
+ for tid, x in enumerate(dense_train_vars):
+ norm_name_to_ids[x.op.name] = tid
+ return norm_name_to_ids
diff --git a/easy_rec/python/utils/estimator_utils.py b/easy_rec/python/utils/estimator_utils.py
index e406d1c73..950118ff8 100644
--- a/easy_rec/python/utils/estimator_utils.py
+++ b/easy_rec/python/utils/estimator_utils.py
@@ -13,12 +13,30 @@
import numpy as np
import six
+import threading
import tensorflow as tf
+from tensorflow.python.framework import ops
+from easy_rec.python.ops.incr_record import get_sparse_indices
+from tensorflow.python.ops import array_ops
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.python.framework import meta_graph
from tensorflow.python.training.summary_io import SummaryWriterCache
+from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer
+from tensorflow.python.training import training_util
+from tensorflow.python.platform import gfile
+from tensorflow.python.framework import errors_impl
from easy_rec.python.utils import shape_utils
+from easy_rec.python.utils import embedding_utils
+from easy_rec.python.utils import constant
+
+try:
+ import kafka
+ from kafka import KafkaProducer, KafkaAdminClient
+ from kafka.admin import NewTopic
+except ImportError as ex:
+ logging.warning('kafka-python is not installed: %s' % str(ex))
+
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -111,10 +129,10 @@ def _check_flag_file(is_chief, flag_file):
logging.info('_check_flag_file: is_chief = %d flag_file=%s' %
(is_chief, flag_file))
if is_chief:
- with tf.gfile.GFile(flag_file, 'w') as fout:
+ with gfile.GFile(flag_file, 'w') as fout:
fout.write('atexit time: %d' % int(time.time()))
else:
- while not tf.gfile.Exists(flag_file):
+ while not gfile.Exists(flag_file):
time.sleep(1)
from atexit import register
@@ -208,10 +226,10 @@ def _check_flag_file(is_chief, flag_file):
logging.info('_check_flag_file: is_chief = %d flag_file=%s' %
(is_chief, flag_file))
if is_chief:
- with tf.gfile.GFile(flag_file, 'w') as fout:
+ with gfile.GFile(flag_file, 'w') as fout:
fout.write('atexit time: %d' % int(time.time()))
else:
- while not tf.gfile.Exists(flag_file):
+ while not gfile.Exists(flag_file):
time.sleep(1)
from atexit import register
@@ -235,7 +253,7 @@ def __init__(self, num_steps, filename, is_chief):
self._num_steps = num_steps
self._is_chief = is_chief
if self._is_chief:
- self._progress_file = tf.gfile.GFile(filename, 'w')
+ self._progress_file = gfile.GFile(filename, 'w')
self._progress_file.write('0.00\n')
self._progress_interval = 0.01 # 1%
self._last_progress_cnt = 0
@@ -276,7 +294,9 @@ def __init__(self,
checkpoint_basename='model.ckpt',
scaffold=None,
listeners=None,
- write_graph=True):
+ write_graph=True,
+ data_offset_var=None,
+ increment_save_config=None):
"""Initializes a `CheckpointSaverHook`.
Args:
@@ -290,6 +310,8 @@ def __init__(self,
Used for callbacks that run immediately before or after this hook saves
the checkpoint.
write_graph: whether to save graph.pbtxt.
+ data_offset_var: data offset variable.
+ increment_save_config: parameters for saving increment checkpoints.
Raises:
ValueError: One of `save_steps` or `save_secs` should be set.
@@ -304,6 +326,61 @@ def __init__(self,
scaffold=scaffold,
listeners=listeners)
self._write_graph = write_graph
+ self._data_offset_var = data_offset_var
+
+ if increment_save_config is not None:
+ self._dense_name_to_ids = embedding_utils.get_dense_name_to_ids()
+ self._sparse_name_to_ids = embedding_utils.get_sparse_name_to_ids()
+
+ with gfile.GFile(os.path.join(checkpoint_dir, constant.DENSE_UPDATE_VARIABLES),
+ 'w') as fout:
+ json.dump(self._dense_name_to_ids, fout, indent=2)
+
+ save_secs = increment_save_config.dense_save_secs
+ save_steps = increment_save_config.dense_save_steps
+ self._dense_timer = SecondOrStepTimer(every_secs=save_secs if save_secs > 0 else None,
+ every_steps=save_steps if save_steps > 0 else None)
+ save_secs = increment_save_config.sparse_save_secs
+ save_steps = increment_save_config.sparse_save_steps
+ self._sparse_timer = SecondOrStepTimer(every_secs=save_secs if save_secs > 0 else None,
+ every_steps=save_steps if save_steps > 0 else None)
+
+ self._dense_timer.update_last_triggered_step(0)
+ self._sparse_timer.update_last_triggered_step(0)
+
+ self._sparse_indices = []
+ self._sparse_values = []
+ sparse_train_vars = ops.get_collection(constant.SPARSE_UPDATE_VARIABLES)
+ for sparse_var, indice_dtype in sparse_train_vars:
+ with ops.control_dependencies([tf.train.get_global_step()]):
+ with ops.colocate_with(sparse_var):
+ sparse_indice = get_sparse_indices(var_name=sparse_var.op.name, ktype=indice_dtype)
+ sparse_indice = sparse_indice.global_indices
+ self._sparse_indices.append(sparse_indice)
+ if 'EmbeddingVariable' in str(type(sparse_var)):
+ self._sparse_values.append(sparse_var.sparse_read(sparse_indice))
+ else:
+ self._sparse_values.append(array_ops.gather(sparse_var, sparse_indice))
+ if increment_save_config.HasField('kafka'):
+ self._topic = increment_save_config.kafka.topic
+ logging.info('increment save topic: %s' % self._topic)
+
+ admin_clt = KafkaAdminClient(bootstrap_servers=increment_save_config.kafka.server)
+ if self._topic not in admin_clt.list_topics():
+ admin_clt.create_topics(new_topics=[NewTopic(name=self._topic,
+ num_partitions=1, replication_factor=1,
+ topic_configs={'max.message.bytes':1024 * 1024 * 1024})], validate_only=False)
+ logging.info('create increment save topic: %s' % self._topic)
+ admin_clt.close()
+
+ servers = increment_save_config.kafka.server.split(',')
+ self._kafka_producer = KafkaProducer(bootstrap_servers=servers,
+ max_request_size=1024 * 1024 * 64)
+ else:
+ self._kafka_producer = None
+ else:
+ self._dense_timer = None
+ self._sparse_timer = None
def after_create_session(self, session, coord):
global_step = session.run(self._global_step_tensor)
@@ -319,16 +396,93 @@ def after_create_session(self, session, coord):
graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
self._summary_writer.add_graph(graph)
self._summary_writer.add_meta_graph(meta_graph_def)
- # when tf version > 1.10.0, we use defaut training strategy, which saves ckpt
- # at first train step
- if LooseVersion(tf.__version__) >= LooseVersion('1.10.0'):
- # The checkpoint saved here is the state at step "global_step".
- self._save(session, global_step)
+
+ # save for step 0
+ self._save(session, global_step)
+
self._timer.update_last_triggered_step(global_step)
def before_run(self, run_context): # pylint: disable=unused-argument
return tf.train.SessionRunArgs(self._global_step_tensor)
+ def _send_dense(self, global_step, session):
+ dense_train_vars = ops.get_collection(constant.DENSE_UPDATE_VARIABLES)
+ dense_train_vals = session.run(dense_train_vars)
+ logging.info("global_step=%d, increment save dense variables" % global_step)
+
+ msg_num = len(dense_train_vals)
+ msg_ids = [ self._dense_name_to_ids[x.op.name] for x in dense_train_vars]
+ # 0 mean dense update message
+ msg_header = [0, msg_num, global_step]
+ for msg_id, x in zip(msg_ids, dense_train_vals):
+ msg_header.append(msg_id)
+ msg_header.append(x.size)
+ bytes_buf = np.array(msg_header, dtype=np.int32).tobytes()
+ for x in dense_train_vals:
+ bytes_buf += x.tobytes()
+ if self._kafka_producer is not None:
+ msg_key = 'dense_update_%d' % global_step
+ send_res = self._kafka_producer.send(self._topic, bytes_buf, key=msg_key.encode('utf-8'))
+ logging.info('kafka send dense: %d exception: %s' % (global_step, send_res.exception))
+ logging.info("global_step=%d, increment update dense variables, msg_num=%d" \
+ % (global_step, msg_num))
+
+ def _send_sparse(self, global_step, session):
+ sparse_train_vars = ops.get_collection(constant.SPARSE_UPDATE_VARIABLES)
+ sparse_res = session.run(self._sparse_indices + self._sparse_values)
+ msg_num = int(len(sparse_res) / 2)
+
+ sel_ids = [ i for i in range(msg_num) if len(sparse_res[i]) > 0 ]
+ sparse_key_res = [ sparse_res[i] for i in sel_ids ]
+ sparse_val_res = [ sparse_res[i+msg_num] for i in sel_ids ]
+ sparse_train_vars = [ sparse_train_vars[i][0] for i in sel_ids ]
+
+ embed_ids = [ self._sparse_name_to_ids[x.name] for x in sparse_train_vars]
+
+ msg_num = len(sel_ids)
+
+ if msg_num == 0:
+ logging.warning('there are no sparse updates, will skip this send: %d' % global_step)
+ return
+
+
+ # 1 means sparse update messages
+ msg_header = [1, msg_num, global_step]
+ for i, x in enumerate(embed_ids):
+ msg_header.append(x)
+ msg_header.append(len(sparse_res[sel_ids[i]]))
+ bytes_buf = np.array(msg_header, dtype=np.int32).tobytes()
+ for tmp_id, tmp_key, tmp_val, tmp_var in zip(embed_ids, sparse_key_res,
+ sparse_val_res, sparse_train_vars):
+ # for non kv embedding variables, add partition offset to tmp_key
+ if 'EmbeddingVariable' not in str(type(tmp_var)):
+ if tmp_var._save_slice_info is not None:
+ tmp_key += tmp_var._save_slice_info.var_offset[0]
+ bytes_buf += tmp_key.tobytes()
+ bytes_buf += tmp_val.tobytes()
+ if self._kafka_producer is not None:
+ msg_key = 'sparse_update_%d' % global_step
+ send_res = self._kafka_producer.send(self._topic, bytes_buf, key=msg_key.encode('utf-8'))
+ logging.info('kafka send sparse: %d %s' % (global_step, send_res.exception))
+ logging.info("global_step=%d, increment update sparse variables, msg_num=%d, msg_size=%d" \
+ % (global_step, msg_num, len(bytes_buf)))
+
+ def after_run(self, run_context, run_values):
+ super(CheckpointSaverHook, self).after_run(run_context, run_values)
+ stale_global_step = run_values.results
+ global_step = -1
+ if self._dense_timer is not None and self._dense_timer.should_trigger_for_step(stale_global_step + self._steps_per_run):
+ global_step = run_context.session.run(self._global_step_tensor)
+ self._dense_timer.update_last_triggered_step(global_step)
+ self._send_dense(global_step, run_context.session)
+
+ if self._sparse_timer is not None and self._sparse_timer.should_trigger_for_step(stale_global_step + self._steps_per_run):
+ if global_step < 0:
+ global_step = run_context.session.run(self._global_step_tensor)
+
+ self._sparse_timer.update_last_triggered_step(global_step)
+ self._send_sparse(global_step, run_context.session)
+
def _save(self, session, step):
"""Saves the latest checkpoint, returns should_stop."""
logging.info('Saving checkpoints for %d into %s.', step, self._save_path)
@@ -343,6 +497,16 @@ def _save(self, session, step):
write_meta_graph=self._write_graph)
save_dir, save_name = os.path.split(self._save_path)
+ if self._data_offset_var is not None:
+ save_data_offset = session.run(self._data_offset_var)
+ data_offset_json = {}
+ for x in save_data_offset:
+ if x :
+ data_offset_json.update(json.loads(x))
+ save_offset_path = os.path.join(save_dir, 'model.ckpt-%d.offset' % step)
+ with gfile.GFile(save_offset_path, 'w') as fout:
+ json.dump(data_offset_json, fout)
+
self._summary_writer.add_session_log(
tf.SessionLog(
status=tf.SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
@@ -357,6 +521,18 @@ def _save(self, session, step):
should_stop = True
return should_stop
+ def end(self, session):
+ super(CheckpointSaverHook, self).end(session)
+ global_step = session.run(self._global_step_tensor)
+ if self._dense_timer is not None and \
+ global_step != self._dense_timer.last_triggered_step():
+ self._dense_timer.update_last_triggered_step(global_step)
+ self._send_dense(global_step, session)
+ if self._sparse_timer is not None and \
+ global_step != self._sparse_timer.last_triggered_step():
+ self._sparse_timer.update_last_triggered_step(global_step)
+ self._send_sparse(global_step, session)
+
class NumpyCheckpointRestoreHook(SessionRunHook):
"""Restore variable from numpy checkpoint."""
@@ -395,7 +571,7 @@ def begin(self):
vars_not_inited[var_name] = ','.join([str(s) for s in var_shape])
self._restore_op = tf.group(assign_ops)
- with tf.gfile.GFile(self._ckpt_path[:-4] + '_not_inited.txt', 'w') as f:
+ with gfile.GFile(self._ckpt_path[:-4] + '_not_inited.txt', 'w') as f:
for var_name in sorted(vars_not_inited.keys()):
f.write('%s:%s\n' % (var_name, vars_not_inited[var_name]))
assert not has_shape_unmatch, 'exist variable shape not match, restore failed'
@@ -491,6 +667,40 @@ def after_create_session(self, session, coord):
logging.info('restore checkpoint from %s' % ckpt_path)
saver.restore(session, ckpt_path)
+class OssStopSignalHook(SessionRunHook):
+ def __init__(self, model_dir, secs_interval=60, step_interval=10):
+ self._stop_sig_file = os.path.join(model_dir, 'OSS_STOP_SIGNAL')
+ self._stop = False
+ self._check_stop = False
+ self._last_chk_step = 0
+ self._curr_step = 0
+ def _check_stop():
+ while self._check_stop:
+ if self._curr_step < self._last_chk_step + step_interval:
+ time.sleep(1)
+ continue
+ self._last_chk_step = self._curr_step
+ if gfile.Exists(self._stop_sig_file):
+ self._stop = True
+ logging.info('OssStopSignalHook: stop on signal %s' % self._stop_sig_file)
+ break
+ time.sleep(secs_interval)
+ self._th = threading.Thread(target=_check_stop)
+ self._th.start()
+
+ def before_run(self, run_context):
+ if self._stop:
+ run_context.request_stop()
+ self._global_step_tensor = training_util._get_or_create_global_step_read()
+ return tf.train.SessionRunArgs(self._global_step_tensor)
+
+ def after_run(self, run_context, run_values):
+ self._curr_step = run_values.results
+
+ def end(self, session):
+ self._check_stop = True
+ self._th.join()
+
class OnlineEvaluationHook(SessionRunHook):
@@ -516,7 +726,7 @@ def end(self, session):
eval_result_file = os.path.join(self._output_dir,
'online_eval_result.txt-%s' % global_step)
logging.info('Saving online eval result to file %s' % eval_result_file)
- with tf.gfile.GFile(eval_result_file, 'w') as ofile:
+ with gfile.GFile(eval_result_file, 'w') as ofile:
result_to_write = {}
for key in sorted(metric_value_dict):
# convert numpy float to python float
@@ -580,7 +790,11 @@ def latest_checkpoint(model_dir):
Return:
model_path: xx/model.ckpt-2000
"""
- ckpt_metas = tf.gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.meta'))
+ try:
+ ckpt_metas = gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.meta'))
+ except errors_impl.NotFoundError as ex:
+ return None
+
if len(ckpt_metas) == 0:
return None
diff --git a/easy_rec/python/utils/export_big_model.py b/easy_rec/python/utils/export_big_model.py
index 248d6d021..e89b77043 100644
--- a/easy_rec/python/utils/export_big_model.py
+++ b/easy_rec/python/utils/export_big_model.py
@@ -6,6 +6,7 @@
import time
import numpy as np
+from google.protobuf import json_format
import tensorflow as tf
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import ops
@@ -18,12 +19,15 @@
from tensorflow.python.training.device_setter import replica_device_setter
from tensorflow.python.training.monitored_session import ChiefSessionCreator
from tensorflow.python.training.saver import export_meta_graph
+from tensorflow.python.training.monitored_session import Scaffold
import easy_rec
+import json
from easy_rec.python.utils import estimator_utils
from easy_rec.python.utils import io_util
from easy_rec.python.utils import proto_util
-from easy_rec.python.utils.meta_graph_editor import MetaGraphEditor
+from easy_rec.python.utils import constant
+from easy_rec.python.utils.meta_graph_editor import MetaGraphEditor, EMBEDDING_INITIALIZERS
if tf.__version__ >= '2.0':
from tensorflow.python.framework.ops import disable_eager_execution
@@ -33,6 +37,9 @@
GPUOptions = config_pb2.GPUOptions
+INCR_UPDATE_SIGNATURE_KEY = 'incr_update_sig'
+
+
def export_big_model(export_dir, pipeline_config, redis_params,
serving_input_fn, estimator, checkpoint_path, verbose):
for key in redis_params:
@@ -282,6 +289,7 @@ def export_big_model(export_dir, pipeline_config, redis_params,
saver = tf.train.Saver()
with tf.Session(target=server.target if server else '') as sess:
saver.restore(sess, checkpoint_path)
+
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
@@ -298,7 +306,7 @@ def export_big_model(export_dir, pipeline_config, redis_params,
return
-def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
+def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
serving_input_fn, estimator, checkpoint_path,
verbose):
for key in oss_params:
@@ -489,6 +497,7 @@ def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
oss_timeout=oss_params.get('oss_timeout', 1500),
meta_graph_def=meta_graph_def,
norm_name_to_ids=norm_name_to_ids,
+ incr_update_params=oss_params.get('incr_save', None),
debug_dir=export_dir if verbose else '')
meta_graph_editor.edit_graph_for_oss()
tf.reset_default_graph()
@@ -500,11 +509,24 @@ def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
with GFile(embed_name_to_id_file, 'w') as fout:
for tmp_norm_name in norm_name_to_ids:
fout.write('%s\t%s\n' % (tmp_norm_name, norm_name_to_ids[tmp_norm_name]))
- tf.add_to_collection(
- tf.GraphKeys.ASSET_FILEPATHS,
+ ops.add_to_collection(
+ ops.GraphKeys.ASSET_FILEPATHS,
tf.constant(
embed_name_to_id_file, dtype=tf.string, name='embed_name_to_ids.txt'))
+ dense_train_vars_path = os.path.join(os.path.dirname(checkpoint_path), constant.DENSE_UPDATE_VARIABLES)
+ ops.add_to_collection(
+ ops.GraphKeys.ASSET_FILEPATHS,
+ tf.constant(
+ dense_train_vars_path, dtype=tf.string, name=constant.DENSE_UPDATE_VARIABLES))
+
+ kafka_params_file = os.path.join(export_dir, "kafka.txt")
+ with GFile(kafka_params_file, 'w') as fout:
+ json.dump(json.loads(json_format.MessageToJson(oss_params['incr_save']['kafka'],
+ preserving_proto_field_name=True)), fout, indent=2)
+ ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS,
+ tf.constant(kafka_params_file, dtype=tf.string, name="kafka.txt"))
+
export_dir = os.path.join(export_dir,
meta_graph_def.meta_info_def.meta_graph_version)
export_dir = io_util.fix_oss_dir(export_dir)
@@ -518,6 +540,7 @@ def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
tmp = graph.get_tensor_by_name(inputs[tmp_key].name)
tensor_info_inputs[tmp_key] = \
tf.saved_model.utils.build_tensor_info(tmp)
+
tensor_info_outputs = {}
for tmp_key in outputs:
tmp = graph.get_tensor_by_name(outputs[tmp_key].name)
@@ -529,19 +552,43 @@ def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
outputs=tensor_info_outputs,
method_name=signature_constants.PREDICT_METHOD_NAME))
+ incr_update_inputs = meta_graph_editor.sparse_update_inputs
+ incr_update_outputs = meta_graph_editor.sparse_update_outputs
+ incr_update_inputs.update(meta_graph_editor.dense_update_inputs)
+ incr_update_outputs.update(meta_graph_editor.dense_update_outputs)
+ tensor_info_incr_update_inputs = {}
+ tensor_info_incr_update_outputs = {}
+ for tmp_key in incr_update_inputs:
+ tmp = graph.get_tensor_by_name(incr_update_inputs[tmp_key].name)
+ tensor_info_incr_update_inputs[tmp_key] = \
+ tf.saved_model.utils.build_tensor_info(tmp)
+ for tmp_key in incr_update_outputs:
+ tmp = graph.get_tensor_by_name(incr_update_outputs[tmp_key].name)
+ tensor_info_incr_update_outputs[tmp_key] = \
+ tf.saved_model.utils.build_tensor_info(tmp)
+ incr_update_signature = (
+ tf.saved_model.signature_def_utils.build_signature_def(
+ inputs=tensor_info_incr_update_inputs,
+ outputs=tensor_info_incr_update_outputs,
+ method_name=signature_constants.PREDICT_METHOD_NAME))
+
session_config = ConfigProto(
allow_soft_placement=True, log_device_placement=True)
saver = tf.train.Saver()
with tf.Session(target=server.target if server else '') as sess:
saver.restore(sess, checkpoint_path)
+ main_op = tf.group([Scaffold.default_local_init_op(),
+ ops.get_collection(EMBEDDING_INITIALIZERS)])
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature,
+ INCR_UPDATE_SIGNATURE_KEY: incr_update_signature
},
assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS),
saver=saver,
+ main_op=main_op,
strip_default_attrs=True,
clear_devices=True)
builder.save()
diff --git a/easy_rec/python/utils/meta_graph_editor.py b/easy_rec/python/utils/meta_graph_editor.py
index d01c6fb8e..fb5194882 100644
--- a/easy_rec/python/utils/meta_graph_editor.py
+++ b/easy_rec/python/utils/meta_graph_editor.py
@@ -7,8 +7,14 @@
from tensorflow.python.platform.gfile import GFile
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model.loader_impl import SavedModelLoader
+from tensorflow.python.saved_model import constants
+from tensorflow.python.framework import ops
from easy_rec.python.utils import proto_util
+from easy_rec.python.utils import embedding_utils
+from easy_rec.python.utils import constant
+
+EMBEDDING_INITIALIZERS = 'embedding_initializers'
class MetaGraphEditor:
@@ -27,6 +33,7 @@ def __init__(self,
oss_timeout=0,
meta_graph_def=None,
norm_name_to_ids=None,
+ incr_update_params=None,
debug_dir=''):
self._lookup_op = tf.load_op_library(lookup_lib_path)
self._debug_dir = debug_dir
@@ -73,6 +80,37 @@ def __init__(self,
self._oss_ak = oss_ak
self._oss_sk = oss_sk
self._oss_timeout = oss_timeout
+
+ self._kafka_params = None
+ if incr_update_params is not None and 'kafka' in incr_update_params:
+ self._kafka_params = incr_update_params['kafka']
+
+ self._datahub_params = None
+ if incr_update_params is not None and 'datahub' in incr_update_params:
+ self._datahub_params = incr_update_params['datahub']
+
+ # increment update placeholders
+ self._embedding_update_inputs = {}
+ self._embedding_update_outputs = {}
+
+ self._dense_update_inputs = {}
+ self._dense_update_outputs = {}
+
+ @property
+ def sparse_update_inputs(self):
+ return self._embedding_update_inputs
+
+ @property
+ def sparse_update_outputs(self):
+ return self._embedding_update_outputs
+
+ @property
+ def dense_update_inputs(self):
+ return self._dense_update_inputs
+
+ @property
+ def dense_update_outputs(self):
+ return self._dense_update_outputs
@property
def graph_def(self):
@@ -378,7 +416,46 @@ def add_oss_lookup_op(self, lookup_input_indices, lookup_input_values,
combiners=self._embed_combiners,
embedding_dims=self._embed_dims,
embedding_names=self._embed_ids,
- embedding_is_kv=self._embed_is_kv)
+ embedding_is_kv=self._embed_is_kv,
+ shared_name='embedding_lookup_res',
+ name='embedding_lookup_fused/lookup')
+
+ lookup_init_op = self._lookup_op.oss_init(
+ osspath=self._oss_path,
+ endpoint=self._oss_endpoint,
+ ak=self._oss_ak,
+ sk=self._oss_sk,
+ combiners=self._embed_combiners,
+ embedding_dims=self._embed_dims,
+ embedding_names=self._embed_ids,
+ embedding_is_kv=self._embed_is_kv,
+ N=len(self._embed_is_kv),
+ shared_name='embedding_lookup_res',
+ name='embedding_lookup_fused/init')
+
+ ops.add_to_collection(EMBEDDING_INITIALIZERS, lookup_init_op)
+
+ if self._kafka_params:
+ # all sparse variables are updated by a single custom operation
+ message_ph = tf.placeholder(tf.int8, [None], name='incr_update/message')
+ embedding_update = self._lookup_op.embedding_update(
+ message=message_ph,
+ shared_name='embedding_lookup_res',
+ name='embedding_lookup_fused/embedding_update')
+ self._embedding_update_inputs['incr_update/sparse/message'] = message_ph
+ self._embedding_update_outputs['incr_update/sparse/embedding_update'] = embedding_update
+
+ # dense variables are updated one by one
+ dense_name_to_ids = embedding_utils.get_dense_name_to_ids()
+ for x in ops.get_collection(constant.DENSE_UPDATE_VARIABLES):
+ dense_var_id = dense_name_to_ids[x.op.name]
+ dense_input_name = 'incr_update/dense/%d/input' % dense_var_id
+ dense_output_name = 'incr_update/dense/%d/output' % dense_var_id
+ dense_update_input = tf.placeholder(tf.float32, x.get_shape(),
+ name=dense_input_name)
+ self._dense_update_inputs[dense_input_name] = dense_update_input
+ dense_assign_op = tf.assign(x, dense_update_input)
+ self._dense_update_outputs[dense_output_name] = dense_assign_op
meta_graph_def = tf.train.export_meta_graph()
diff --git a/easy_rec/python/utils/multi_optimizer.py b/easy_rec/python/utils/multi_optimizer.py
index 9e5cefbda..c34c4abe0 100644
--- a/easy_rec/python/utils/multi_optimizer.py
+++ b/easy_rec/python/utils/multi_optimizer.py
@@ -38,6 +38,9 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
update_ops.append(opt.apply_gradients(tmp, None))
return tf.group(update_ops)
+ def open_auto_record(self, flag=True):
+ super(MultiOptimizer, self).open_auto_record(flag)
+
def get_slot(self, var, name):
raise NotImplementedError('not implemented')
# for opt in self._opts:
diff --git a/easy_rec/python/utils/test_utils.py b/easy_rec/python/utils/test_utils.py
index 104f4faa0..74a8158c7 100644
--- a/easy_rec/python/utils/test_utils.py
+++ b/easy_rec/python/utils/test_utils.py
@@ -150,9 +150,9 @@ def _load_config_for_test(pipeline_config_path, test_dir, total_steps=50):
def test_datahub_train_eval(pipeline_config_path,
+ odps_oss_config,
test_dir,
process_pipeline_func=None,
- hyperparam_str='',
total_steps=50,
post_check_func=None):
gpus = get_available_gpus()
@@ -175,13 +175,26 @@ def test_datahub_train_eval(pipeline_config_path,
pipeline_config.train_config.train_distribute = 0
pipeline_config.train_config.num_gpus_per_worker = 1
pipeline_config.train_config.sync_replicas = False
+
+ pipeline_config.datahub_train_input.akId = odps_oss_config.dh_id
+ pipeline_config.datahub_train_input.akSecret = odps_oss_config.dh_key
+ pipeline_config.datahub_train_input.region = odps_oss_config.dh_endpoint
+ pipeline_config.datahub_train_input.project = odps_oss_config.dh_project
+ pipeline_config.datahub_train_input.topic = odps_oss_config.dh_topic
+
+ pipeline_config.datahub_eval_input.akId = odps_oss_config.dh_id
+ pipeline_config.datahub_eval_input.akSecret = odps_oss_config.dh_key
+ pipeline_config.datahub_eval_input.region = odps_oss_config.dh_endpoint
+ pipeline_config.datahub_eval_input.project = odps_oss_config.dh_project
+ pipeline_config.datahub_eval_input.topic = odps_oss_config.dh_topic
+
if process_pipeline_func is not None:
assert callable(process_pipeline_func)
pipeline_config = process_pipeline_func(pipeline_config)
config_util.save_pipeline_config(pipeline_config, test_dir)
test_pipeline_config_path = os.path.join(test_dir, 'pipeline.config')
- train_cmd = 'python3 -m easy_rec.python.train_eval --pipeline_config_path %s %s' % (
- test_pipeline_config_path, hyperparam_str)
+ train_cmd = 'python -m easy_rec.python.train_eval --pipeline_config_path %s' % \
+ test_pipeline_config_path
proc = run_cmd(train_cmd, '%s/log_%s.txt' % (test_dir, 'master'))
proc.wait()
if proc.returncode != 0:
diff --git a/pai_jobs/deploy.sh b/pai_jobs/deploy.sh
index e1f10f6f5..385669119 100755
--- a/pai_jobs/deploy.sh
+++ b/pai_jobs/deploy.sh
@@ -92,6 +92,16 @@ fi
cp easy_rec/__init__.py easy_rec/__init__.py.bak
sed -i -e "s/\[VERSION\]/$VERSION/g" easy_rec/__init__.py
find -L easy_rec -name "*.pyc" | xargs rm -rf
+
+if [ ! -d "datahub" ]
+then
+ wget http://easyrec.oss-cn-beijing.aliyuncs.com/third_party/pydatahub.tar.gz
+ if [ $? -ne 0 ]
+ then
+ echo "datahub download failed."
+ fi
+ tar -zvxf pydatahub.tar.gz
+fi
tar -cvzhf $RES_PATH easy_rec run.py
mv easy_rec/__init__.py.bak easy_rec/__init__.py
diff --git a/pai_jobs/deploy_ext.sh b/pai_jobs/deploy_ext.sh
index 8426596d5..142e45880 100755
--- a/pai_jobs/deploy_ext.sh
+++ b/pai_jobs/deploy_ext.sh
@@ -92,7 +92,19 @@ fi
cp -R $root_dir/easy_rec ./easy_rec
sed -i -e "s/\[VERSION\]/$VERSION/g" easy_rec/__init__.py
find -L easy_rec -name "*.pyc" | xargs rm -rf
-tar -cvzhf $RES_PATH easy_rec run.py
+
+if [ ! -d "datahub" ]
+then
+ wget http://easyrec.oss-cn-beijing.aliyuncs.com/third_party/pydatahub.tar.gz
+ if [ $? -ne 0 ]
+ then
+ echo "datahub download failed."
+ fi
+ tar -zvxf pydatahub.tar.gz
+ rm -rf pydatahub.tar.gz
+fi
+
+tar -cvzhf $RES_PATH easy_rec datahub lz4 cprotobuf run.py
# 2 means generate only
if [ $mode -ne 2 ]
diff --git a/pai_jobs/easy_rec_flow/easy_rec.xml b/pai_jobs/easy_rec_flow/easy_rec.xml
index d9dc17ddc..4c04c52b0 100644
--- a/pai_jobs/easy_rec_flow/easy_rec.xml
+++ b/pai_jobs/easy_rec_flow/easy_rec.xml
@@ -56,7 +56,7 @@
#
#g;s###g' index.html +cp easy_rec/python/protos/dataset_pb2.py easy_rec/python/protos/tf_predict_pb2.py processor/ diff --git a/scripts/kafka_test.sh b/scripts/kafka_test.sh new file mode 100644 index 000000000..e18193629 --- /dev/null +++ b/scripts/kafka_test.sh @@ -0,0 +1 @@ +kafka_install_dir=../kafka_2.13-3.1.0/ oss_path=oss://yangxi-bj/export_embedding_taobao_fg_step_0 oss_ak=xxx oss_sk=xxx oss_endpoint=oss-cn-beijing.aliyuncs.com TEST_DEVICES='' PYTHONPATH=.:pai_jobs/ python -m easy_rec.python.test.kafka_test KafkaTest.test_kafka_processor