Skip to content

Commit

Permalink
Add CMSIS-NN int8 and int16 batch matmul (#2669)
Browse files Browse the repository at this point in the history
* Moves some common functions with ref to new header file
 * Creates new cmsis_nn batch_matmul.cc

Authored-by:    Ryan O'Shea <[email protected]>
Co-authored-by: Adrian Lundell <[email protected]>

BUG=Add BatchMatmul kernels to cmsisnn
  • Loading branch information
ArmRyan authored Sep 9, 2024
1 parent 19aaea8 commit 89f99a9
Show file tree
Hide file tree
Showing 5 changed files with 730 additions and 139 deletions.
1 change: 1 addition & 0 deletions tensorflow/lite/micro/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ tflm_kernel_cc_library(
hdrs = [
"activations.h",
"add.h",
"batch_matmul.h",
"circular_buffer.h",
"conv.h",
"depthwise_conv.h",
Expand Down
159 changes: 22 additions & 137 deletions tensorflow/lite/micro/kernels/batch_matmul.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -24,60 +24,31 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/transpose.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/batch_matmul.h"
#include "tensorflow/lite/micro/micro_log.h"

namespace tflite {
namespace {

constexpr int kInputLhsTensor = 0;
constexpr int kInputRhsTensor = 1;
constexpr int kOutputTensor = 0;

struct QuantizationOpData {
// The scaling factor from input to output (aka the 'real multiplier') can
// be represented as a fixed point multiplier plus a left shift.
int32_t output_multiplier;
int output_shift; // exponent

// The range of the fused activation layer. For example for kNone and
// int8_t these would be -128 and 127.
int32_t output_activation_min;
int32_t output_activation_max;

int32_t lhs_zero_point;
int32_t rhs_zero_point;
int32_t output_zero_point;
};

struct OpData {
QuantizationOpData* quantization;

// Transpose tensors and state
TfLiteEvalTensor* lhs_transposed_tensor;
TfLiteEvalTensor* rhs_transposed_tensor;
bool rhs_is_transposed;
bool lhs_is_constant_tensor;
bool rhs_is_constant_tensor;
};

struct OpContext {
OpContext(TfLiteContext* context, TfLiteNode* node)
: params(static_cast<TfLiteBatchMatMulParams*>(node->builtin_data)),
op_data(static_cast<OpData*>(node->user_data)) {}
op_data(static_cast<OpDataBatchMatmul*>(node->user_data)) {}

TfLiteBatchMatMulParams* params;
OpData* op_data;
OpDataBatchMatmul* op_data;
};

struct PrepareOpContext : OpContext {
PrepareOpContext(TfLiteContext* context, TfLiteNode* node)
: OpContext(context, node),
micro_context_(GetMicroContext(context)),
lhs(micro_context_->AllocateTempInputTensor(node, kInputLhsTensor)),
rhs(micro_context_->AllocateTempInputTensor(node, kInputRhsTensor)),
output(micro_context_->AllocateTempOutputTensor(node, kOutputTensor)) {}
lhs(micro_context_->AllocateTempInputTensor(
node, kBatchMatmulInputLhsTensor)),
rhs(micro_context_->AllocateTempInputTensor(
node, kBatchMatmulInputRhsTensor)),
output(micro_context_->AllocateTempOutputTensor(
node, kBatchMatmulOutputTensor)) {}

~PrepareOpContext() {
if (lhs != nullptr) {
Expand All @@ -103,56 +74,18 @@ struct PrepareOpContext : OpContext {
struct EvalOpContext : OpContext {
EvalOpContext(TfLiteContext* context, TfLiteNode* node)
: OpContext(context, node),
lhs(tflite::micro::GetEvalInput(context, node, kInputLhsTensor)),
rhs(tflite::micro::GetEvalInput(context, node, kInputRhsTensor)),
output(tflite::micro::GetEvalOutput(context, node, kOutputTensor)) {}
lhs(tflite::micro::GetEvalInput(context, node,
kBatchMatmulInputLhsTensor)),
rhs(tflite::micro::GetEvalInput(context, node,
kBatchMatmulInputRhsTensor)),
output(tflite::micro::GetEvalOutput(context, node,
kBatchMatmulOutputTensor)) {}

const TfLiteEvalTensor* lhs;
const TfLiteEvalTensor* rhs;
TfLiteEvalTensor* output;
};

TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, TfLiteNode* node,
const RuntimeShape& extended_lhs_shape,
const RuntimeShape& extended_rhs_shape,
bool adj_x, bool adj_y, int output_rank,
TfLiteTensor* output) {
int64_t orig_size = NumElements(output);

// make sure the new output dims rank does not exceed the original rank
TF_LITE_ENSURE(context, output_rank <= NumDimensions(output));

// make sure output tensor dims are not in the FlatBuffer
TfLiteEvalTensor* output_eval =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy(
context, output, output_eval));

// Fill in any broadcast dimensions.
for (int i = 0; i < output_rank - 2; ++i) {
const int lhs_dim = extended_lhs_shape.Dims(i);
const int rhs_dim = extended_rhs_shape.Dims(i);
int broadcast_dim = lhs_dim;
if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) {
broadcast_dim = rhs_dim;
}
output->dims->data[i] = broadcast_dim;
}
// Fill in the matmul dimensions.
int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;

output->dims->data[output_rank - 2] = extended_lhs_shape.Dims(lhs_rows_index);
output->dims->data[output_rank - 1] = extended_rhs_shape.Dims(rhs_cols_index);
output->dims->size = output_rank;

// Check that output tensor has not been resized
// since TFLM doesn't support tensor resizing.
TF_LITE_ENSURE_EQ(context, orig_size, NumElements(output));

return kTfLiteOk;
}

TfLiteEvalTensor* AllocInitTransposeTensorFromTfLiteTensor(
TfLiteContext* context, const TfLiteTensor& tensor) {
MicroContext* micro_context = GetMicroContext(context);
Expand Down Expand Up @@ -195,7 +128,7 @@ TfLiteEvalTensor* AllocInitTransposeTensorFromTfLiteTensor(
// Allocate normal quantization data if needed.
TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
const PrepareOpContext& op_context) {
OpData* op_data = op_context.op_data;
OpDataBatchMatmul* op_data = op_context.op_data;
const TfLiteTensor* lhs = op_context.lhs;
const TfLiteTensor* rhs = op_context.rhs;
MicroContext* micro_context = GetMicroContext(context);
Expand Down Expand Up @@ -231,62 +164,14 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}

template <typename Scalar>
void TransposeRowsColumnsImpl(const TfLiteEvalTensor& tensor_in,
TfLiteEvalTensor* tensor_out) {
const Scalar* input = tflite::micro::GetTensorData<Scalar>(&tensor_in);
Scalar* output = tflite::micro::GetTensorData<Scalar>(tensor_out);
RuntimeShape transposed_shape(tflite::micro::GetTensorShape(&tensor_in));
RuntimeShape shape(transposed_shape);
TransposeParams params;
const int rank = shape.DimensionsCount();
params.perm_count = rank;
for (int i = 0; i < rank - 2; ++i) {
params.perm[i] = i;
}
// Transpose the last two dimensions.
params.perm[rank - 2] = rank - 1;
params.perm[rank - 1] = rank - 2;
transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2));
transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1));
reference_ops::Transpose(params, shape, input, transposed_shape, output);
}

TfLiteStatus TransposeRowsColumns(const TfLiteEvalTensor& tensor_in,
TfLiteEvalTensor* tensor_out) {
if (tensor_in.type == kTfLiteFloat32) {
TransposeRowsColumnsImpl<float>(tensor_in, tensor_out);
return kTfLiteOk;
} else if (tensor_in.type == kTfLiteInt8) {
TransposeRowsColumnsImpl<int8_t>(tensor_in, tensor_out);
return kTfLiteOk;
} else if (tensor_in.type == kTfLiteInt16) {
TransposeRowsColumnsImpl<int16_t>(tensor_in, tensor_out);
return kTfLiteOk;
} else {
MicroPrintf(
"BATCH_MATMUL can only transpose tensors with FLOAT32, INT8, INT16 "
"type.");
}
return kTfLiteError;
}

RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) {
RuntimeShape swapped_shape(shape);
const int32_t dims = shape.DimensionsCount();
swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
return swapped_shape;
}

void* BatchMatMulInit(TfLiteContext* context, const char* buffer,
size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
MicroContext* micro_context = GetMicroContext(context);
return micro_context->AllocatePersistentBuffer(sizeof(OpData));
return micro_context->AllocatePersistentBuffer(sizeof(OpDataBatchMatmul));
}

TfLiteStatus BatchMatMulPrepare(TfLiteContext* context, TfLiteNode* node) {
Expand Down Expand Up @@ -323,7 +208,7 @@ TfLiteStatus BatchMatMulPrepare(TfLiteContext* context, TfLiteNode* node) {

TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, op_context));

OpData* op_data = op_context.op_data;
OpDataBatchMatmul* op_data = op_context.op_data;
// If the RHS is constant, we only transpose once.
op_data->rhs_is_transposed = false;
op_data->lhs_is_constant_tensor = IsConstantTensor(lhs_data);
Expand Down Expand Up @@ -393,7 +278,7 @@ TfLiteStatus BatchMatMulPrepare(TfLiteContext* context, TfLiteNode* node) {
return status;
}

TfLiteStatus EvalInt8(TfLiteContext* context, const OpData& data,
TfLiteStatus EvalInt8(TfLiteContext* context, const OpDataBatchMatmul& data,
const RuntimeShape& lhs_shape,
const TfLiteEvalTensor& lhs,
const RuntimeShape& rhs_shape,
Expand Down Expand Up @@ -423,7 +308,7 @@ TfLiteStatus EvalInt8(TfLiteContext* context, const OpData& data,
return kTfLiteOk;
}

TfLiteStatus EvalInt16(TfLiteContext* context, const OpData& data,
TfLiteStatus EvalInt16(TfLiteContext* context, const OpDataBatchMatmul& data,
const RuntimeShape& lhs_shape,
const TfLiteEvalTensor& lhs,
const RuntimeShape& rhs_shape,
Expand Down Expand Up @@ -466,7 +351,7 @@ TfLiteStatus EvalInt16(TfLiteContext* context, const OpData& data,
// A X C row-oriented.
TfLiteStatus BatchMatMulEval(TfLiteContext* context, TfLiteNode* node) {
EvalOpContext op_context(context, node);
OpData* op_data = op_context.op_data;
OpDataBatchMatmul* op_data = op_context.op_data;
const TfLiteEvalTensor* lhs = op_context.lhs;
const TfLiteEvalTensor* rhs = op_context.rhs;
TfLiteEvalTensor* output = op_context.output;
Expand Down
Loading

0 comments on commit 89f99a9

Please sign in to comment.