From 700269e00d9af8092289a2e4f4894214a7307ca1 Mon Sep 17 00:00:00 2001 From: ddavis-2015 Date: Tue, 13 Aug 2024 20:03:35 -0700 Subject: [PATCH] added DEPTHWISE_CONV --- .../lite/micro/kernels/depthwise_conv.cc | 39 +++++++++++++++++++ .../micro/kernels/depthwise_conv_common.cc | 21 +++++++++- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/micro/kernels/depthwise_conv.cc b/tensorflow/lite/micro/kernels/depthwise_conv.cc index fa55a705606..8d15f3bd1b4 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv.cc @@ -52,6 +52,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { ? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor) : nullptr; +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: { tflite::reference_ops::DepthwiseConv( @@ -59,9 +71,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -94,9 +115,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -118,9 +148,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc index 52804de3315..23e0b8b1e01 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc @@ -127,7 +127,9 @@ TfLiteStatus CalculateOpDataDepthwiseConv( micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); - micro_context->DeallocateTempTfLiteTensor(bias); + if (has_bias) { + micro_context->DeallocateTempTfLiteTensor(bias); + } micro_context->DeallocateTempTfLiteTensor(output); return kTfLiteOk; @@ -209,6 +211,23 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) { context, node, params, input_width, input_height, filter_width, filter_height, output_width, output_height, input->type, data)); +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kDepthwiseConvWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kDepthwiseConvWeightsTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter);