【ATU Book-i.MX8系列 - TFLite 進階】 模組轉換(三)

一.   概述

邊緣運算的重點技術之中,除了 模組輕量化網路架構模組量化 技術之外。另一項技術就是將各家神經網路框架 進行所謂的 模組轉換 技術,能幫助開發者快速部屬至不同的神經網路框架中。故這裡想跟各位讀者探討關於各種模型格式轉換為 TensorFlow Lite 的方式,依序分為 TensorFlow 各模組格式轉換、 Pytorch 與 ONNX 格式轉換、以及 逆轉換 TensorFlow Lite 三個章節。

 

本篇章將介紹 從浮點數的 TensorFlow Lite 格式轉換成 SavedModel 格式,再重新量化一次為整數的 TensorFlow Lite 格式,如圖下所示。此技術要點參考 @PINTO0309 的 github 文章 。

各模組轉換至 TensorFlow Lite 格式示意圖

 

如下圖所示,本系列是隸屬於 機器學習開發環境 eIQ 推理引擎層 (Inference Engines Layer) 中的 TensorFlow Lite 進階系列,故後續將向讀者介紹 模組轉換()
若新讀者欲理解更多人工智慧、機器學習以及深度學習的資訊,可點選查閱下方博文

 大大通精彩博文   【ATU Book-i.MX8系列】博文索引

 

 

TensorFlow Lite 進階系列博文-文章架構示意圖

 

二.  模組轉換

TensorFlow Lite Inverter 顧名思義就是將 TensorFlow Lite 逆轉換回 SavedModel 儲存格式。主要目的是將強制將無法取得來源的 tflite 檔案進行轉換與調用,比如說 TensorFlow Lite 模組經逆轉換後,重新藉由 SavedModel 轉為整數的 TensorFlow Lite 模組。故此技術較為艱難複雜,必須對於 TensorFlow 模組設計須有一定程度的理解 !! 接下來,將以 Google 提供的 MediaPipe API 作解析 !! 有興趣者可查看 @PINTO 所撰寫的文章 !

第一步,安裝必要套件

Flatbuffers 套件下載 : https://github.com/google/flatbuffers.git

$ git clone https://github.com/google/flatbuffers.git 
$ cd /root/flatbuffers
$ cmake -G "Unix Makefiles"
$ make

MediaPipe 模組下載 : https://github.com/google/mediapipe

$ git clone https://github.com/google/mediapipe
$ cd /root/mediapipe
$ cp mediapipe/modules/palm_detection/palm_detection.tflite /root/flatbuffers/palm_detection.tflite # 移動至flatbuffers目錄下​

 

第二步,下載 schema.fbs 檔案

此檔案須與模組的 TensorFlow 版本相符

$ git clone https://github.com/tensorflow/tensorflow
$ cd /root/tensorflow
$ cp tensorflow/lite/schema/schema.fbs /root/flatbuffers/schema.fbs # 移動至flatbuffers目錄下

 

第三步,搭建重建函式與架構 API

# (1) 載入必要函式庫
import os
import numpy as np
import json
import tensorflow as tf
import shutil
from pathlib import Path

# (2) API config
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
schema = "schema.fbs"
binary = "./flatc"
model_path = "palm_detection.tflite"
output_pb_path = "palm_detection.pb"
output_savedmodel_path = "saved_model"
model_json_path = "palm_detection.json"
output_node_names = [ 'classificators' , 'regressors' ]

#( 3) API : 共三個
# 產生架構檔(JSON)
def gen_model_json():
cmd = (binary + " -t --strict-json --defaults-json -o . {schema} -- {input}".format(input=model_path, schema=schema))
print("output json command =", cmd)
os.system(cmd)

# 解析架構
def parse_json():
j = json.load(open(model_json_path))
op_types = [ v['builtin_code’] for v in j[ 'operator_codes'] ]
ops = j['subgraphs'][0]['operators'] # subgraphs => 分為 tensors / inputs / outputs / operators
return ops, op_types

# 描述架構圖
# make_graph 函式內需要撰寫 tflite 內所有用到 架構層 layer
# 因此非常龐大 這裡僅介紹 conv2 的描述,完整代碼請至 “該連結查看”
def make_graph(ops, op_types, interpreter):
tensors = {}
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for input_detail in input_details:
tensors[input_detail['index']] = tf.compat.v1.placeholder(
dtype=input_detail['dtype'], shape=input_detail['shape'], name=input_detail['name'])

for index, op in enumerate(ops):
print('op: ', op)
op_type = op_types[op['opcode_index']]
if op_type == 'CONV_2D':
input_tensor = tensors[op['inputs'][0]]
weights_detail = interpreter._get_tensor_details(op['inputs'][1])
bias_detail = interpreter._get_tensor_details(op['inputs'][2])
output_detail = interpreter._get_tensor_details(op['outputs'][0])
weights_array = interpreter.get_tensor(weights_detail['index'])
weights_array = np.transpose(weights_array, (1, 2, 3, 0))
bias_array = interpreter.get_tensor(bias_detail['index'])
weights = tf.compat.v1.Variable(weights_array, name=weights_detail['name'])
bias = tf.compat.v1.Variable(bias_array, name=bias_detail['name'])
options = op['builtin_options']
output_tensor = tf.compat.v1.nn.conv2d(
input_tensor,
weights,
strides=[1, options['stride_h'], options['stride_w'], 1],
padding=options['padding'],
dilations=[1, options['dilation_h_factor'],options['dilation_w_factor'], 1],
name=output_detail['name'] + '/conv2d')
output_tensor = tf.compat.v1.add(output_tensor, bias, name=output_detail['name'])

tensors[output_detail['index']] = output_tensor

 

第四步,調用 API 進行逆轉換

# 將 tf lite 架構轉成 json 格式
tf.compat.v1.disable_eager_execution()
gen_model_json()
ops, op_types = parse_json()

# 調用 tflite 取得資訊
interpreter = Interpreter(model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)

# 將 tf lite 重新拆解回 savedmodel
# 若欲將 tflite 轉成全整數的格式,請查閱並結合上一小節方式。
make_graph(ops, op_types, interpreter)
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
graph = tf.compat.v1.get_default_graph()
with tf.compat.v1.Session(config=config, graph=graph) as sess:
sess.run(tf.compat.v1.global_variables_initializer())
graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess=sess,\
input_graph_def=graph.as_graph_def(),\
output_node_names=output_node_names)
with tf.io.gfile.GFile(output_pb_path, 'wb') as f:
f.write(graph_def.SerializeToString())
shutil.rmtree('saved_model', ignore_errors=True)
tf.compat.v1.saved_model.simple_save(sess,
output_savedmodel_path,
inputs={'input': graph.get_tensor_by_name('input:0')},
outputs={
'classificators': graph.get_tensor_by_name('classificators:0'),
'regressors': graph.get_tensor_by_name('regressors:0')
})

 

第五步,SavedModel 轉換 TensorFlow Lite

# 因擔心代碼過於冗長,故此代碼轉成 tflite 格式為主。 
import tensorflow as tf
converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model( 'saved_model' )
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.compat.v1.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.compat.v1.uint8
converter.inference_output_type = tf.compat.v1.uint8
converter.representative_dataset = representative_dataset_gen # representative_dataset 請參閱上述的量化方式
tflite_model = converter.convert()
with tf.io.gfile.GFile( "/root/palm_detection_uint8.tflite" , 'wb') as f:
f.write(tflite_model)
print("Integer Quantization complete! ")

※ tips 分享 : 此方式相當費時費工,且因 TensorFlow 版本因素 再次提高轉換的實現難度

 

三.  結語

模組轉換是一項相當實用的技術,但取決於各家神經網路框架的版本不同,仍會出現無法轉換成功的問題。此章節提供讀者一套 TensorFlow Lite 重新拆解的方法,拆解為 SavedModel 後就可以再次轉成 ONNX 或是 TensorFlow Lite 的格式。這可以用在純整數的 AI 晶片上,而 NXP I.MX8MPlus 的 NPU 神經處理單元正好可以適用此項技術,並能幫助讀者移植至該平台中實現應用。若有任何模組轉換的問題,可以至下方留言所遇到的問題,讓我們一起互相切磋,一起成長!! 下一篇文章,將會介紹如何建置 TensorFlow Lite 的物件偵測應用,並應用在 NXP i.MX8M Plus 的平台上。

 

四.  參考文件

[1] 官方文件 - i.MX Machine Learning User's Guide pdf
[2] 官方文件 - TensorFlow Lite 轉換工具
[3] 官方文件 - Post-training quantization
[4] 官方文件 - TensorFlow API
[5] 第三方文件 - Tensorflow模型量化(Quantization)原理及其实现方法
[6] 官方文件 - TensorFlow Lite 現有應用資源
[7] 官方文件 - TensorFlow Lite Hub
[8] 官方文件 - TensorFlow Lite Slim
[9] 官方文件 - TensorFlow Model Garden
[10] 官方文件 - TensorFlow Model JS
[11] 官方文件 -  Keras
[12] 官方文件 - TensorFlow - Using the SavedModel Format
[13] 第三方文件 - TensorFlow 模型導出總結
[14] 官方文件 - TensorFlow Model Garden
[15] 官方文件 -Megnta github
[16] 官方文件 - Pytorch
[17] 官方文件 - ONNX
[18] 官方文件 - OpenVINO
[19] 官方文件 - ONNX Model Zoo
[20] 官方文件 - OpenVINO Model Zoo
[21] 官方文件 - YOLO : Real-Time Object Detection
[22] 第三方文件 -
@PINTO github 系列

 

如有任何相關 TensorFlow Lite 進階技術問題,歡迎至博文底下留言提問 !!
接下來還會分享更多 TensorFlow Lite 進階的技術文章 !!敬請期待 ATU Book-i.MX8系列 – TFLite 進階 !!

★博文內容均由個人提供,與平台無關,如有違法或侵權,請與網站管理員聯繫。

★文明上網,請理性發言。內容一周內被舉報5次,發文人進小黑屋喔~

評論