发布于 

AI杂谈001. 使用TFLite实现MNIST推理

import tensorflow as tf
from tensorflow import keras

import numpy as np
import matplotlib.pyplot as plt
import random

print(tf.__version__)
2023-02-22 22:47:18.176362: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


2.11.0

下载MNIST数据集

mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 6s 0us/step
# 归一化图片像素值至0-1间
train_images = train_images / 255.0
test_images = test_images / 255.0
# 绘制训练集的前25张图片
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5, 5, i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i], cmap=plt.cm.gray)
plt.xlabel(train_labels[i])
plt.show()


png

定义模型

# 定义模型结构
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation=tf.nn.relu),
keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Dropout(0.25),
keras.layers.Flatten(),
keras.layers.Dense(10)
])

# 定义如何训练模型
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)

#训练模型
model.fit(train_images, train_labels, epochs=5)
1875/1875 [==============================] - 44s 23ms/step - loss: 0.1370 - accuracy: 0.9589
Epoch 2/5
1875/1875 [==============================] - 47s 25ms/step - loss: 0.0534 - accuracy: 0.9835
Epoch 3/5
1875/1875 [==============================] - 48s 26ms/step - loss: 0.0401 - accuracy: 0.9876
Epoch 4/5
1875/1875 [==============================] - 49s 26ms/step - loss: 0.0312 - accuracy: 0.9897
Epoch 5/5
1875/1875 [==============================] - 50s 27ms/step - loss: 0.0256 - accuracy: 0.9919
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 reshape_1 (Reshape)         (None, 28, 28, 1)         0         

 conv2d_2 (Conv2D)           (None, 26, 26, 32)        320       

 conv2d_3 (Conv2D)           (None, 24, 24, 64)        18496     

 max_pooling2d_1 (MaxPooling  (None, 12, 12, 64)       0         
 2D)                                                             

 dropout_1 (Dropout)         (None, 12, 12, 64)        0         

 flatten_1 (Flatten)         (None, 9216)              0         

 dense_1 (Dense)             (None, 10)                92170     

=================================================================
Total params: 110,986
Trainable params: 110,986
Non-trainable params: 0
_________________________________________________________________

评估模型

test_loss, test_acc = model.evaluate(test_images, test_labels)

print("Test accuracy:", test_acc)
313/313 [==============================] - 2s 7ms/step - loss: 0.0310 - accuracy: 0.9903
Test accuracy: 0.9902999997138977
def get_label_color(val1, val2):
if (val1 == val2):
return 'green'
else:
return 'red'

predictions = model.predict(test_images)

# 模型输出10个浮点数,表示输入图片中是0到9的概率,需要找出最大概率的值,也就是预测的最可能的数字
prediction_digits = np.argmax(predictions, axis=1)

plt.figure(figsize=(18, 18))
for i in range(100):
ax = plt.subplot(10, 10, i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
image_index = random.randint(0, len(prediction_digits))
plt.imshow(test_images[image_index], cmap=plt.cm.gray)
ax.xaxis.label.set_color(get_label_color(prediction_digits[image_index], test_labels[image_index]))
plt.xlabel('Predicted: %d' % prediction_digits[image_index])
plt.show()
313/313 [==============================] - 2s 7ms/step

png

转换为tflite模型

# 将Keras模型转换为TF Lite浮点模型
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_float_model = converter.convert()

# 显示浮点模型大小
float_model_size = len(tflite_float_model) / 1024
print('Float model size = %dKBs' % float_model_size)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _update_step_xla while saving (showing 3 of 3). These functions will not be directly callable after loading.


INFO:tensorflow:Assets written to: /var/folders/mp/n8hmdv6j5y1f9th2v0k78zz80000gp/T/tmp614pyr8m/assets


INFO:tensorflow:Assets written to: /var/folders/mp/n8hmdv6j5y1f9th2v0k78zz80000gp/T/tmp614pyr8m/assets


Float model size = 437KBs


2023-02-22 23:22:07.283679: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2023-02-22 23:22:07.283696: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2023-02-22 23:22:07.284211: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /var/folders/mp/n8hmdv6j5y1f9th2v0k78zz80000gp/T/tmp614pyr8m
2023-02-22 23:22:07.285685: I tensorflow/cc/saved_model/reader.cc:89] Reading meta graph with tags { serve }
2023-02-22 23:22:07.285696: I tensorflow/cc/saved_model/reader.cc:130] Reading SavedModel debug info (if present) from: /var/folders/mp/n8hmdv6j5y1f9th2v0k78zz80000gp/T/tmp614pyr8m
2023-02-22 23:22:07.289811: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:357] MLIR V1 optimization pass is not enabled
2023-02-22 23:22:07.290950: I tensorflow/cc/saved_model/loader.cc:229] Restoring SavedModel bundle.
2023-02-22 23:22:07.323036: I tensorflow/cc/saved_model/loader.cc:213] Running initialization op on SavedModel bundle at path: /var/folders/mp/n8hmdv6j5y1f9th2v0k78zz80000gp/T/tmp614pyr8m
2023-02-22 23:22:07.331699: I tensorflow/cc/saved_model/loader.cc:305] SavedModel load for tags { serve }; Status: success: OK. Took 47489 microseconds.
2023-02-22 23:22:07.351673: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
# 使用量化方式重新转换模型
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quantized_model = converter.convert()

# 显示量化模型大小
quantized_model_size = len(tflite_quantized_model) / 1024
print('Quantized model size = %dKBs' % quantized_model_size)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _update_step_xla while saving (showing 3 of 3). These functions will not be directly callable after loading.


INFO:tensorflow:Assets written to: /var/folders/mp/n8hmdv6j5y1f9th2v0k78zz80000gp/T/tmpio1p9wql/assets


INFO:tensorflow:Assets written to: /var/folders/mp/n8hmdv6j5y1f9th2v0k78zz80000gp/T/tmpio1p9wql/assets


Quantized model size = 114KBs


2023-02-22 23:24:13.132530: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2023-02-22 23:24:13.132544: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2023-02-22 23:24:13.132644: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /var/folders/mp/n8hmdv6j5y1f9th2v0k78zz80000gp/T/tmpio1p9wql
2023-02-22 23:24:13.134124: I tensorflow/cc/saved_model/reader.cc:89] Reading meta graph with tags { serve }
2023-02-22 23:24:13.134133: I tensorflow/cc/saved_model/reader.cc:130] Reading SavedModel debug info (if present) from: /var/folders/mp/n8hmdv6j5y1f9th2v0k78zz80000gp/T/tmpio1p9wql
2023-02-22 23:24:13.138687: I tensorflow/cc/saved_model/loader.cc:229] Restoring SavedModel bundle.
2023-02-22 23:24:13.172666: I tensorflow/cc/saved_model/loader.cc:213] Running initialization op on SavedModel bundle at path: /var/folders/mp/n8hmdv6j5y1f9th2v0k78zz80000gp/T/tmpio1p9wql
2023-02-22 23:24:13.181701: I tensorflow/cc/saved_model/loader.cc:305] SavedModel load for tags { serve }; Status: success: OK. Took 49057 microseconds.

运行tflite模型

def evaluate_tflite_model(tflite_model):
# 初始化tflite解释器
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_tensor_index = interpreter.get_input_details()[0]["index"]
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])

# 对测试集的每张图片运行推理
prediction_digits = []
for test_image in test_images:
# 预处理:添加batch维,并且转换成float32类型,来匹配模型的输入格式
test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
interpreter.set_tensor(input_tensor_index, test_image)

# 运行推理
interpreter.invoke()

# 后处理:消除batch维,并且找出可能性最高的数字
digit = np.argmax(output()[0])
prediction_digits.append(digit)

# 比较预测结果和标杆标签,计算精度
accurate_count = 0
for index in range(len(prediction_digits)):
if prediction_digits[index] == test_labels[index]:
accurate_count += 1
accuracy = accurate_count * 1.0 / len(prediction_digits)

return accuracy

float_accuracy = evaluate_tflite_model(tflite_float_model)
print('Float model accuracy = %.4f' % float_accuracy)

quantized_accuracy = evaluate_tflite_model(tflite_quantized_model)
print('Quantized model accuracy = %.4f' % quantized_accuracy)
Float model accuracy = 0.9903
Quantized model accuracy = 0.9903

参考

[1] TensorFlow示例