diff --git a/tensorflow/lite/micro/compression.h b/tensorflow/lite/micro/compression.h index cfc8a1c8e8..197d5cd963 100644 --- a/tensorflow/lite/micro/compression.h +++ b/tensorflow/lite/micro/compression.h @@ -26,7 +26,8 @@ namespace tflite { // Compressed tensors // -static constexpr const char* kCompressionMetadataString = "TFLM_COMPRESSION"; +static constexpr const char* kCompressionMetadataString = + "COMPRESSION_METADATA"; enum class CompressionScheme : uint8_t { kBinQuant, diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index cde1b55bb1..cfa3e8a507 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -13,11 +13,6 @@ flatbuffer_cc_library( srcs = ["metadata.fbs"], ) -flatbuffer_py_library( - name = "original_flatbuffer_py", - srcs = ["original.fbs"], -) - flatbuffer_py_library( name = "metadata_flatbuffer_py", srcs = ["metadata.fbs"], @@ -34,33 +29,6 @@ cc_test( size = "small", ) -py_binary( - name = "compress", - srcs = ["compress.py"], - deps = [ - "@absl_py//absl:app", - "@absl_py//absl/flags", - "@absl_py//absl/logging", - "@flatbuffers//:runtime_py", - "metadata_flatbuffer_py", - "//tensorflow/lite/python:schema_py", - requirement("bitarray"), - requirement("numpy"), - requirement("scikit-learn"), - ], -) - -py_binary( - name = "view", - srcs = [ - "view.py", - ], - deps = [ - "metadata_flatbuffer_py", - "//tensorflow/lite/python:schema_py", - ], -) - py_test( name = "metadata_test_py", main = "metadata_test.py", @@ -72,23 +40,3 @@ py_test( ], size = "small", ) - -py_test( - name = "original_test_py", - main = "original_test.py", - srcs = ["original_test.py"], - deps = [ - "original_flatbuffer_py", - "@flatbuffers//:runtime_py", - requirement("hexdump"), - ], - size = "small", -) - -genrule( - name = "hello_world_int8.compressed", - srcs = ["//tensorflow/lite/micro/examples/hello_world/models:hello_world_int8.tflite"], - outs = ["hello_world_int8.compressed.tflite"], - cmd = "$(location :compress) --input_model_path $< --output_model_path $@", - tools = [":compress"], -) diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py deleted file mode 100644 index 18834982f2..0000000000 --- a/tensorflow/lite/micro/compression/compress.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Reduces the number of weights in a .tflite model using various strategies.""" - -# Usage information: -# Default: -# `bazel run tensorflow/lite/micro/tools:compress -- \ -# --input_model_path=` \ -# --output_model_path=` - - -from tensorflow.lite.micro.compression import metadata_flatbuffer_py_generated as compression_schema -from tensorflow.lite.python import schema_py_generated as tflite_schema - -from absl import app -from absl import flags -from absl import logging -import bitarray -import bitarray.util -import numpy as np -import flatbuffers -import sklearn.cluster -import struct - - -_INPUT_MODEL_PATH = flags.DEFINE_string( - "input_model_path", - None, - ".tflite input model path", - required=True, -) - -_TEST_COMPRESSED_MODEL = flags.DEFINE_bool( - "test_compressed_model", - False, - "optional config to test models with random data and" - " report on the differences in output.", -) - -_OUTPUT_MODEL_PATH = flags.DEFINE_string( - "output_model_path", - None, - ".tflite output path. Leave blank if same as input+.compressed.tflite", -) - - -def read_model(path): - with open(path, 'rb') as file: - buffer = bytearray(file.read()) - return tflite_schema.ModelT.InitFromPackedBuf(buffer, 0) - - -def write_model(model, path): - builder = flatbuffers.Builder(32) - root = model.Pack(builder) - builder.Finish(root) - buffer: bytearray = builder.Output() - - with open(path, 'wb') as file: - file.write(buffer) - - -def pack_compression_metadata(m): - builder = flatbuffers.Builder(32) - root = m.Pack(builder) - builder.Finish(root) - buffer: bytearray = builder.Output() - return buffer - - -def pack_lut_indexes(indexes, bitwidth): - """Pack the sequence of integers given in `indexes` into bitwidth-wide fields - in a buffer, and return the buffer. Raise an OverflowError if any element - does not fit into a bitwidth-wide field. """ - ba = bitarray.bitarray(endian="big") - for i in indexes: - field = bitarray.util.int2ba(i, length=bitwidth, endian="big") - ba.extend(field) - return ba.tobytes() - - -def pack_lut_values(values, struct_format): - """Pack the `values` into a buffer of bytes, using a `struct_format` - character from the standard module `struct` to determine the type of values - and corresponding encoding into bytes. Always little-endian byte order. - """ - buffer = bytearray() - little_endian = "<" - packer = struct.Struct(little_endian + struct_format) - for v in values: - buffer.extend(packer.pack(v)) - return buffer - - -def unpack_buffer_values(data, struct_format): - little_endian = "<" - unpacker = struct.Struct(little_endian + struct_format) - values = [v[0] for v in unpacker.iter_unpack(bytes(data))] - return values - - -def tensor_type_to_struct_format(type): - m = { - tflite_schema.TensorType.INT8: "b", - tflite_schema.TensorType.INT16: "h", - tflite_schema.TensorType.FLOAT32: "f", - } - return m[type] - - -def bq(sequence, num_values): - """Quantize a sequence of integers, minimizing the total error using k-means - clustering. - - Parameters: - sequence :list - a sequence of integers to be quanized - num_values :int - the number of quantization levels - - Returns: - (indexes, values): a tuple with the list of indexes and list of values - """ - sequence = np.array(sequence).reshape(-1, 1) - kmeans = sklearn.cluster.KMeans(n_clusters=num_values, - random_state=0).fit(sequence) - values = kmeans.cluster_centers_.flatten() - values = np.round(values).astype(int).tolist() - indexes = kmeans.predict(sequence).tolist() - return (indexes, values) - - -def compress_tensor(subgraph_id, tensor_id, model): - subgraph = model.subgraphs[subgraph_id] - tensor = subgraph.tensors[tensor_id] - struct_format = tensor_type_to_struct_format(tensor.type) - buffer_id = tensor.buffer - buffer = model.buffers[buffer_id] - sequence = unpack_buffer_values(buffer.data, struct_format) - bitwidth = 2 - indexes, values = bq(sequence, 2 ** bitwidth) - - # append index buffer - buffer = tflite_schema.BufferT() - buffer.data = pack_lut_indexes(indexes, bitwidth) - model.buffers.append(buffer) - index_id = len(model.buffers) - 1 - - # append value buffer - buffer = tflite_schema.BufferT() - buffer.data = pack_lut_values(values, struct_format) - model.buffers.append(buffer) - value_id = len(model.buffers) - 1 - - # create metadata - lut_tensor = compression_schema.LutTensorT() - lut_tensor.subgraph = subgraph_id - lut_tensor.tensor = tensor_id - lut_tensor.indexBitwidth = bitwidth - lut_tensor.indexBuffer = index_id - lut_tensor.valueBuffer = value_id - - return lut_tensor - - -def compress_fully_connected(subgraph_id, operator_id, model): - # On a fully_connected operator, we compress the 2nd - subgraph = model.subgraphs[subgraph_id] - operator = subgraph.operators[operator_id] - tensor_id_2 = operator.inputs[1] - # tensor_id_3 = operator.inputs[2] - lut_tensor_2 = compress_tensor(subgraph_id, tensor_id_2, model) - # lut_tensor_3 = compress_tensor(subgraph_id, tensor_id_2, model) - return (lut_tensor_2,) - - -def get_opcode_compressions(model): - """Return a map of operator_code indexes to compression functions, for those - operators we wish to and know how to compress. - """ - compressable = {tflite_schema.BuiltinOperator.FULLY_CONNECTED: compress_fully_connected} - compressions = {} - for index, code in enumerate(model.operatorCodes): - if code.builtinCode in compressable: - compressions[index] = compressable[code.builtinCode] - return compressions - - -def compress(model): - # Walk op codes, identify those we compress, note index - # Walk operators, match op code indexes, note tensors to compress - # Walk those tensors, creating LUTs in buffers and metadata - - compressions = get_opcode_compressions(model) - - lut_tensors = [] - - for subgraph_id, subgraph in enumerate(model.subgraphs): - for operator_id, operator in enumerate(subgraph.operators): - fn = compressions.get(operator.opcodeIndex) - if fn is not None: - result = fn(subgraph_id, operator_id, model) - if result is not None: - lut_tensors.extend(result) - - compression_metadata = compression_schema.MetadataT() - compression_metadata.lutTensors = lut_tensors - - return compression_metadata - - -def main(_) -> None: - output_model_path = _OUTPUT_MODEL_PATH.value or ( - _INPUT_MODEL_PATH.value.split(".tflite")[0] + ".compressed.tflite") - logging.info("compressing %s to %s", _INPUT_MODEL_PATH.value, output_model_path) - - model = read_model(_INPUT_MODEL_PATH.value) - - compression_metadata = compress(model) - - buffer = tflite_schema.BufferT() - buffer.data = pack_compression_metadata(compression_metadata) - model.buffers.append(buffer) - - metadata = tflite_schema.MetadataT() - metadata.name = "COMPRESSION_METADATA" - metadata.buffer = len(model.buffers) - 1 - model.metadata.append(metadata) - - write_model(model, output_model_path) - - -if __name__ == "__main__": - app.run(main) diff --git a/tensorflow/lite/micro/compression/metadata.fbs b/tensorflow/lite/micro/compression/metadata.fbs index dcfb1ccafb..dbbe7b0e40 100644 --- a/tensorflow/lite/micro/compression/metadata.fbs +++ b/tensorflow/lite/micro/compression/metadata.fbs @@ -12,27 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Flatbuffer schema describing a TFLM compressed model. Use as the value for -// the key "TFLM_COMPRESSION" in the metadata table in a .tflite flatbuffer. - namespace tflite.micro.compression; table Metadata { - lut_tensors:[LutTensor]; // list of tensors that are compressed by LUT + // Compression data root, to be used in a tflite.Model.metadata field with + // the key "COMPRESSION_METADATA". + + subgraphs:[Subgraph]; // compression data indexed by subgraph index +} + +table Subgraph { + // Per-subgraph compression metadata. + + lut_tensors:[LutTensor]; + // ^ A list of tensors which are compressed using the + // (L)ook-(U)p-(T)able method. The indices of this vector are not + // significant. } -struct LutTensor { - subgraph:uint16; // the index of the subgraph - tensor:uint16; // the index of the tensor in its subgraph - index_bitwidth:uint8; // the bit-width of LUT indexes - index_buffer:uint16; // the index of the buffer containing LUT indexes - value_buffer:uint16; // the index of the buffer containing LUT values +table LutTensor { + // Look-Up-Table Tensor: a tensor representation where elements are + // compressed into indices into a table of values. The indices are unsigned + // integers, index_bitwidth-wide, in big-endian bit order, packed into the + // buffer identified by the corresponding tflite.Tensor's buffer field. The + // values are located in a newly-created buffer, encoded according to the + // tflite.Tensor.type. Tensors with multiple channels have distinct values + // tables for each channel, concatenated one after another in the buffer. + // An element's LUT index must be looked up in the value table for its + // channel. + + tensor:int; // index of the corresponding tflite.Tensor + value_buffer:uint; // index of the buffer containing LUT values + index_bitwidth:uint8; // bit-width of LUT indexes } -// Look-Up-Table tensors are encoded in two buffers: an index buffer and a -// value buffer. The indexes are unsigned integers packed into the index buffer -// in bitwidth-wide bit fields with a big-endian bit order. The data in the -// value buffer is encoded as usual according to the type of the tensor. -// Tensors with multiple channels have distinct values tables for each channel, -// concatinated into one value buffer. (Will elaborate this comment.) root_type Metadata; diff --git a/tensorflow/lite/micro/compression/metadata_generated.h b/tensorflow/lite/micro/compression/metadata_generated.h index eaa03cb21e..6b3af3b3e2 100644 --- a/tensorflow/lite/micro/compression/metadata_generated.h +++ b/tensorflow/lite/micro/compression/metadata_generated.h @@ -6,6 +6,13 @@ #include "flatbuffers/flatbuffers.h" +// Ensure the included flatbuffers.h is the same version as when this file was +// generated, otherwise it may not be compatible. +static_assert(FLATBUFFERS_VERSION_MAJOR == 23 && + FLATBUFFERS_VERSION_MINOR == 5 && + FLATBUFFERS_VERSION_REVISION == 26, + "Non-compatible flatbuffers version included"); + namespace tflite { namespace micro { namespace compression { @@ -13,131 +20,204 @@ namespace compression { struct Metadata; struct MetadataBuilder; +struct Subgraph; +struct SubgraphBuilder; + struct LutTensor; +struct LutTensorBuilder; -FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(2) LutTensor FLATBUFFERS_FINAL_CLASS { - private: - uint16_t subgraph_; - uint16_t tensor_; - uint8_t index_bitwidth_; - int8_t padding0__; - uint16_t index_buffer_; - uint16_t value_buffer_; - - public: - LutTensor() - : subgraph_(0), - tensor_(0), - index_bitwidth_(0), - padding0__(0), - index_buffer_(0), - value_buffer_(0) { - (void)padding0__; - } - LutTensor(uint16_t _subgraph, uint16_t _tensor, uint8_t _index_bitwidth, uint16_t _index_buffer, uint16_t _value_buffer) - : subgraph_(flatbuffers::EndianScalar(_subgraph)), - tensor_(flatbuffers::EndianScalar(_tensor)), - index_bitwidth_(flatbuffers::EndianScalar(_index_bitwidth)), - padding0__(0), - index_buffer_(flatbuffers::EndianScalar(_index_buffer)), - value_buffer_(flatbuffers::EndianScalar(_value_buffer)) { - } - uint16_t subgraph() const { - return flatbuffers::EndianScalar(subgraph_); - } - uint16_t tensor() const { - return flatbuffers::EndianScalar(tensor_); +struct Metadata FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef MetadataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SUBGRAPHS = 4 + }; + const ::flatbuffers::Vector<::flatbuffers::Offset> *subgraphs() const { + return GetPointer> *>(VT_SUBGRAPHS); } - uint8_t index_bitwidth() const { - return flatbuffers::EndianScalar(index_bitwidth_); + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_SUBGRAPHS) && + verifier.VerifyVector(subgraphs()) && + verifier.VerifyVectorOfTables(subgraphs()) && + verifier.EndTable(); } - uint16_t index_buffer() const { - return flatbuffers::EndianScalar(index_buffer_); +}; + +struct MetadataBuilder { + typedef Metadata Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_subgraphs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> subgraphs) { + fbb_.AddOffset(Metadata::VT_SUBGRAPHS, subgraphs); + } + explicit MetadataBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); } - uint16_t value_buffer() const { - return flatbuffers::EndianScalar(value_buffer_); + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; } }; -FLATBUFFERS_STRUCT_END(LutTensor, 10); -struct Metadata FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef MetadataBuilder Builder; +inline ::flatbuffers::Offset CreateMetadata( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> subgraphs = 0) { + MetadataBuilder builder_(_fbb); + builder_.add_subgraphs(subgraphs); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateMetadataDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<::flatbuffers::Offset> *subgraphs = nullptr) { + auto subgraphs__ = subgraphs ? _fbb.CreateVector<::flatbuffers::Offset>(*subgraphs) : 0; + return tflite::micro::compression::CreateMetadata( + _fbb, + subgraphs__); +} + +struct Subgraph FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SubgraphBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_LUT_TENSORS = 4 }; - const flatbuffers::Vector *lut_tensors() const { - return GetPointer *>(VT_LUT_TENSORS); + const ::flatbuffers::Vector<::flatbuffers::Offset> *lut_tensors() const { + return GetPointer> *>(VT_LUT_TENSORS); } - bool Verify(flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_LUT_TENSORS) && verifier.VerifyVector(lut_tensors()) && + verifier.VerifyVectorOfTables(lut_tensors()) && verifier.EndTable(); } }; -struct MetadataBuilder { - typedef Metadata Table; - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_lut_tensors(flatbuffers::Offset> lut_tensors) { - fbb_.AddOffset(Metadata::VT_LUT_TENSORS, lut_tensors); +struct SubgraphBuilder { + typedef Subgraph Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_lut_tensors(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> lut_tensors) { + fbb_.AddOffset(Subgraph::VT_LUT_TENSORS, lut_tensors); } - explicit MetadataBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit SubgraphBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - flatbuffers::Offset Finish() { + ::flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = ::flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateMetadata( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset> lut_tensors = 0) { - MetadataBuilder builder_(_fbb); +inline ::flatbuffers::Offset CreateSubgraph( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> lut_tensors = 0) { + SubgraphBuilder builder_(_fbb); builder_.add_lut_tensors(lut_tensors); return builder_.Finish(); } -inline flatbuffers::Offset CreateMetadataDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *lut_tensors = nullptr) { - auto lut_tensors__ = lut_tensors ? _fbb.CreateVectorOfStructs(*lut_tensors) : 0; - return tflite::micro::compression::CreateMetadata( +inline ::flatbuffers::Offset CreateSubgraphDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<::flatbuffers::Offset> *lut_tensors = nullptr) { + auto lut_tensors__ = lut_tensors ? _fbb.CreateVector<::flatbuffers::Offset>(*lut_tensors) : 0; + return tflite::micro::compression::CreateSubgraph( _fbb, lut_tensors__); } +struct LutTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef LutTensorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TENSOR = 4, + VT_VALUE_BUFFER = 6, + VT_INDEX_BITWIDTH = 8 + }; + int32_t tensor() const { + return GetField(VT_TENSOR, 0); + } + uint32_t value_buffer() const { + return GetField(VT_VALUE_BUFFER, 0); + } + uint8_t index_bitwidth() const { + return GetField(VT_INDEX_BITWIDTH, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_TENSOR, 4) && + VerifyField(verifier, VT_VALUE_BUFFER, 4) && + VerifyField(verifier, VT_INDEX_BITWIDTH, 1) && + verifier.EndTable(); + } +}; + +struct LutTensorBuilder { + typedef LutTensor Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_tensor(int32_t tensor) { + fbb_.AddElement(LutTensor::VT_TENSOR, tensor, 0); + } + void add_value_buffer(uint32_t value_buffer) { + fbb_.AddElement(LutTensor::VT_VALUE_BUFFER, value_buffer, 0); + } + void add_index_bitwidth(uint8_t index_bitwidth) { + fbb_.AddElement(LutTensor::VT_INDEX_BITWIDTH, index_bitwidth, 0); + } + explicit LutTensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateLutTensor( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t tensor = 0, + uint32_t value_buffer = 0, + uint8_t index_bitwidth = 0) { + LutTensorBuilder builder_(_fbb); + builder_.add_value_buffer(value_buffer); + builder_.add_tensor(tensor); + builder_.add_index_bitwidth(index_bitwidth); + return builder_.Finish(); +} + inline const tflite::micro::compression::Metadata *GetMetadata(const void *buf) { - return flatbuffers::GetRoot(buf); + return ::flatbuffers::GetRoot(buf); } inline const tflite::micro::compression::Metadata *GetSizePrefixedMetadata(const void *buf) { - return flatbuffers::GetSizePrefixedRoot(buf); + return ::flatbuffers::GetSizePrefixedRoot(buf); } inline bool VerifyMetadataBuffer( - flatbuffers::Verifier &verifier) { + ::flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } inline bool VerifySizePrefixedMetadataBuffer( - flatbuffers::Verifier &verifier) { + ::flatbuffers::Verifier &verifier) { return verifier.VerifySizePrefixedBuffer(nullptr); } inline void FinishMetadataBuffer( - flatbuffers::FlatBufferBuilder &fbb, - flatbuffers::Offset root) { + ::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::Offset root) { fbb.Finish(root); } inline void FinishSizePrefixedMetadataBuffer( - flatbuffers::FlatBufferBuilder &fbb, - flatbuffers::Offset root) { + ::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::Offset root) { fbb.FinishSizePrefixed(root); } diff --git a/tensorflow/lite/micro/compression/metadata_test.cc b/tensorflow/lite/micro/compression/metadata_test.cc index 74b567c7d1..dd0575fa68 100644 --- a/tensorflow/lite/micro/compression/metadata_test.cc +++ b/tensorflow/lite/micro/compression/metadata_test.cc @@ -21,51 +21,57 @@ limitations under the License. */ #include "metadata_generated.h" #include "tensorflow/lite/micro/hexdump.h" -using tflite::micro::compression::LutTensor; using tflite::micro::compression::Metadata; using tflite::micro::compression::MetadataT; +using tflite::micro::compression::Subgraph; +using tflite::micro::compression::SubgraphT; +using tflite::micro::compression::LutTensor; +using tflite::micro::compression::LutTensorT; -bool operator==(const LutTensor& a, const LutTensor& b) { +bool operator==(const LutTensorT& a, const LutTensor& b) { return - a.subgraph() == b.subgraph() && - a.tensor() == b.tensor() && - a.index_bitwidth() == b.index_bitwidth() && - a.index_buffer() == b.index_buffer() && - a.value_buffer() == b.value_buffer(); + a.tensor == b.tensor() && + a.value_buffer == b.value_buffer() && + a.index_bitwidth == b.index_bitwidth(); } int main(int argc, char* argv[]) { - const LutTensor lut_tensor0 { - 0, // subgraph - 127, // tensor - 2, // index_bitwidth - 128, // index_buffer - 129, // value_buffer - }; - const LutTensor lut_tensor1 { - 1, // subgraph - 164, // tensor - 2, // index_bitwidth - 136, // index_buffer - 129, // value_buffer - }; - MetadataT metadata; - metadata.lut_tensors = {lut_tensor0, lut_tensor1}; + // Create these objects on the stack and copy them into the subgraph's vector + // later, so that we can compare to these objects to what we read from the + // flatbuffer later. + LutTensorT lut_tensor0; + lut_tensor0.tensor = 63; + lut_tensor0.value_buffer = 128; + lut_tensor0.index_bitwidth = 2; + + LutTensorT lut_tensor1; + lut_tensor1.tensor = 64; + lut_tensor1.value_buffer = 129; + lut_tensor1.index_bitwidth = 4; + + auto subgraph0 = std::make_unique(); + subgraph0->lut_tensors.push_back(std::make_unique(lut_tensor0)); + subgraph0->lut_tensors.push_back(std::make_unique(lut_tensor1)); + + auto metadata = std::make_unique(); + metadata->subgraphs.push_back(std::move(subgraph0)); flatbuffers::FlatBufferBuilder builder; - auto root = Metadata::Pack(builder, &metadata); + auto root = Metadata::Pack(builder, metadata.get()); builder.Finish(root); const uint8_t* buffer = builder.GetBufferPointer(); + const size_t buffer_size = builder.GetSize(); tflite::hexdump( - {reinterpret_cast(buffer), builder.GetSize()}); - std::cout << "length: " << builder.GetSize() << "\n"; - - auto readback = tflite::micro::compression::GetMetadata(buffer); - auto& read_lut_tensor0 = *readback->lut_tensors()->Get(0); - auto& read_lut_tensor1 = *readback->lut_tensors()->Get(1); - assert(read_lut_tensor0 == lut_tensor0); - assert(read_lut_tensor1 == lut_tensor1); + {reinterpret_cast(buffer), buffer_size}); + std::cout << "length: " << buffer_size << "\n"; + + const Metadata* read_metadata = tflite::micro::compression::GetMetadata(buffer); + const Subgraph* read_subgraph0 = read_metadata->subgraphs()->Get(0); + const LutTensor* read_lut_tensor0 = read_subgraph0->lut_tensors()->Get(0); + const LutTensor* read_lut_tensor1 = read_subgraph0->lut_tensors()->Get(1); + assert(lut_tensor0 == *read_lut_tensor0); + assert(lut_tensor1 == *read_lut_tensor1); return 0; } diff --git a/tensorflow/lite/micro/compression/metadata_test.py b/tensorflow/lite/micro/compression/metadata_test.py index 3d954154b8..be3daa09d3 100644 --- a/tensorflow/lite/micro/compression/metadata_test.py +++ b/tensorflow/lite/micro/compression/metadata_test.py @@ -28,21 +28,20 @@ def main(): # The classes with a `T` suffix provide an object-oriented representation of # the object tree in the flatbuffer using native data structures. lut_tensor0 = schema.LutTensorT() - lut_tensor0.subgraph = 1 - lut_tensor0.tensor = 127 + lut_tensor0.tensor = 63 + lut_tensor0.valueBuffer = 128 lut_tensor0.indexBitwidth = 2 - lut_tensor0.indexBuffer = 128 - lut_tensor0.valueBuffer = 129 lut_tensor1 = schema.LutTensorT() - lut_tensor1.subgraph = 1 - lut_tensor1.tensor = 164 - lut_tensor1.indexBitwidth = 2 - lut_tensor1.indexBuffer = 136 + lut_tensor1.tensor = 64 lut_tensor1.valueBuffer = 129 + lut_tensor1.indexBitwidth = 4 + + subgraph0 = schema.SubgraphT() + subgraph0.lutTensors = [lut_tensor0, lut_tensor1] metadata = schema.MetadataT() - metadata.lutTensors = [lut_tensor0, lut_tensor1] + metadata.subgraphs = [subgraph0] # Build the flatbuffer itself using the flatbuffers runtime module. builder = flatbuffers.Builder(32) @@ -56,9 +55,11 @@ def main(): def attrs_equal(a, b): return all(vars(a)[key] == vars(b)[key] for key in vars(a)) - readback = schema.MetadataT.InitFromPackedBuf(buffer, 0) - assert attrs_equal(readback.lutTensors[0], lut_tensor0) - assert attrs_equal(readback.lutTensors[1], lut_tensor1) + read_metadata = schema.MetadataT.InitFromPackedBuf(buffer, 0) + read_subgraph0 = read_metadata.subgraphs[0] + + assert attrs_equal(read_subgraph0.lutTensors[0], lut_tensor0) + assert attrs_equal(read_subgraph0.lutTensors[1], lut_tensor1) sys.exit() diff --git a/tensorflow/lite/micro/compression/original.fbs b/tensorflow/lite/micro/compression/original.fbs deleted file mode 100644 index 3a05a6cd4f..0000000000 --- a/tensorflow/lite/micro/compression/original.fbs +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - - -namespace tflite.micro; - -table ValuesInt8 { - values:[int8]; -} - -table ValuesInt16 { - values:[int16]; -} - -table ValuesInt32 { - values:[int32]; -} - -table ValuesInt64 { - values:[int64]; -} - -table ValuesFloat32 { - values:[float32]; -} - -union ValuesUnion { - ValuesFloat32, - ValuesInt8, - ValuesInt16, - ValuesInt32, - ValuesInt64 -} - -table Values { - values:ValuesUnion; -} - -table BinQuantBufferOptions { - value_table_index:int; - compressed_bit_width:uint8; // Should be 2 or 4 -} - -union CompressedBufferOptions { - BinQuantBufferOptions, - // HuffmanBufferOptions, // Future -} - -table CompressedBuffer { - buffer_index:int; // Buffer index from the top-level Model buffer vector - options:CompressedBufferOptions; -} - -table BinQuantCompression { - version:uint8; - // For a given value table, if the corresponding buffer was per-tensor quantized, there should be 4 or 16 elements (2 bit or 4 bit indexes). - // If the buffer was per-channel quantized, there should be 4/16 x number of channels elements. These will be laid out in the table as: - // [c0v0, c0v1, c0v2, c0v3, c1v0, c1v1, ... cNv3] - value_tables:[Values]; -} - -table CompressionMetadata { - // List of compressed buffers - buffers:[CompressedBuffer]; - - // (Optional) Model-wide Bin & Quant compression parameters. Only needed if a - // CompressedBuffer contains BinQuantBufferOptions. - bin_quant_compression:BinQuantCompression; -} - -root_type CompressionMetadata; diff --git a/tensorflow/lite/micro/compression/original_test.py b/tensorflow/lite/micro/compression/original_test.py deleted file mode 100644 index edc8ad4d11..0000000000 --- a/tensorflow/lite/micro/compression/original_test.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Test validity of the flatbuffer schema and illustrate use of the flatbuffer -# machinery with Python - -import sys -import hexdump -import flatbuffers - -# `.*_generated` is the name of the module created by the Bazel rule -# `flatbuffer_py_library' based on the schema. -from tensorflow.lite.micro.compression import original_flatbuffer_py_generated as schema - - -def main(): - # The classes with a `T` suffix provide an object-oriented representation of - # the object tree in the flatbuffer using native data structures. - bq0_options = schema.BinQuantBufferOptionsT() - bq0_options.valueTableIndex = 0 - bq0_options.compressedBitWidth = 2 - - bq1_options = schema.BinQuantBufferOptionsT() - bq1_options.valueTableIndex = 1 - bq1_options.compressedBitBidth = 4 - - buffer0 = schema.CompressedBufferT() - buffer0.bufferIndex = 0 - buffer0.options = bq0_options - buffer0.optionsType = schema.CompressedBufferOptions.BinQuantBufferOptions - - buffer1 = schema.CompressedBufferT() - buffer1.bufferIndex = 1 - buffer1.options = bq1_options - buffer1.optionsType = schema.CompressedBufferOptions.BinQuantBufferOptions - - valuesInt8 = schema.ValuesInt8T() - valuesInt8.values = [65] - values0 = schema.ValuesT() - values0.values = valuesInt8 - values0.values.Type = schema.ValuesUnion.ValuesInt8 - - bq_compression = schema.BinQuantCompressionT() - bq_compression.valueTables = [values0] - - metadata = schema.CompressionMetadataT() - metadata.buffers = [buffer0, buffer1] - metadata.binQuantCompression = bq_compression - - # Build the flatbuffer itself using the flatbuffers runtime module. - builder = flatbuffers.Builder(32) - root = metadata.Pack(builder) - builder.Finish(root) - buffer: bytearray = builder.Output() - - print(hexdump.hexdump(buffer, result='return')) - print(f"length: {len(buffer)}") - - readback = schema.CompressionMetadataT.InitFromPackedBuf(buffer, 0) - - sys.exit() - - -if __name__ == "__main__": - main() diff --git a/tensorflow/lite/micro/compression/view.py b/tensorflow/lite/micro/compression/view.py deleted file mode 100644 index 55c4255ede..0000000000 --- a/tensorflow/lite/micro/compression/view.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pprint - -from tensorflow.lite.micro.compression import metadata_flatbuffer_py_generated as compression_schema -from tensorflow.lite.python import schema_py_generated as tflite_schema - - -def read_model(path): - with open(path, 'rb') as file: - buffer = bytearray(file.read()) - return tflite_schema.ModelT.InitFromPackedBuf(buffer, 0) - - -def unpack_list(source): - result = [] - for index, s in enumerate(source): - d = {"_index": index} | vars(s) - result.append(d) - return result - - -def unpack_operators(operators): - result = [] - for index, o in enumerate(operators): - d = {"_index": index, - "opcode_index": o.opcodeIndex, - "inputs": unpack_array(o.inputs), - "outputs": unpack_array(o.outputs), - } - result.append(d) - return result - - -def unpack_TensorType(type): - attrs = [attr for attr in dir(tflite_schema.TensorType) if not - attr.startswith("__")] - lut = {getattr(tflite_schema.TensorType, attr): attr for attr in attrs} - return lut[type] - - -def unpack_tensors(tensors): - result = [] - for index, t in enumerate(tensors): - d = {"_index": index, - "name": t.name.decode("utf-8"), - "type": unpack_TensorType(t.type), - "shape": unpack_array(t.shape), - "quantization": [unpack_array(t.quantization.scale), unpack_array(t.quantization.zeroPoint)], - "buffer": t.buffer, - } - result.append(d) - return result - - -def unpack_subgraphs(subgraphs): - result = [] - for index, s in enumerate(subgraphs): - d = {"_index": index, - "name": s.name, - # "inputs": s.inputs, - # "outputs": s.outputs, - "operators": unpack_operators(s.operators), - "tensors": unpack_tensors(s.tensors), - } - result.append(d) - return result - - -def unpack_metadata(metadata): - return [{"name": m.name.decode("utf-8"), "buffer": m.buffer} for m in - metadata] - - -def unpack_compression_metadata(buffer): - metadata = compression_schema.MetadataT.InitFromPackedBuf(buffer, 0) - result = [] - for index, t in enumerate(metadata.lutTensors): - d = {"_index": index, - "subgraph": t.subgraph, - "tensor": t.tensor, - "indexBitwidth": t.indexBitwidth, - "indexBuffer": t.indexBuffer, - "valueBuffer": t.valueBuffer, - } - result.append(d) - return {"lut_tensors": result} - - -def unpack_array(a): - try: - # Avoid printing as numpy arrays if possible. The pprint module does not - # format them well. - a = a.tolist() - except AttributeError: - pass - return a - - -def unpack_buffers(buffers, compression_metadata=None): - result = [] - for index, b in enumerate(buffers): - d = {"_index": index} - d = d | {"data": unpack_array(b.data)} - if index == compression_metadata: d = d | {"_compression_metadata_decoded": - unpack_compression_metadata(bytes(b.data))} - result.append(d) - return result - - -def get_compression_metadata_buffer(model): - # Return the metadata buffer data or None - for item in model.metadata: - if item.name.decode("utf-8") == "COMPRESSION_METADATA": - return item.buffer - else: - return None - - -def print_model(model, format=None): - output = { - "description": model.description.decode("utf-8"), - "version": model.version, - "operator_codes": unpack_list(model.operatorCodes), - "metadata": unpack_metadata(model.metadata), - "subgraphs": unpack_subgraphs(model.subgraphs), - "buffers": unpack_buffers(model.buffers, - get_compression_metadata_buffer(model)), - } - - pprint.pprint(output, width=90, sort_dicts=False, compact=True) - - -def main(argv=None): - filename = argv[1] - model = read_model(filename) - print_model(model) - - -if __name__ == "__main__": - import sys - main(sys.argv) diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc index 5caefa3476..c83a009b4b 100644 --- a/tensorflow/lite/micro/micro_allocator.cc +++ b/tensorflow/lite/micro/micro_allocator.cc @@ -418,16 +418,12 @@ const tflite::micro::compression::Metadata* GetCompressionMetadata( } TfLiteStatus InitializeCompressionTensorDataFromFlatbuffer( - const Model& model, const tflite::micro::compression::LutTensor& lut_tensor, + const Model& model, const size_t subgraph_index, + const tflite::micro::compression::LutTensor& lut_tensor, CompressionTensorData* ctd) { + // TODO(ddavis-2015): support multiple compression schemes ctd->scheme = CompressionScheme::kBinQuant; - const size_t subgraph_index = lut_tensor.subgraph(); - if (subgraph_index >= model.subgraphs()->size()) { - MicroPrintf("Compression: invalid subgraph index %u in LutTensor", - subgraph_index); - return kTfLiteError; - } const size_t tensor_index = lut_tensor.tensor(); auto tensors = model.subgraphs()->Get(subgraph_index)->tensors(); if (tensor_index >= tensors->size()) { @@ -461,11 +457,6 @@ TfLiteStatus InitializeCompressionTensorDataFromFlatbuffer( MicroPrintf("Compression: scalar tensors not supported"); return kTfLiteError; } - if (tensor->buffer() != lut_tensor.index_buffer()) { - MicroPrintf("Compression: mismatched index_buffer %u != %u in LutTensor", - lut_tensor.index_buffer(), tensor->buffer()); - return kTfLiteError; - } TfLiteType tensor_type = kTfLiteNoType; TfLiteStatus status = ConvertTensorType(tensor->type(), &tensor_type); if (status != kTfLiteOk) { @@ -932,83 +923,96 @@ TfLiteStatus MicroAllocator::AllocateCompressedTensorsList( // no compression metadata is available return kTfLiteOk; } - if (compression_metadata->lut_tensors() == nullptr) { - MicroPrintf("Compression: invalid LutTensor vector"); + if (compression_metadata->subgraphs() == nullptr) { + MicroPrintf("Compression: invalid Subgraph vector"); return kTfLiteError; } - if (compression_metadata->lut_tensors()->size() == 0) { - MicroPrintf("Compression: zero length LutTensor vector"); + if (compression_metadata->subgraphs()->size() == 0) { + MicroPrintf("Compression: zero length Subgraph vector"); return kTfLiteError; } - for (size_t lut_tensors_index = 0; - lut_tensors_index < compression_metadata->lut_tensors()->size(); - lut_tensors_index++) { - auto lut_tensor = - compression_metadata->lut_tensors()->Get(lut_tensors_index); - - CompressionTensorData* ctd = reinterpret_cast( - persistent_buffer_allocator_->AllocatePersistentBuffer( - sizeof(CompressionTensorData), alignof(CompressionTensorData))); - if (ctd == nullptr) { - MicroPrintf( - "Compressions: failed to allocate memory for CompressionTensorData, " - "%d bytes required", - sizeof(CompressionTensorData)); - return kTfLiteError; - } + for (size_t subgraph_index = 0; + subgraph_index < compression_metadata->subgraphs()->size(); + subgraph_index++) { + auto subgraph = compression_metadata->subgraphs()->Get(subgraph_index); - LookupTableData* lut_table = reinterpret_cast( - persistent_buffer_allocator_->AllocatePersistentBuffer( - sizeof(LookupTableData), alignof(LookupTableData))); - if (lut_table == nullptr) { - MicroPrintf( - "Compressions: failed to allocate memory for LookupTableData, " - "%d bytes required", - sizeof(LookupTableData)); + if (subgraph->lut_tensors() == nullptr) { + MicroPrintf("Compression: invalid LutTensor vector"); return kTfLiteError; } - ctd->data.lut_data = lut_table; - - TfLiteStatus status = - internal::InitializeCompressionTensorDataFromFlatbuffer( - *model, *lut_tensor, ctd); - if (status != kTfLiteOk) { - MicroPrintf("Compression: failed to initialize data for LutTensor %u", - lut_tensors_index); + if (subgraph->lut_tensors()->size() == 0) { + MicroPrintf("Compression: zero length LutTensor vector"); return kTfLiteError; } - const size_t subgraph_index = lut_tensor->subgraph(); - if (subgraph_allocations[subgraph_index].compressed.tensors == nullptr) { - size_t alloc_count = - model->subgraphs()->Get(subgraph_index)->tensors()->size(); - const CompressionTensorData** tensors = - reinterpret_cast( - persistent_buffer_allocator_->AllocatePersistentBuffer( - sizeof(CompressionTensorData*) * alloc_count, - alignof(CompressionTensorData*))); - if (tensors == nullptr) { + for (size_t lut_tensors_index = 0; + lut_tensors_index < subgraph->lut_tensors()->size(); + lut_tensors_index++) { + auto lut_tensor = subgraph->lut_tensors()->Get(lut_tensors_index); + + CompressionTensorData* ctd = reinterpret_cast( + persistent_buffer_allocator_->AllocatePersistentBuffer( + sizeof(CompressionTensorData), alignof(CompressionTensorData))); + if (ctd == nullptr) { MicroPrintf( - "Compression: failed to allocate memory for compression tensor " - "list, %d bytes required", - sizeof(CompressionTensorData*) * alloc_count); + "Compressions: failed to allocate memory for " + "CompressionTensorData, %d bytes required", + sizeof(CompressionTensorData)); return kTfLiteError; } - subgraph_allocations[subgraph_index].compressed.tensors = tensors; - std::fill(tensors, tensors + alloc_count, nullptr); - } + LookupTableData* lut_table = reinterpret_cast( + persistent_buffer_allocator_->AllocatePersistentBuffer( + sizeof(LookupTableData), alignof(LookupTableData))); + if (lut_table == nullptr) { + MicroPrintf( + "Compressions: failed to allocate memory for LookupTableData, " + "%d bytes required", + sizeof(LookupTableData)); + return kTfLiteError; + } + ctd->data.lut_data = lut_table; - const size_t tensor_index = lut_tensor->tensor(); - if (subgraph_allocations[subgraph_index].compressed.tensors[tensor_index] != - nullptr) { - MicroPrintf("Compression: duplicate LutTensor subgraph %u tensor %u", - subgraph_index, tensor_index); - return kTfLiteError; - } else { - subgraph_allocations[subgraph_index].compressed.tensors[tensor_index] = - ctd; + TfLiteStatus status = + internal::InitializeCompressionTensorDataFromFlatbuffer( + *model, subgraph_index, *lut_tensor, ctd); + if (status != kTfLiteOk) { + MicroPrintf("Compression: failed to initialize data for LutTensor %u", + lut_tensors_index); + return kTfLiteError; + } + + if (subgraph_allocations[subgraph_index].compressed.tensors == nullptr) { + size_t alloc_count = + model->subgraphs()->Get(subgraph_index)->tensors()->size(); + const CompressionTensorData** tensors = + reinterpret_cast( + persistent_buffer_allocator_->AllocatePersistentBuffer( + sizeof(CompressionTensorData*) * alloc_count, + alignof(CompressionTensorData*))); + if (tensors == nullptr) { + MicroPrintf( + "Compression: failed to allocate memory for compression tensor " + "list, %d bytes required", + sizeof(CompressionTensorData*) * alloc_count); + return kTfLiteError; + } + + subgraph_allocations[subgraph_index].compressed.tensors = tensors; + std::fill(tensors, tensors + alloc_count, nullptr); + } + + const size_t tensor_index = lut_tensor->tensor(); + if (subgraph_allocations[subgraph_index] + .compressed.tensors[tensor_index] != nullptr) { + MicroPrintf("Compression: duplicate LutTensor subgraph %u tensor %u", + subgraph_index, tensor_index); + return kTfLiteError; + } else { + subgraph_allocations[subgraph_index].compressed.tensors[tensor_index] = + ctd; + } } } diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc index aeb9a439ec..33535ec866 100644 --- a/tensorflow/lite/micro/test_helpers.cc +++ b/tensorflow/lite/micro/test_helpers.cc @@ -587,16 +587,20 @@ const Model* BuildSimpleMockModel() { #ifdef USE_TFLM_COMPRESSION -const flatbuffers::span BuildLutMetadata( - const std::initializer_list& - lut_tensor_structs) { +const flatbuffers::span BuildLutMetadata(uint tensor_index, + uint value_table_buffer_index, + uint bit_width) { using flatbuffers::Offset; namespace compression = tflite::micro::compression; flatbuffers::FlatBufferBuilder* builder = BuilderInstance(); - auto lut_tensors = builder->CreateVectorOfStructs(lut_tensor_structs.begin(), - lut_tensor_structs.size()); - auto metadata = compression::CreateMetadata(*builder, lut_tensors); + + auto lut_tensor = compression::CreateLutTensor( + *builder, tensor_index, value_table_buffer_index, bit_width); + auto subgraph = compression::CreateSubgraph( + *builder, builder->CreateVector(&lut_tensor, 1)); + auto metadata = compression::CreateMetadata( + *builder, builder->CreateVector(&subgraph, 1)); compression::FinishMetadataBuffer(*builder, metadata); return builder->GetBufferSpan(); } @@ -612,14 +616,10 @@ const Model* BuildSimpleMockModelCompressed() { // constexpr uint kInputTensor = 0; constexpr uint kWeightsTensor = 1; // constexpr uint kOutputTensor = 2; - constexpr uint kSubgraphIndex = 0; constexpr uint kCompressedBitWidth = 4; - const std::initializer_list lut_tensors = { - LutTensor(kSubgraphIndex, kWeightsTensor, kCompressedBitWidth, - kWeightsBuffer, kValueTableBuffer), - }; - auto lut_tensors_span = BuildLutMetadata(lut_tensors); + auto lut_tensors_span = + BuildLutMetadata(kWeightsTensor, kValueTableBuffer, kCompressedBitWidth); flatbuffers::FlatBufferBuilder* builder = BuilderInstance();