diff --git a/codegen/examples/hello_world/Makefile.inc b/codegen/examples/hello_world/Makefile.inc index 56e2da712f6..de8d902d736 100644 --- a/codegen/examples/hello_world/Makefile.inc +++ b/codegen/examples/hello_world/Makefile.inc @@ -1,9 +1,13 @@ +# TODO(rjascani): The codegen runtime files (ie, in runtime subdir) should be a +# separate library. CODEGEN_HELLO_WORLD_SRCS := \ $(TENSORFLOW_ROOT)codegen/examples/hello_world/hello_world.cc \ -$(TENSORFLOW_ROOT)codegen/examples/hello_world/hello_world_model.cc +$(TENSORFLOW_ROOT)codegen/examples/hello_world/hello_world_model.cc \ +$(TENSORFLOW_ROOT)codegen/runtime/micro_codegen_context.cc CODEGEN_HELLO_WORLD_HDRS := \ -$(TENSORFLOW_ROOT)codegen/examples/hello_world/hello_world_model.h +$(TENSORFLOW_ROOT)codegen/examples/hello_world/hello_world_model.h \ +$(TENSORFLOW_ROOT)codegen/runtime/micro_codegen_context.h # Builds a standalone binary. $(eval $(call microlite_test,codegen_hello_world,\ diff --git a/codegen/examples/hello_world/hello_world_model.cc b/codegen/examples/hello_world/hello_world_model.cc index 36ab962f43c..6b9be5f2475 100644 --- a/codegen/examples/hello_world/hello_world_model.cc +++ b/codegen/examples/hello_world/hello_world_model.cc @@ -17,11 +17,14 @@ limitations under the License. #include "hello_world_model.h" +#include "codegen/runtime/micro_codegen_context.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/micro_common.h" +#include "tensorflow/lite/micro/micro_context.h" namespace hello_world_model { namespace { @@ -101,6 +104,10 @@ alignas(16) uint8_t buffer_7[16] = { // buffer_10 is located in the arena +constexpr size_t kSubgraph0Inputs[1] = {0}; + +constexpr size_t kSubgraph0Outputs[1] = {9}; + struct Node0_0 { struct Inputs { int size = 3; @@ -202,13 +209,34 @@ struct Tensor0_9Dims { int data[2] = {1, 1}; } tensor0_9_dims; +TfLiteStatus InvokeSubgraph0(TfLiteContext* context, + tflite::Span nodes) { + TFLITE_DCHECK(nodes.size() == 3); + TF_LITE_ENSURE_OK( + context, op_table[OpCode::kFullyConnected].invoke(context, &nodes[0])); + TF_LITE_ENSURE_OK( + context, op_table[OpCode::kFullyConnected].invoke(context, &nodes[1])); + TF_LITE_ENSURE_OK( + context, op_table[OpCode::kFullyConnected].invoke(context, &nodes[2])); + + return kTfLiteOk; +} + } // namespace -Model::Model() { - context_.impl_ = nullptr; +Model::Model() + : subgraphs_{ + {.inputs = {&kSubgraph0Inputs[0], 1}, + .outputs = {&kSubgraph0Outputs[0], 1}, + .nodes = {&subgraph0_nodes_[0], 3}, + .tensors = {&subgraph0_tensors_[0], 10}, + .invoke = &InvokeSubgraph0}, + }, + micro_context_{&context_, {&subgraphs_[0], 1}} { + context_.impl_ = static_cast(µ_context_); context_.ReportError = nullptr; context_.GetTensor = nullptr; - context_.GetEvalTensor = nullptr; + context_.GetEvalTensor = tflite::MicroContextGetEvalTensor; context_.profiler = nullptr; context_.GetExternalContext = nullptr; context_.GetScratchBuffer = nullptr; @@ -280,17 +308,6 @@ Model::Model() { .type = kTfLiteInt8}; } -TfLiteStatus Model::Invoke() { return InvokeSubgraph0(); } - -TfLiteStatus Model::InvokeSubgraph0() { - TF_LITE_ENSURE_OK(context_, op_table[OpCode::kFullyConnected].invoke( - &context_, &subgraph0_nodes_[0])); - TF_LITE_ENSURE_OK(context_, op_table[OpCode::kFullyConnected].invoke( - &context_, &subgraph0_nodes_[1])); - TF_LITE_ENSURE_OK(context_, op_table[OpCode::kFullyConnected].invoke( - &context_, &subgraph0_nodes_[2])); - - return kTfLiteOk; -} +TfLiteStatus Model::Invoke() { return micro_context_.InvokeSubgraph(0); } } // namespace hello_world_model diff --git a/codegen/examples/hello_world/hello_world_model.h b/codegen/examples/hello_world/hello_world_model.h index 3a78b77202a..80cfe2c3221 100644 --- a/codegen/examples/hello_world/hello_world_model.h +++ b/codegen/examples/hello_world/hello_world_model.h @@ -17,6 +17,7 @@ limitations under the License. #pragma once +#include "codegen/runtime/micro_codegen_context.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" @@ -29,9 +30,9 @@ class Model { TfLiteStatus Invoke(); private: - TfLiteStatus InvokeSubgraph0(); - TfLiteContext context_ = {}; + tflite::Subgraph subgraphs_[1]; + tflite::MicroCodegenContext micro_context_; TfLiteNode subgraph0_nodes_[3] = {}; TfLiteEvalTensor subgraph0_tensors_[10] = {}; }; diff --git a/codegen/graph.py b/codegen/graph.py index 8aa5ad11fd8..ad5a700c696 100644 --- a/codegen/graph.py +++ b/codegen/graph.py @@ -73,6 +73,14 @@ def __init__(self, model: schema_fb.ModelT, buffers: Sequence[tensor.Buffer], def index(self) -> int: return self._subgraph_idx + @property + def inputs(self) -> Sequence[int]: + return self._subgraph.inputs + + @property + def outputs(self) -> Sequence[int]: + return self._subgraph.outputs + @property def operators(self) -> Sequence[operator.Operator]: return self._operators @@ -85,6 +93,18 @@ def tensors(self) -> Sequence[tensor.Tensor]: def needs_zero_length_int_array(self) -> bool: return any(t.needs_zero_length_int_array for t in self.tensors) + @property + def invoke_fn_name(self) -> str: + return f"InvokeSubgraph{self.index}" + + @property + def inputs_array_name(self) -> str: + return f"kSubgraph{self.index}Inputs" + + @property + def outputs_array_name(self) -> str: + return f"kSubgraph{self.index}Outputs" + @property def nodes_array(self) -> str: return f"subgraph{self.index}_nodes_" @@ -116,16 +136,54 @@ def generate_c_node_init(self, indent: str) -> str: return textwrap.indent("\n".join(node_init_strs), indent) def generate_c_invoke(self, indent: str) -> str: - invoke_template = string.Template( - "TF_LITE_ENSURE_OK(context_, op_table[${op_code}].invoke(\n" - " &context_, &${node}));\n") + function_template = string.Template( + "TfLiteStatus ${function_name}(TfLiteContext* context,\n" + " tflite::Span nodes) {\n" + " TFLITE_DCHECK(nodes.size() == ${num_nodes});\n" + "${body}\n" + " return kTfLiteOk;\n" + "}") + + body_template = string.Template( + " TF_LITE_ENSURE_OK(\n" + " context, op_table[${op_code}].invoke(context, &${node}));\n") invoke_strs: List[str] = [] for op_idx, op in enumerate(self.operators): invoke_strs.append( - invoke_template.substitute( + body_template.substitute( op_code=self._op_codes[op.op_code_index].full_enum_name, - node=self.nodes_element(op_idx))) - return textwrap.indent("".join(invoke_strs), indent) + node=f"nodes[{op_idx}]")) + + invoke = function_template.substitute(function_name=self.invoke_fn_name, + num_nodes=len(self.operators), + body="".join(invoke_strs)) + return textwrap.indent(invoke, indent) + + def generate_c_input_array(self, indent: str) -> str: + return utils.generate_c_int_array(indent, "size_t", self.inputs_array_name, + self.inputs) + + def generate_c_output_array(self, indent: str) -> str: + return utils.generate_c_int_array(indent, "size_t", + self.outputs_array_name, self.outputs) + + def generate_c_subgraph_init(self, indent: str) -> str: + init_template = string.Template( + "{.inputs = {&${input_array}[0], ${input_size}},\n" + " .outputs = {&${output_array}[0], ${output_size}},\n" + " .nodes = {&${node_array}[0], ${node_size}},\n" + " .tensors = {&${tensor_array}[0], ${tensor_size}},\n" + " .invoke = &${invoke}},") + return textwrap.indent( + init_template.substitute(input_array=self.inputs_array_name, + input_size=len(self.inputs), + output_array=self.outputs_array_name, + output_size=len(self.outputs), + node_array=self.nodes_array, + node_size=len(self.operators), + tensor_array=self.tensors_array, + tensor_size=len(self.tensors), + invoke=self.invoke_fn_name), indent) @property def tensors_array(self) -> str: diff --git a/codegen/runtime/micro_codegen_context.cc b/codegen/runtime/micro_codegen_context.cc new file mode 100644 index 00000000000..858c823c1f3 --- /dev/null +++ b/codegen/runtime/micro_codegen_context.cc @@ -0,0 +1,139 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "codegen/runtime/micro_codegen_context.h" + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/micro_log.h" + +namespace tflite { + +MicroCodegenContext::MicroCodegenContext(TfLiteContext* context, + Span subgraphs) + : context_(context), subgraphs_(subgraphs) {} + +void* MicroCodegenContext::GetScratchBuffer(int buffer_idx) { + // TODO(rjascani): Implement scratch buffers + return nullptr; +} + +TfLiteEvalTensor* MicroCodegenContext::GetEvalTensor(int tensor_idx) { + TFLITE_DCHECK(static_cast(tensor_idx) < + subgraphs_[current_subgraph_idx_].tensors.size()); + return &subgraphs_[current_subgraph_idx_].tensors[tensor_idx]; +} + +TfLiteStatus MicroCodegenContext::set_external_context( + void* external_context_payload) { + if (external_context_payload == nullptr || + external_context_payload_ != nullptr) { + MicroPrintf( + "Attempting to set external context to %x but it was %x already", + external_context_payload, external_context_payload_); + return kTfLiteError; + } + + external_context_payload_ = external_context_payload; + return kTfLiteOk; +} + +void* MicroCodegenContext::external_context() { + return external_context_payload_; +} + +MicroGraph& MicroCodegenContext::graph() { return *this; } + +void* MicroCodegenContext::AllocatePersistentBuffer(size_t) { + // Not allowed at Eval + TFLITE_ABORT; + return nullptr; +} + +TfLiteStatus MicroCodegenContext::RequestScratchBufferInArena(size_t, int*) { + // Not allowed at Eval + TFLITE_ABORT; + return kTfLiteError; +} + +TfLiteTensor* MicroCodegenContext::AllocateTempTfLiteTensor(int) { + // Not allowed at Eval + TFLITE_ABORT; + return nullptr; +} + +void MicroCodegenContext::DeallocateTempTfLiteTensor(TfLiteTensor*) { + // Not allowed at Eval + TFLITE_ABORT; +} + +uint8_t* MicroCodegenContext::AllocateTempBuffer(size_t, size_t) { + // Not allowed at Eval + TFLITE_ABORT; + return nullptr; +} + +void MicroCodegenContext::DeallocateTempBuffer(uint8_t*) { + // Not allowed at Eval + TFLITE_ABORT; +} + +TfLiteStatus MicroCodegenContext::InvokeSubgraph(int subgraph_idx) { + TF_LITE_ENSURE(context_, + static_cast(subgraph_idx) < subgraphs_.size()); + size_t previous_subgraph_idx = current_subgraph_idx_; + current_subgraph_idx_ = subgraph_idx; + TfLiteStatus status = + subgraphs_[subgraph_idx].invoke(context_, subgraphs_[subgraph_idx].nodes); + current_subgraph_idx_ = previous_subgraph_idx; + return status; +} + +size_t MicroCodegenContext::NumSubgraphInputs(int subgraph_idx) { + TFLITE_DCHECK(static_cast(subgraph_idx) < subgraphs_.size()); + return subgraphs_[subgraph_idx].inputs.size(); +} + +TfLiteEvalTensor* MicroCodegenContext::GetSubgraphInput(int subgraph_idx, + int input_idx) { + TFLITE_DCHECK(static_cast(subgraph_idx) < subgraphs_.size()); + TFLITE_DCHECK(static_cast(input_idx) < + subgraphs_[subgraph_idx].inputs.size()); + const size_t tensor_idx = subgraphs_[subgraph_idx].inputs[input_idx]; + return &subgraphs_[subgraph_idx].tensors[tensor_idx]; +} + +size_t MicroCodegenContext::NumSubgraphOutputs(int subgraph_idx) { + TFLITE_DCHECK(static_cast(subgraph_idx) < subgraphs_.size()); + return subgraphs_[subgraph_idx].outputs.size(); +} + +TfLiteEvalTensor* MicroCodegenContext::GetSubgraphOutput(int subgraph_idx, + int output_idx) { + TFLITE_DCHECK(static_cast(subgraph_idx) < subgraphs_.size()); + TFLITE_DCHECK(static_cast(output_idx) < + subgraphs_[subgraph_idx].outputs.size()); + const size_t tensor_idx = subgraphs_[subgraph_idx].outputs[output_idx]; + return &subgraphs_[subgraph_idx].tensors[tensor_idx]; +} + +int MicroCodegenContext::NumSubgraphs() { return subgraphs_.size(); } + +MicroResourceVariables* MicroCodegenContext::GetResourceVariables() { + return nullptr; +} + +} // namespace tflite diff --git a/codegen/runtime/micro_codegen_context.h b/codegen/runtime/micro_codegen_context.h new file mode 100644 index 00000000000..ca01a63bce4 --- /dev/null +++ b/codegen/runtime/micro_codegen_context.h @@ -0,0 +1,90 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef CODEGEN_RUNTIME_MICRO_CODEGEN_CONTEXT_H_ +#define CODEGEN_RUNTIME_MICRO_CODEGEN_CONTEXT_H_ + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/micro_graph.h" + +namespace tflite { + +// A poor man's std::span, we should consider using the Pigweed span instead. +template +class Span { + public: + constexpr Span(T* data, size_t size) noexcept : data_(data), size_(size) {} + + constexpr T& operator[](size_t idx) const noexcept { return *(data_ + idx); } + + constexpr T* data() const noexcept { return data_; } + constexpr size_t size() const noexcept { return size_; } + + private: + T* data_; + size_t size_; +}; + +struct Subgraph { + Span inputs; + Span outputs; + Span nodes; + Span tensors; + TfLiteStatus (*invoke)(TfLiteContext*, Span); +}; + +class MicroCodegenContext : public MicroContext, MicroGraph { + public: + MicroCodegenContext(TfLiteContext* context, Span subgraphs); + + ~MicroCodegenContext() = default; + + // MicroContext API + void* AllocatePersistentBuffer(size_t bytes) override; + TfLiteStatus RequestScratchBufferInArena(size_t bytes, + int* buffer_idx) override; + void* GetScratchBuffer(int buffer_idx) override; + TfLiteTensor* AllocateTempTfLiteTensor(int tensor_idx) override; + void DeallocateTempTfLiteTensor(TfLiteTensor* tensor) override; + uint8_t* AllocateTempBuffer(size_t size, size_t alignment) override; + void DeallocateTempBuffer(uint8_t* buffer) override; + TfLiteEvalTensor* GetEvalTensor(int tensor_idx) override; + TfLiteStatus set_external_context(void* external_context_payload) override; + void* external_context() override; + MicroGraph& graph() override; + + // MicroGraph API + TfLiteStatus InvokeSubgraph(int subgraph_idx) override; + size_t NumSubgraphInputs(int subgraph_idx) override; + TfLiteEvalTensor* GetSubgraphInput(int subgraph_idx, int input_idx) override; + size_t NumSubgraphOutputs(int subgraph_idx) override; + TfLiteEvalTensor* GetSubgraphOutput(int subgraph_idx, + int output_idx) override; + int NumSubgraphs() override; + MicroResourceVariables* GetResourceVariables() override; + + private: + TfLiteContext* context_; + Span subgraphs_; + size_t current_subgraph_idx_ = 0; + void* external_context_payload_ = nullptr; + + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // CODEGEN_RUNTIME_MICRO_CODEGEN_CONTEXT_H_ diff --git a/codegen/templates/inference.cc.mako b/codegen/templates/inference.cc.mako index cfb735417a0..cb6e59ad2d2 100644 --- a/codegen/templates/inference.cc.mako +++ b/codegen/templates/inference.cc.mako @@ -17,11 +17,14 @@ limitations under the License. #include "${header_file}" +#include "codegen/runtime/micro_codegen_context.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/micro_common.h" +#include "tensorflow/lite/micro/micro_context.h" namespace ${model_name} { namespace { @@ -44,21 +47,36 @@ TFLMInferenceRegistration op_table[OpCode::kCount] = { ${buffer.generate_c_buffer_array("")} % endfor % for subgraph in graph.subgraphs: +${subgraph.generate_c_input_array("")} + +${subgraph.generate_c_output_array("")} + ${subgraph.generate_c_node_data("")} ${subgraph.generate_c_tensor_data("")} % endfor - % if graph.needs_zero_length_int_array: + TfLiteIntArray zero_length_int_array = {}; % endif + +% for subgraph in graph.subgraphs: +${subgraph.generate_c_invoke("")} +% endfor + } // namespace -Model::Model() { - context_.impl_ = nullptr; +Model::Model() + : subgraphs_{ +%for subgraph in graph.subgraphs: +${subgraph.generate_c_subgraph_init(" ")} +%endfor + }, + micro_context_{&context_, {&subgraphs_[0], ${len(graph.subgraphs)}}} { + context_.impl_ = static_cast(µ_context_); context_.ReportError = nullptr; context_.GetTensor = nullptr; - context_.GetEvalTensor = nullptr; + context_.GetEvalTensor = tflite::MicroContextGetEvalTensor; context_.profiler = nullptr; context_.GetExternalContext = nullptr; context_.GetScratchBuffer = nullptr; @@ -70,13 +88,6 @@ ${subgraph.generate_c_tensor_init(" ")} % endfor } -TfLiteStatus Model::Invoke() { return InvokeSubgraph0(); } - -% for subgraph in graph.subgraphs: -TfLiteStatus Model::InvokeSubgraph${subgraph.index}() { -${subgraph.generate_c_invoke(" ")} - return kTfLiteOk; -} -% endfor +TfLiteStatus Model::Invoke() { return micro_context_.InvokeSubgraph(0); } } // namespace ${model_name} diff --git a/codegen/templates/inference.h.mako b/codegen/templates/inference.h.mako index b45c7cad4f2..5ab64e108c1 100644 --- a/codegen/templates/inference.h.mako +++ b/codegen/templates/inference.h.mako @@ -17,6 +17,7 @@ limitations under the License. #pragma once +#include "codegen/runtime/micro_codegen_context.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" @@ -29,11 +30,9 @@ class Model { TfLiteStatus Invoke(); private: -% for subgraph_idx in range(len(graph.subgraphs)): - TfLiteStatus InvokeSubgraph${subgraph_idx}(); -% endfor - TfLiteContext context_ = {}; + tflite::Subgraph subgraphs_[${len(graph.subgraphs)}]; + tflite::MicroCodegenContext micro_context_; % for subgraph in graph.subgraphs: TfLiteNode ${subgraph.nodes_array}[${len(subgraph.operators)}] = {}; % endfor diff --git a/codegen/utils.py b/codegen/utils.py index 8baa30860a9..c6c31c8ada2 100644 --- a/codegen/utils.py +++ b/codegen/utils.py @@ -14,8 +14,9 @@ # ============================================================================== """ Utility functions and classes for code generation. """ -from typing import Any, Generator, Iterable, List, Optional, Tuple +from typing import Any, Generator, Iterable, List, Optional, Sequence, Tuple import string +import textwrap import itertools @@ -43,6 +44,40 @@ def split_into_chunks( yield chunk +def generate_c_int_array(indent: str, int_type: str, name: str, + ints: Sequence[int]) -> str: + int_strs = ['{}'.format(i) for i in ints] + + # Try to do it on a single line first + single_line_array_template = string.Template( + "constexpr ${int_type} ${name}[${size}] = {${data}};") + single_line = textwrap.indent( + single_line_array_template.substitute(int_type=int_type, + name=name, + size=len(int_strs), + data=', '.join(int_strs)), indent) + + if len(single_line) < 81: + return single_line + + # Couldn't fit, so split it across multiple lines + multi_line_array_template = string.Template( + "constexpr ${int_type} ${name}[${size}] = {\n" + "${body}\n" + "};\n") + + lines = [] + for int_strs_for_line in split_into_chunks(int_strs, 12): + ints_segment = ', '.join(int_strs_for_line) + lines.append(f' {ints_segment},') + + return textwrap.indent( + multi_line_array_template.substitute(int_type=int_type, + name=name, + size=len(ints), + body='\n'.join(lines)), indent) + + class IntArray(object): """ A helper class for generating int arrays that can be used to provide the backing storage for a TfLiteIntArray. """