Skip to content

Commit

Permalink
update the tensorflow.lite to use ai-edge-litert for all python based…
Browse files Browse the repository at this point in the history
… scripts
  • Loading branch information
rascani committed Sep 26, 2024
1 parent c9212e2 commit 420e0b7
Show file tree
Hide file tree
Showing 11 changed files with 26 additions and 11 deletions.
1 change: 1 addition & 0 deletions python/tflite_micro/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ py_test(
],
deps = [
":runtime",
requirement("ai-edge-litert"),
requirement("numpy"),
requirement("tensorflow"),
"//tensorflow/lite/micro/examples/recipes:add_four_numbers",
Expand Down
5 changes: 3 additions & 2 deletions python/tflite_micro/runtime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import numpy as np
import tensorflow as tf

from ai_edge_litert import interpreter as litert_interpreter
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tflite_micro.python.tflite_micro import runtime
Expand Down Expand Up @@ -199,10 +200,10 @@ def testCompareWithTFLite(self):
tflm_interpreter = runtime.Interpreter.from_bytes(model_data)

# TFLite interpreter
tflite_interpreter = tf.lite.Interpreter(
tflite_interpreter = litert_interpreter.Interpreter(
model_content=model_data,
experimental_op_resolver_type=\
tf.lite.experimental.OpResolverType.BUILTIN_REF)
litert_interpreter.OpResolverType.BUILTIN_REF)
tflite_interpreter.allocate_tensors()
tflite_output_details = tflite_interpreter.get_output_details()[0]
tflite_input_details = tflite_interpreter.get_input_details()[0]
Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/micro/examples/hello_world/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ py_binary(
"@absl_py//absl:app",
"@absl_py//absl/flags",
"@absl_py//absl/logging",
requirement("ai-edge-litert"),
requirement("numpy"),
requirement("tensorflow"),
"//python/tflite_micro:runtime",
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/lite/micro/examples/hello_world/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tensorflow as tf
from absl import app
from absl import flags
from ai_edge_litert import interpreter as litert_interpreter
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.platform import resource_loader
Expand Down Expand Up @@ -92,9 +93,9 @@ def get_tflm_prediction(model_path, x_values):
# returns the prediction of the interpreter.
def get_tflite_prediction(model_path, x_values):
# TFLite interpreter
tflite_interpreter = tf.lite.Interpreter(
tflite_interpreter = litert_interpreter.Interpreter(
model_path=model_path,
experimental_op_resolver_type=tf.lite.experimental.OpResolverType.
experimental_op_resolver_type=litert_interpreter.OpResolverType.
BUILTIN_REF,
)
tflite_interpreter.allocate_tensors()
Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/micro/examples/mnist_lstm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ py_binary(
srcs = ["train.py"],
srcs_version = "PY3",
deps = [
requirement("ai-edge-litert"),
requirement("numpy"),
requirement("tensorflow"),
],
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/lite/micro/examples/mnist_lstm/evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import tensorflow as tf

from ai_edge_litert import interpreter as litert_interpreter
from tensorflow.python.framework import test_util
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
Expand All @@ -43,10 +44,10 @@ def testInputErrHandling(self):
evaluate.predict_image(self.tflm_interpreter, wrong_size_image_path)

def testCompareWithTFLite(self):
tflite_interpreter = tf.lite.Interpreter(
tflite_interpreter = litert_interpreter.Interpreter(
model_path=self.model_path,
experimental_op_resolver_type=\
tf.lite.experimental.OpResolverType.BUILTIN_REF)
litert_interpreter.OpResolverType.BUILTIN_REF)
tflite_interpreter.allocate_tensors()
tflite_output_details = tflite_interpreter.get_output_details()[0]
tflite_input_details = tflite_interpreter.get_input_details()[0]
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/lite/micro/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ py_library(
srcs_version = "PY3",
visibility = ["//:__subpackages__"],
deps = [
requirement("ai-edge-litert"),
"//tensorflow/lite/python:schema_py",
],
)
Expand Down Expand Up @@ -208,6 +209,7 @@ py_binary(
":model_transforms_utils",
"@absl_py//absl:app",
"@absl_py//absl/flags",
requirement("ai-edge-litert"),
requirement("tensorflow"),
"//python/tflite_micro:runtime",
"//tensorflow/lite/tools:flatbuffer_utils",
Expand Down
9 changes: 5 additions & 4 deletions tensorflow/lite/micro/tools/generate_test_for_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import tensorflow as tf

from ai_edge_litert import interpreter as litert_interpreter
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb


Expand Down Expand Up @@ -103,9 +104,9 @@ def generate_golden_single_in_single_out(self):
if (len(self.model_paths) != 1):
raise RuntimeError(f'Single model expected')
model_path = self.model_paths[0]
interpreter = tf.lite.Interpreter(model_path=model_path,
interpreter = litert_interpreter.Interpreter(model_path=model_path,
experimental_op_resolver_type=\
tf.lite.experimental.OpResolverType.BUILTIN_REF)
litert_interpreter.OpResolverType.BUILTIN_REF)

interpreter.allocate_tensors()

Expand Down Expand Up @@ -140,10 +141,10 @@ def generate_goldens(self, builtin_operator):

for model_path in self.model_paths:
# Load model and run a single inference with random inputs.
interpreter = tf.lite.Interpreter(
interpreter = litert_interpreter.Interpreter(
model_path=model_path,
experimental_op_resolver_type=\
tf.lite.experimental.OpResolverType.BUILTIN_REF)
litert_interpreter.OpResolverType.BUILTIN_REF)
interpreter.allocate_tensors()
input_tensor = interpreter.tensor(
interpreter.get_input_details()[0]['index'])
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/lite/micro/tools/layer_by_layer_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl import app
from absl import flags
from absl import logging
from ai_edge_litert import interpreter as litert_interpreter
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -194,7 +195,7 @@ def main(_) -> None:
intrepreter_config=runtime.InterpreterConfig.kPreserveAllTensors,
)

tflite_interpreter = tf.lite.Interpreter(
tflite_interpreter = litert_interpreter.Interpreter(
model_path=_INPUT_TFLITE_FILE.value,
experimental_preserve_all_tensors=True,
)
Expand Down
1 change: 1 addition & 0 deletions third_party/python_requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ mako
pillow
yapf
protobuf
ai-edge-litert
4 changes: 4 additions & 0 deletions third_party/python_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ absl-py==2.0.0 \
# keras
# tensorboard
# tensorflow
ai-edge-litert==1.0.1 \
--hash=sha256:25a9b1577941498842bf77630722eda1163026c37abd57af66791a6955551b9d
# via -r third_party/python_requirements.in
astunparse==1.6.3 \
--hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \
--hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8
Expand Down Expand Up @@ -505,6 +508,7 @@ numpy==1.26.3 \
--hash=sha256:f73497e8c38295aaa4741bdfa4fda1a5aedda5473074369eca10626835445511
# via
# -r third_party/python_requirements.in
# ai-edge-litert
# h5py
# keras
# ml-dtypes
Expand Down

0 comments on commit 420e0b7

Please sign in to comment.