xref: /llvm-project/llvm/lib/Analysis/models/saved-model-to-tflite.py (revision b71edfaa4ec3c998aadb35255ce2f60bba2940b0)
1"""Convert a saved model to tflite model.
2
3Usage: python3 saved-model-to-tflite.py <mlgo saved_model_dir> <tflite dest_dir>
4
5The <tflite dest_dir> will contain:
6  model.tflite: this is the converted saved model
7  output_spec.json: the output spec, copied from the saved_model dir.
8"""
9
10import tensorflow as tf
11import os
12import sys
13from tf_agents.policies import greedy_policy
14
15
16def main(argv):
17    assert len(argv) == 3
18    sm_dir = argv[1]
19    tfl_dir = argv[2]
20    tf.io.gfile.makedirs(tfl_dir)
21    tfl_path = os.path.join(tfl_dir, "model.tflite")
22    converter = tf.lite.TFLiteConverter.from_saved_model(sm_dir)
23    converter.target_spec.supported_ops = [
24        tf.lite.OpsSet.TFLITE_BUILTINS,
25    ]
26    tfl_model = converter.convert()
27    with tf.io.gfile.GFile(tfl_path, "wb") as f:
28        f.write(tfl_model)
29
30    json_file = "output_spec.json"
31    src_json = os.path.join(sm_dir, json_file)
32    if tf.io.gfile.exists(src_json):
33        tf.io.gfile.copy(src_json, os.path.join(tfl_dir, json_file))
34
35
36if __name__ == "__main__":
37    main(sys.argv)
38