lin
2025-08-14 dae8bad597b6607a449b32bf76c523423f7720ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
# ==============================================================================
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Upgrade script to move from pre-release schema to new schema.
 
Usage examples:
 
bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.json
bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.bin
bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.json
bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.bin
bazel run tensorflow/lite/schema/upgrade_schema -- in.tflite out.tflite
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import argparse
import contextlib
import json
import os
import shutil
import subprocess
import sys
import tempfile
 
import tensorflow as tf
from tensorflow.python.platform import resource_loader
 
parser = argparse.ArgumentParser(
    description="Script to move TFLite models from pre-release schema to "
    "new schema.")
parser.add_argument(
    "input",
    type=str,
    help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.")
parser.add_argument(
    "output",
    type=str,
    help="Output json or bin TensorFlow lite model compliant with "
    "the new schema. Extension must be `.json`, `.bin` or `.tflite`.")
 
 
# RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles.
@contextlib.contextmanager
def TemporaryDirectoryResource():
  temporary = tempfile.mkdtemp()
  try:
    yield temporary
  finally:
    shutil.rmtree(temporary)
 
 
class Converter(object):
  """Converts TensorFlow flatbuffer models from old to new version of schema.
 
  This can convert between any version to the latest version. It uses
  an incremental upgrade strategy to go from version to version.
 
  Usage:
    converter = Converter()
    converter.Convert("a.tflite", "a.json")
    converter.Convert("b.json", "b.tflite")
  """
 
  def __init__(self):
    # TODO(aselle): make this work in the open source version with better
    # path.
    paths_to_try = [
        "../../../../flatbuffers/flatc",  # not bazel
        "../../../../external/flatbuffers/flatc"  # bazel
    ]
    for p in paths_to_try:
      self._flatc_path = resource_loader.get_path_to_datafile(p)
      if os.path.exists(self._flatc_path): break
 
    def FindSchema(base_name):
      return resource_loader.get_path_to_datafile("%s" % base_name)
 
    # Supported schemas for upgrade.
    self._schemas = [
        (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1),
        (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2),
        (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3),
        (3, FindSchema("schema_v3.fbs"), False, None)  # Non-callable by design.
    ]
    # Ensure schemas are sorted, and extract latest version and upgrade
    # dispatch function table.
    self._schemas.sort()
    self._new_version, self._new_schema = self._schemas[-1][:2]
    self._upgrade_dispatch = {
        version: dispatch
        for version, unused1, unused2, dispatch in self._schemas}
 
  def _Read(self, input_file, schema, raw_binary=False):
    """Read a tflite model assuming the given flatbuffer schema.
 
    If `input_file` is in bin, then we must use flatc to convert the schema
    from binary to json.
 
    Args:
      input_file: a binary (flatbuffer) or json file to read from. Extension
        must  be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or
        FlatBuffer JSON.
      schema: which schema to use for reading
      raw_binary: whether to assume raw_binary (versions previous to v3)
        that lacked file_identifier require this.
 
    Raises:
      RuntimeError: When flatc cannot be invoked.
      ValueError: When the extension is not json or bin.
 
    Returns:
      A dictionary representing the read tflite model.
    """
    raw_binary = ["--raw-binary"] if raw_binary else []
    with TemporaryDirectoryResource() as tempdir:
      basename = os.path.basename(input_file)
      basename_no_extension, extension = os.path.splitext(basename)
      if extension in [".bin", ".tflite"]:
        # Convert to json using flatc
        returncode = subprocess.call([
            self._flatc_path,
            "-t",
            "--strict-json",
            "--defaults-json",
        ] + raw_binary + ["-o", tempdir, schema, "--", input_file])
        if returncode != 0:
          raise RuntimeError("flatc failed to convert from binary to json.")
        json_file = os.path.join(tempdir, basename_no_extension + ".json")
        if not os.path.exists(json_file):
          raise RuntimeError("Could not find %r" % json_file)
      elif extension == ".json":
        json_file = input_file
      else:
        raise ValueError("Invalid extension on input file %r" % input_file)
      return json.load(open(json_file))
 
  def _Write(self, data, output_file):
    """Output a json or bin version of the flatbuffer model.
 
    Args:
      data: Dict representing the TensorFlow Lite model to write.
      output_file: filename to write the converted flatbuffer to. (json,
        tflite, or bin extension is required).
    Raises:
      ValueError: When the extension is not json or bin
      RuntimeError: When flatc fails to convert json data to binary.
    """
    _, extension = os.path.splitext(output_file)
    with TemporaryDirectoryResource() as tempdir:
      if extension == ".json":
        json.dump(data, open(output_file, "w"), sort_keys=True, indent=2)
      elif extension in [".tflite", ".bin"]:
        input_json = os.path.join(tempdir, "temp.json")
        with open(input_json, "w") as fp:
          json.dump(data, fp, sort_keys=True, indent=2)
        returncode = subprocess.call([
            self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o",
            tempdir, self._new_schema, input_json
        ])
        if returncode != 0:
          raise RuntimeError("flatc failed to convert upgraded json to binary.")
 
        shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file)
      else:
        raise ValueError("Invalid extension on output file %r" % output_file)
 
  def _Upgrade0To1(self, data):
    """Upgrade data from Version 0 to Version 1.
 
    Changes: Added subgraphs (which contains a subset of formally global
    entries).
 
    Args:
      data: Dictionary representing the TensorFlow lite data to be upgraded.
        This will be modified in-place to be an upgraded version.
    """
    subgraph = {}
    for key_to_promote in ["tensors", "operators", "inputs", "outputs"]:
      subgraph[key_to_promote] = data[key_to_promote]
      del data[key_to_promote]
    data["subgraphs"] = [subgraph]
 
  def _Upgrade1To2(self, data):
    """Upgrade data from Version 1 to Version 2.
 
    Changes: Rename operators to Conform to NN API.
 
    Args:
      data: Dictionary representing the TensorFlow lite data to be upgraded.
        This will be modified in-place to be an upgraded version.
    Raises:
      ValueError: Throws when model builtins are numeric rather than symbols.
    """
 
    def RemapOperator(opcode_name):
      """Go from old schema op name to new schema op name.
 
      Args:
        opcode_name: String representing the ops (see :schema.fbs).
      Returns:
        Converted opcode_name from V1 to V2.
      """
      old_name_to_new_name = {
          "CONVOLUTION": "CONV_2D",
          "DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D",
          "AVERAGE_POOL": "AVERAGE_POOL_2D",
          "MAX_POOL": "MAX_POOL_2D",
          "L2_POOL": "L2_POOL_2D",
          "SIGMOID": "LOGISTIC",
          "L2NORM": "L2_NORMALIZATION",
          "LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION",
          "Basic_RNN": "RNN",
      }
 
      return (old_name_to_new_name[opcode_name]
              if opcode_name in old_name_to_new_name else opcode_name)
 
    def RemapOperatorType(operator_type):
      """Remap operator structs from old names to new names.
 
      Args:
        operator_type: String representing the builtin operator data type
          string.
        (see :schema.fbs).
      Returns:
        Upgraded builtin operator data type as a string.
      """
      old_to_new = {
          "PoolOptions": "Pool2DOptions",
          "DepthwiseConvolutionOptions": "DepthwiseConv2DOptions",
          "ConvolutionOptions": "Conv2DOptions",
          "LocalResponseNormOptions": "LocalResponseNormalizationOptions",
          "BasicRNNOptions": "RNNOptions",
      }
      return (old_to_new[operator_type]
              if operator_type in old_to_new else operator_type)
 
    for subgraph in data["subgraphs"]:
      for ops in subgraph["operators"]:
        ops["builtin_options_type"] = RemapOperatorType(
            ops["builtin_options_type"])
 
    # Upgrade the operator codes
    for operator_code in data["operator_codes"]:
      # Check if builtin_code is the appropriate string type
      # use type("") instead of str or unicode. for py2and3
      if not isinstance(operator_code["builtin_code"], type(u"")):
        raise ValueError("builtin_code %r is non-string. this usually means "
                         "your model has consistency problems." %
                         (operator_code["builtin_code"]))
      operator_code["builtin_code"] = (RemapOperator(
          operator_code["builtin_code"]))
 
  def _Upgrade2To3(self, data):
    """Upgrade data from Version 2 to Version 3.
 
    Changed actual read-only tensor data to be in a buffers table instead
    of inline with the tensor.
 
    Args:
      data: Dictionary representing the TensorFlow lite data to be upgraded.
        This will be modified in-place to be an upgraded version.
    """
    buffers = [{"data": []}]  # Start with 1 empty buffer
    for subgraph in data["subgraphs"]:
      if "tensors" not in subgraph:
        continue
      for tensor in subgraph["tensors"]:
        if "data_buffer" not in tensor:
          tensor["buffer"] = 0
        else:
          if tensor["data_buffer"]:
            tensor[u"buffer"] = len(buffers)
            buffers.append({"data": tensor["data_buffer"]})
          else:
            tensor["buffer"] = 0
          del tensor["data_buffer"]
    data["buffers"] = buffers
 
  def _PerformUpgrade(self, data):
    """Manipulate the `data` (parsed JSON) based on changes in format.
 
    This incrementally will upgrade from version to version within data.
 
    Args:
      data: Dictionary representing the TensorFlow data. This will be upgraded
        in place.
    """
    while data["version"] < self._new_version:
      self._upgrade_dispatch[data["version"]](data)
      data["version"] += 1
 
  def Convert(self, input_file, output_file):
    """Perform schema conversion from input_file to output_file.
 
    Args:
      input_file: Filename of TensorFlow Lite data to convert from. Must
        be `.json` or `.bin` extension files for JSON or Binary forms of
        the TensorFlow FlatBuffer schema.
      output_file: Filename to write to. Extension also must be `.json`
        or `.bin`.
 
    Raises:
      RuntimeError: Generated when none of the upgrader supported schemas
        matche the `input_file` data.
    """
    # Read data in each schema (since they are incompatible). Version is
    # always present. Use the read data that matches the version of the
    # schema.
    for version, schema, raw_binary, _ in self._schemas:
      try:
        data_candidate = self._Read(input_file, schema, raw_binary)
      except RuntimeError:
        continue  # Skip and hope another schema works
      if "version" not in data_candidate:  # Assume version 1 if not present.
        data_candidate["version"] = 1
      elif data_candidate["version"] == 0:  # Version 0 doesn't exist in wild.
        data_candidate["version"] = 1
 
      if data_candidate["version"] == version:
        self._PerformUpgrade(data_candidate)
        self._Write(data_candidate, output_file)
        return
    raise RuntimeError("No schema that the converter understands worked with "
                       "the data file you provided.")
 
 
def main(argv):
  del argv
  Converter().Convert(FLAGS.input, FLAGS.output)
 
 
if __name__ == "__main__":
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)