Skip to content

TF Keras saves model in distribution

workingloong edited this page Nov 15, 2019 · 1 revision

Keras Saves a Model Using a Distribution Strategy

How does Keras place the variables using a distribution strategy

If the users want to use a distributions strategy to train a Keras model, they need enter the strategy scope to build and compile the model like:

with strategy.scope():
    inputs = tf.keras.layers.Input(4,)
    dense = tf.keras.layers.Dense(4)(inputs)
    output = tf.keras.layers.Dense(1)(dense)
    model = tf.keras.models.Model(inputs=inputs, outputs=output)
    model.compile(optimizer='rmsprop',
                loss='binary_crossentropy',
                metrics=['accuracy'])

The strategy.scope() returns a _CurrentDitributionConext. When we enter the context, we will also enter the variable_scope, variable_creator_scope and device_scope.

def __enter__(self):
    # Allow this scope to be entered if this strategy is already in scope.
    if distribution_strategy_context.has_strategy():
      _require_cross_replica_or_default_context_extended(
          self._context.strategy.extended)
      self._same_scope_again_count += 1
    else:
      _push_per_thread_mode(self._context)
      if self._var_scope:
        self._var_scope.__enter__()
      self._var_creator_scope.__enter__()
      if self._device_scope:
        self._device_scope.__enter__()
    return self._context.strategy
  • variable_scope: A context manager for defining ops that creates variables (layers).
  • variable_creator_scope: Scope which defines a variable creation function to be used by variable().
  • device_scope: Context-manager to force placement of operations and Tensors on a device.

Now, we focus on the device_scope which determines the variables placement. After entering the device_scope, Keras will get the local device information in the current process and set the device information to _EagerDeviceContext. We can view the device name from the context by the following code snippet:

from tensorflow.python.eager import context as _context
with strategy.scope():
    _ctx = _context._context
    print(_ctx)
    print(_ctx._thread_local_data.device_name)

The _ctx is the distribution execution context and it will display all available devices including CPU and CPU. The _ctx._thread_local_data.device_name is the device name to run the current thread.

After setting the execution context, Keras will execute the ops on the device whose device name is _ctx._thread_local_data.device_name. Before calling each layer in a Keras model, Keras need to execute some Tensorflow ops in add_weight and initializer to create ResourceVariable variables for each layer. So, the variables will be placed on the device in the execution context. As is known, the Tensorflow ops are executed in CPP runtime. The execution code to execute ResourceVariable creation ops is generated by source CPP file "resource_variable_ops.cc", that is "tensorflow_core/python/ops/gen_resource_variable_ops.py". The ops execution code is following:

  _ctx = _context._context or _context.context()
  if _ctx is not None and _ctx._thread_local_data.is_eager:
    try: 
      _result = _pywrap_tensorflow.TFE_Py_FastPathExecute(
        _ctx._context_handle, _ctx._thread_local_data.device_name,
        "VarHandleOp", name, _ctx._post_execution_callbacks, "container",
        container, "shared_name", shared_name, "dtype", dtype, "shape", shape)
      return _result
    except _core._FallbackException:
      try: 
        return var_handle_op_eager_fallback(
            container=container, shared_name=shared_name, dtype=dtype,
            shape=shape, name=name, ctx=_ctx)
      except _core._SymbolicException:
        pass  # Add nodes to the TensorFlow graph.
    except _core._NotOkStatusException as e:
      if name is not None:
        message = e.message + " name: " + name 
      else:
        message = e.message
      _six.raise_from(_core._status_to_exception(e.code, message), None)

we can see that the runtime firstly get the context constructed by strategy.scope and then executes the ops on the device _ctx._thread_local_data.device_name.

The CentralStorageStrategy uses the ParameterServerStrategyExtended extends StrategyExtendedV2. The default device is None in the StrategyExtendedV2 and ParameterServerStrategyExtended does not set the default device for the strategy. So, automatic placement is performed for CentralStorageStrategy. We can verify the conclusion by the follow snippet.

import tensorflow as tf
from tensorflow.python.eager import context as _context

_ctx = _context._context
print("The context before entering startegy scope is:")
print(_ctx)
print("current device name: ", _ctx._thread_local_data.device_name)

strategy = tf.distribute.experimental.CentralStorageStrategy()

with strategy.scope():
    _ctx = _context._context
    print("The context in the startegy scope is:")
    print(_ctx)
    print("current device name: ", _ctx._thread_local_data.device_name)

MultiWorkerMirroredStrategy uses CollectiveAllReduceExtended which will set _default_device for the strategy. CollectiveAllReduceExtended will initialize the strategy according to the cluster_spec which is generated by TF_CONFIG and set the current device for the strategy. So, we will initialize all variables on each device using MultiWorkerMirroredStrategy. To verify the conclusion, we can execute the snippet:

import os
import json
import tensorflow as tf
from tensorflow.python.eager import context as _context

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'chief': ["localhost:12347"],
        'worker': ["localhost:12348"]
    },
    'task': {'type': 'worker', 'index': 0}
})
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

with strategy.scope():
    _ctx = _context._context or _context.context()
    print(_ctx)
    print("current device name: ", _ctx._thread_local_data.device_name)

How does Keras execute ops using a distribution strategy

Keras trains the model by model.fit() which calls model_iteration in "training_arrays.py". The execution steps in model_iteration are:

  1. Make execution function fn e.g. model.train_on_batch for per replica.
  2. Run fn for each mini-batch samples once per replica. tf.distribute.Strategy.experimental_run_v2 will create _MirroredReplicaThreads for each replica/worker device to run fn under its owner device scope.
  3. Collect all outputs on each replica by the strategy.
def _make_execution_function_without_cloning(model, mode):
"""Creates a function to run one step of distributed model execution."""
strategy = model._distribution_strategy

with strategy.scope():
  per_replica_function = _make_replica_execution_function(model, mode)

  def distributed_function(input_fn):
    """A single step of the distributed execution across replicas."""
    x, y, sample_weights = input_fn()
    # Call `Model.{train,test,predict}_on_batch` on every replica passing
    # PerReplicas as arguments.  On every replica inside this call, each
    # PerReplica object will return the value for that replica.  The outputs
    # are PerReplicas too.
    outputs = strategy.experimental_run_v2(
        per_replica_function, args=(x, y, sample_weights))
    # Out of PerReplica outputs reduce or pick values to return.
    all_outputs = unwrap_outputs(
        strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
    return all_outputs
    ...
    execution_function = distributed_function

  return execution_function

How does Keas save the model using distribution strategy

The steps in tf.saved_model.save are:

  1. Construct a _AugmentedGraphView base the Keras model instance.
  2. Generate signatures for the model instance.
  3. Gather all tensors in the graph view and group the tensors by device name.
  4. Save all tensors by ops.save.
  5. Save the graph view to saved_model.pb file.

In this section, we focus on how Keras save variables to files in the supported distribution strategy. After Keras places all variables on the devices, all variables contain the placement device information.

import tensorflow as tf
with tf.device(''):
    v = tf.Variable(tf.zeros([10, 10]))
print(v.device)

tf.saved_model.save will collect all variables in the current process and group variables to different shards by its device name. Then Keras will run io_ops.save() to save the shard to shard file under device scope. In the following snippet, we simulate that Keras saves multiple variable groups and each group represents variables on a device.

import os
import uuid
import tensorflow as tf
from tensorflow.python import ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.lib.io import file_io

def save_tensor(tensors, export_dir):
    filename_tensor = export_dir
    tensor_names = []
    tensor_slices = []
    for tensor in tensors:
        tensor_names.append(tensor.name)
        tensor_slices.append('') 
    io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,tensors)
    
def save_checkpoint_with_shards(shard_tensor, checkpoint_prefix):
    file_io.recursive_create_dir(os.path.dirname(checkpoint_prefix))
    _SHARDED_SUFFIX = "_temp_%s/part" % uuid.uuid4().hex
    tmp_checkpoint_prefix = string_ops.string_join(
            [checkpoint_prefix, _SHARDED_SUFFIX])
    sharded_prefixes = []
    num_shards = len(shard_tensor)
    for shard_id,tensors in shard_tensor:
        sharded_filename = gen_io_ops.sharded_filename(tmp_checkpoint_prefix, shard_id, num_shards)
        save_tensor(tensors, sharded_filename)
        sharded_prefixes.append(sharded_filename)
    gen_io_ops.merge_v2_checkpoints(sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
        
shard_tensor = [(0, [tf.Variable([1,2,3], name='t1-0'),tf.Variable([1,2,3], name='t1-1')]), 
                (1, [tf.Variable([4,5,6], name='t2')]), 
                (2, [tf.Variable([7,8,9], name='t3')]),
               ]
save_checkpoint_with_shards(shard_tensor, checkpoint_prefix='ckpt-0/variable')

restore_variable = io_ops.restore_v2('ckpt-0/variable', ['t1-0:0'], [''], [tf.int32])[0]
print('Restore tensor is:', restore_variable)

The saved file contents is

|-- ckpt-0
    |-- variable.data-00000-of-00003 
    |-- variable.data-00001-of-00003 
    |-- variable.data-00002-of-00003 
    |-- variable.index

In "variable.data-00000-of-00003", the "00000" is the first shard and "00003" is the number of shards.