ronnie
2022-10-14 1504bb53e29d3d46222c0b3ea994fc494b48e153
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Lite tooling helper functionality."""
 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import warnings
import enum
from six import PY3
 
from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn  # pylint: disable=unused-import
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell  # pylint: disable=unused-import
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell  # pylint: disable=unused-import
from tensorflow.lite.python import lite_constants as constants
from tensorflow.lite.python.convert import build_toco_convert_protos  # pylint: disable=unused-import
from tensorflow.lite.python.convert import ConverterError  # pylint: disable=unused-import
from tensorflow.lite.python.convert import OpsSet
from tensorflow.lite.python.convert import tensor_name as _tensor_name
from tensorflow.lite.python.convert import toco_convert  # pylint: disable=unused-import
from tensorflow.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_impl
from tensorflow.lite.python.convert import toco_convert_protos  # pylint: disable=unused-import
from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
from tensorflow.lite.python.convert_saved_model import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
from tensorflow.lite.python.convert_saved_model import set_tensor_shapes as _set_tensor_shapes
from tensorflow.lite.python.interpreter import Interpreter  # pylint: disable=unused-import
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs  # pylint: disable=unused-import
from tensorflow.lite.python.op_hint import OpHint  # pylint: disable=unused-import
from tensorflow.lite.python.optimize import calibrator as _calibrator
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2
from tensorflow.core.protobuf import config_pb2 as _config_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.eager import def_function as _def_function
from tensorflow.python.eager import function as _function
from tensorflow.python.framework import convert_to_constants as _convert_to_constants
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import graph_util as _tf_graph_util
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
from tensorflow.python.grappler import tf_optimizer as _tf_optimizer
from tensorflow.python.lib.io import file_io as _file_io
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
from tensorflow.python.util import deprecation as _deprecation
from tensorflow.python.util.tf_export import tf_export as _tf_export
 
 
def _run_graph_optimizations(graph_def, input_arrays, output_arrays,
                             graph=None):
  """Apply standard TensorFlow optimizations to the graph_def.
 
  Args:
    graph_def: Frozen GraphDef to be optimized.
    input_arrays: List of arrays that are considered inputs of the graph.
    output_arrays: List of arrays that are considered outputs of the graph.
    graph: TensorFlow Graph. Required when Eager mode is enabled. (default None)
 
  Returns:
    A new, optimized GraphDef.
  """
  meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph)
 
  # We need to add a collection called 'train_op' so that grappler
  # knows what the outputs are.
  fetch_collection = _meta_graph_pb2.CollectionDef()
  for array in input_arrays + output_arrays:
    fetch_collection.node_list.value.append(array.name)
  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
 
  config = _config_pb2.ConfigProto()
  rewrite_options = config.graph_options.rewrite_options
  rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.ON
  # Avoid remapping as it creates ops like _FusedConv2D, which are not
  # supported by TF Lite.
  rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF
  return _tf_optimizer.OptimizeGraph(config, meta_graph)
 
 
@_tf_export("lite.Optimize")
class Optimize(enum.Enum):
  """Enum defining the optimizations to apply when generating tflite graphs.
 
  Some optimizations may come at the cost of accuracy.
  """
 
  # Optimize for size.
  #
  # Optimizations that reduce the size of the model.
  # The model size will be reduced. Optimizations can include quantizing the
  # weights of the floating point model.
  OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE"
 
  # Optimize for latency.
  #
  # Optimizations that reduce the latency of the model.
  # The model latency will be reduced. Optimizations can include quantizing the
  # weights of the floating point model.
  OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
 
  def __str__(self):
    return self.value
 
 
@_tf_export("lite.RepresentativeDataset")
class RepresentativeDataset(object):
  """Representative dataset to evaluate optimizations.
 
  A representative dataset that can be used to evaluate optimizations by the
  converter. E.g. converter can use these examples to estimate (min, max) ranges
  by calibrating the model on inputs. This can allow converter to quantize a
  converted floating point model.
  """
 
  def __init__(self, input_gen, output_gen=None):
    """Creates a representative dataset.
 
    Args:
      input_gen: an input generator that can be used to generate input samples
        for the model. This must be a callable object that returns an object
        that supports the `iter()` protocol (e.g. a generator function). The
        elements generated must have same type and shape as inputs to the model.
      output_gen: (optional) an output generator that can be used to generate
        output samples for the model. This must be a callable object that
        returns an object that supports the `iter()` protocol (e.g. a generator
        function). The elements generated must have same type and shape as
        outputs to the model. (default None)
    """
    self.input_gen = input_gen
    self.output_gen = output_gen
 
 
@_tf_export("lite.TargetSpec")
class TargetSpec(object):
  """Specification of target device.
 
  Details about target device. Converter optimizes the generated model for
  specific device.
 
  Attributes:
    supported_ops: Experimental flag, subject to change. Set of OpsSet options
      supported by the device. (default set([OpsSet.TFLITE_BUILTINS]))
  """
 
  def __init__(self, supported_ops=None):
    if supported_ops is None:
      supported_ops = set([OpsSet.TFLITE_BUILTINS])
    self.supported_ops = supported_ops
 
 
@_tf_export("lite.TFLiteConverter", v1=[])
class TFLiteConverterV2(object):
  """Converts a TensorFlow model into TensorFlow Lite model.
 
  Attributes:
    allow_custom_ops: Boolean indicating whether to allow custom operations.
      When false any unknown operation is an error. When true, custom ops are
      created for any op that is unknown. The developer will need to provide
      these to the TensorFlow Lite runtime with a custom resolver. (default
      False)
    target_spec: Experimental flag, subject to change. Specification of target
      device.
    optimizations: Experimental flag, subject to change, A list of optimizations
      to apply when converting the model. The converter applies the
      optimizations by giving priority to the optimizations specified earlier in
      the list. E.g. `[optimize.OPTIMIZE_FOR_SIZE,
      optimize.OPTIMIZE_FOR_LATENCY]` requires the converter to do both size and
      latency optimizations giving priority to size optimizations over latency
      optimizations.
    representative_dataset: A representative dataset that can be used to
      generate input and output samples for the model. The converter can use the
      dataset to evaluate different optimizations.
 
  Example usage:
 
    ```python
    # Converting a GraphDef from a ConcreteFunction.
    converter = lite.TFLiteConverter.from_concrete_function(func)
    tflite_model = converter.convert()
    open("converted_model.tflite", "wb").write(tflite_model)
    ```
  """
 
  def __init__(self, func):
    """Constructor for TFLiteConverter.
 
    Args:
      func: TensorFlow ConcreteFunction.
    """
    self._func = func
    self.allow_custom_ops = False
    self.target_spec = TargetSpec()
    self.representative_dataset = None
    self.optimizations = []
 
  @classmethod
  def from_concrete_function(cls, func):
    """Creates a TFLiteConverter class from a ConcreteFunction.
 
    Args:
      func: TensorFlow ConcreteFunction.
 
    Returns:
      TFLiteConverter class.
    """
    if not isinstance(func, _function.ConcreteFunction):
      message = "This function takes in a ConcreteFunction."
      if isinstance(func, _def_function.Function):
        message += (" To get the ConcreteFunction from a Function,"
                    " call from_concrete_function.")
      raise ValueError(message)
    return cls(func)
 
  def convert(self):
    """Converts a TensorFlow GraphDef based on instance variables.
 
    Returns:
      The converted data in serialized format.
 
    Raises:
      ValueError:
        Input shape is not specified.
        None value for dimension in input_tensor.
    """
    frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
        self._func)
    input_tensors = [
        tensor for tensor in frozen_func.inputs
        if tensor.dtype != _dtypes.resource
    ]
    output_tensors = frozen_func.outputs
 
    # Run a Grappler pass.
    graph_def = _run_graph_optimizations(frozen_func.graph.as_graph_def(),
                                         input_tensors, output_tensors,
                                         frozen_func.graph)
 
    # Checks dimensions in input tensor.
    for tensor in input_tensors:
      # Note that shape_list might be empty for scalar shapes.
      shape_list = tensor.shape.as_list()
      if None in shape_list[1:]:
        raise ValueError(
            "None is only supported in the 1st dimension. Tensor '{0}' has "
            "invalid shape '{1}'.".format(_tensor_name(tensor), shape_list))
      elif shape_list and shape_list[0] is None:
        # Set the batch size to 1 if undefined.
        shape = tensor.shape.as_list()
        shape[0] = 1
        tensor.set_shape(shape)
 
    if self.representative_dataset:
      if not isinstance(self.representative_dataset, RepresentativeDataset):
        raise TypeError("`representative_dataset` must be an instance of "
                        "`RepresentativeDataset`")
      if self.representative_dataset.input_gen is None:
        raise ValueError(
            "Provide an input generator for `representative_dataset`")
 
    # TODO(shashishekhar): For now use optimizations order is ignored.
    # Both size and latency optimizations decide whether to apply post
    # training optimizations.
    post_training_optimize = bool(
        len(
            set(self.optimizations)
            & set([Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE])))
    # Do weights only quantization if there is no dataset for calibration.
    weights_only_quantize_flag = (
        post_training_optimize and (self.representative_dataset is None))
 
    converter_kwargs = {
        "input_format": constants.TENSORFLOW_GRAPHDEF,
        "allow_custom_ops": self.allow_custom_ops,
        "post_training_quantize": weights_only_quantize_flag,
        "target_ops": self.target_spec.supported_ops,
    }
 
    # Converts model.
    result = _toco_convert_impl(
        input_data=graph_def,
        input_tensors=input_tensors,
        output_tensors=output_tensors,
        **converter_kwargs)
 
    if self.representative_dataset and post_training_optimize:
      calibrate_quantize = _calibrator.Calibrator(result)
      result = calibrate_quantize.calibrate_and_quantize(
          self.representative_dataset.input_gen)
 
    return result
 
 
@_tf_export(v1=["lite.TFLiteConverter"])
class TFLiteConverter(object):
  """Convert a TensorFlow model into `output_format` using TOCO.
 
  This is used to convert from a TensorFlow GraphDef or SavedModel into either a
  TFLite FlatBuffer or graph visualization.
 
  Attributes:
 
    inference_type: Target data type of real-number arrays in the output file.
      Must be `{tf.float32, tf.uint8}`. (default tf.float32)
    inference_input_type: Target data type of real-number input arrays. Allows
      for a different type for input arrays in the case of quantization.
      Must be `{tf.float32, tf.uint8}`. (default `inference_type`)
    output_format: Output file format. Currently must be `{TFLITE,
      GRAPHVIZ_DOT}`. (default TFLITE)
    quantized_input_stats: Dict of strings representing input tensor names
      mapped to tuple of floats representing the mean and standard deviation
      of the training data (e.g., {"foo" : (0., 1.)}). Only need if
      `inference_input_type` is `QUANTIZED_UINT8`.
      real_input_value = (quantized_input_value - mean_value) / std_dev_value.
      (default {})
    default_ranges_stats: Tuple of integers representing (min, max) range values
      for all arrays without a specified range. Intended for experimenting with
      quantization via "dummy quantization". (default None)
    drop_control_dependency: Boolean indicating whether to drop control
      dependencies silently. This is due to TFLite not supporting control
      dependencies. (default True)
    reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
      nodes in unexpected locations. Used when the location of the FakeQuant
      nodes is preventing graph transformations necessary to convert the graph.
      Results in a graph that differs from the quantized training graph,
      potentially causing differing arithmetic behavior. (default False)
    change_concat_input_ranges: Boolean to change behavior of min/max ranges for
      inputs and outputs of the concat operator for quantized models. Changes
      the ranges of concat operator overlap when true. (default False)
    allow_custom_ops: Boolean indicating whether to allow custom operations.
      When false any unknown operation is an error. When true, custom ops are
      created for any op that is unknown. The developer will need to provide
      these to the TensorFlow Lite runtime with a custom resolver.
      (default False)
    post_training_quantize: deprecated, please specify
     `[optimize.OPTIMIZE_FOR_SIZE]` for `optimizations` instead. Boolean
     indicating whether to quantize the weights of the converted float model.
     Model size will be reduced and there will be latency improvements
     (at the cost of accuracy). (default False)
    dump_graphviz_dir: Full filepath of folder to dump the graphs at various
      stages of processing GraphViz .dot files. Preferred over
      --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
      output file. (default None)
    dump_graphviz_video: Boolean indicating whether to dump the graph after
      every graph transformation. (default False)
    target_ops: Experimental flag, subject to change. Set of OpsSet
      options indicating which converter to use.
      (default set([OpsSet.TFLITE_BUILTINS]))
    optimizations: Experimental flag, subject to change, A list of
      optimizations to apply when converting the model. The converter applies
      the optimizations by giving priority to the optimizations specified
      earlier in the list. E.g.
      `[optimize.OPTIMIZE_FOR_SIZE, optimize.OPTIMIZE_FOR_LATENCY]` requires
      the converter to do both size and latency optimizations giving priority
      to size optimizations over latency optimizations.
    representative_dataset: A representative dataset that can be used to
      generate input and output samples for the model. The converter can use
      the dataset to evaluate different optimizations.
 
  Example usage:
 
    ```python
    # Converting a GraphDef from session.
    converter = lite.TFLiteConverter.from_session(sess, in_tensors, out_tensors)
    tflite_model = converter.convert()
    open("converted_model.tflite", "wb").write(tflite_model)
 
    # Converting a GraphDef from file.
    converter = lite.TFLiteConverter.from_frozen_graph(
      graph_def_file, input_arrays, output_arrays)
    tflite_model = converter.convert()
    open("converted_model.tflite", "wb").write(tflite_model)
 
    # Converting a SavedModel.
    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
    tflite_model = converter.convert()
 
    # Converting a tf.keras model.
    converter = lite.TFLiteConverter.from_keras_model_file(keras_model)
    tflite_model = converter.convert()
    ```
  """
 
  def __init__(self,
               graph_def,
               input_tensors,
               output_tensors,
               input_arrays_with_shape=None,
               output_arrays=None):
    """Constructor for TFLiteConverter.
 
    Args:
      graph_def: Frozen TensorFlow GraphDef.
      input_tensors: List of input tensors. Type and shape are computed using
        `foo.shape` and `foo.dtype`.
      output_tensors: List of output tensors (only .name is used from this).
      input_arrays_with_shape: Tuple of strings representing input tensor names
        and list of integers representing input shapes
        (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
          into TensorFlow and when `input_tensors` and `output_tensors` are
          None. (default None)
      output_arrays: List of output tensors to freeze graph with. Use only when
        graph cannot be loaded into TensorFlow and when `input_tensors` and
        `output_tensors` are None. (default None)
 
    Raises:
      ValueError: Invalid arguments.
    """
    self._graph_def = graph_def
    self._input_tensors = input_tensors
    self._output_tensors = output_tensors
    self.inference_type = constants.FLOAT
    self.inference_input_type = None
    self.output_format = constants.TFLITE
    self.quantized_input_stats = {}
    self.default_ranges_stats = None
    self.drop_control_dependency = True
    self.reorder_across_fake_quant = False
    self.change_concat_input_ranges = False
    self.allow_custom_ops = False
    self._post_training_quantize = False
    self.dump_graphviz_dir = None
    self.dump_graphviz_video = False
    self.target_ops = set([OpsSet.TFLITE_BUILTINS])
    self.representative_dataset = None
    self.optimizations = []
 
    # Attributes are used by models that cannot be loaded into TensorFlow.
    if not self._has_valid_tensors():
      if not input_arrays_with_shape or not output_arrays:
        raise ValueError(
            "If input_tensors and output_tensors are None, both "
            "input_arrays_with_shape and output_arrays must be defined.")
      self._input_arrays_with_shape = input_arrays_with_shape
      self._output_arrays = output_arrays
 
  @classmethod
  def from_session(cls, sess, input_tensors, output_tensors):
    """Creates a TFLiteConverter class from a TensorFlow Session.
 
    Args:
      sess: TensorFlow Session.
      input_tensors: List of input tensors. Type and shape are computed using
        `foo.shape` and `foo.dtype`.
      output_tensors: List of output tensors (only .name is used from this).
 
    Returns:
      TFLiteConverter class.
    """
    graph_def = _freeze_graph(sess, output_tensors)
    return cls(graph_def, input_tensors, output_tensors)
 
  @classmethod
  def from_frozen_graph(cls,
                        graph_def_file,
                        input_arrays,
                        output_arrays,
                        input_shapes=None):
    """Creates a TFLiteConverter class from a file containing a frozen GraphDef.
 
    Args:
      graph_def_file: Full filepath of file containing frozen GraphDef.
      input_arrays: List of input tensors to freeze graph with.
      output_arrays: List of output tensors to freeze graph with.
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
          None}). (default None)
 
    Returns:
      TFLiteConverter class.
 
    Raises:
      IOError:
        File not found.
        Unable to parse input file.
      ValueError:
        The graph is not frozen.
        input_arrays or output_arrays contains an invalid tensor name.
        input_shapes is not correctly defined when required
    """
    with _ops.Graph().as_default():
      with _session.Session() as sess:
        # Read GraphDef from file.
        if not _file_io.file_exists(graph_def_file):
          raise IOError("File '{0}' does not exist.".format(graph_def_file))
        with _file_io.FileIO(graph_def_file, "rb") as f:
          file_content = f.read()
 
        try:
          graph_def = _graph_pb2.GraphDef()
          graph_def.ParseFromString(file_content)
        except (_text_format.ParseError, DecodeError):
          try:
            print("Ignore 'tcmalloc: large alloc' warnings.")
 
            if not isinstance(file_content, str):
              if PY3:
                file_content = file_content.decode("utf-8")
              else:
                file_content = file_content.encode("utf-8")
            graph_def = _graph_pb2.GraphDef()
            _text_format.Merge(file_content, graph_def)
          except (_text_format.ParseError, DecodeError):
            raise IOError(
                "Unable to parse input file '{}'.".format(graph_def_file))
 
        # Handles models with custom TFLite ops that cannot be resolved in
        # TensorFlow.
        load_model_in_session = True
        try:
          _import_graph_def(graph_def, name="")
        except _NotFoundError:
          load_model_in_session = False
 
        if load_model_in_session:
          # Check if graph is frozen.
          if not _is_frozen_graph(sess):
            raise ValueError("Please freeze the graph using freeze_graph.py.")
 
          # Get input and output tensors.
          input_tensors = _get_tensors_from_tensor_names(
              sess.graph, input_arrays)
          output_tensors = _get_tensors_from_tensor_names(
              sess.graph, output_arrays)
          _set_tensor_shapes(input_tensors, input_shapes)
 
          return cls(sess.graph_def, input_tensors, output_tensors)
        else:
          if not input_shapes:
            raise ValueError("input_shapes must be defined for this model.")
          if set(input_arrays) != set(input_shapes.keys()):
            raise ValueError("input_shapes must contain a value for each item "
                             "in input_array.")
 
          input_arrays_with_shape = [
              (name, input_shapes[name]) for name in input_arrays
          ]
          return cls(
              graph_def,
              input_tensors=None,
              output_tensors=None,
              input_arrays_with_shape=input_arrays_with_shape,
              output_arrays=output_arrays)
 
  @classmethod
  def from_saved_model(cls,
                       saved_model_dir,
                       input_arrays=None,
                       input_shapes=None,
                       output_arrays=None,
                       tag_set=None,
                       signature_key=None):
    """Creates a TFLiteConverter class from a SavedModel.
 
    Args:
      saved_model_dir: SavedModel directory to convert.
      input_arrays: List of input tensors to freeze graph with. Uses input
        arrays from SignatureDef when none are provided. (default None)
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
          None}). (default None)
      output_arrays: List of output tensors to freeze graph with. Uses output
        arrays from SignatureDef when none are provided. (default None)
      tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
        analyze. All tags in the tag set must be present. (default set("serve"))
      signature_key: Key identifying SignatureDef containing inputs and outputs.
        (default DEFAULT_SERVING_SIGNATURE_DEF_KEY)
 
    Returns:
      TFLiteConverter class.
    """
    if tag_set is None:
      tag_set = set([_tag_constants.SERVING])
    if signature_key is None:
      signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
 
    result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
                                 output_arrays, tag_set, signature_key)
    return cls(
        graph_def=result[0], input_tensors=result[1], output_tensors=result[2])
 
  @classmethod
  def from_keras_model_file(cls,
                            model_file,
                            input_arrays=None,
                            input_shapes=None,
                            output_arrays=None):
    """Creates a TFLiteConverter class from a tf.keras model file.
 
    Args:
      model_file: Full filepath of HDF5 file containing the tf.keras model.
      input_arrays: List of input tensors to freeze graph with. Uses input
        arrays from SignatureDef when none are provided. (default None)
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
          None}). (default None)
      output_arrays: List of output tensors to freeze graph with. Uses output
        arrays from SignatureDef when none are provided. (default None)
 
    Returns:
      TFLiteConverter class.
    """
    _keras.backend.clear_session()
    _keras.backend.set_learning_phase(False)
    keras_model = _keras.models.load_model(model_file)
    sess = _keras.backend.get_session()
 
    # Get input and output tensors.
    if input_arrays:
      input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
    else:
      input_tensors = keras_model.inputs
 
    if output_arrays:
      output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
    else:
      output_tensors = keras_model.outputs
    _set_tensor_shapes(input_tensors, input_shapes)
 
    graph_def = _freeze_graph(sess, output_tensors)
    return cls(graph_def, input_tensors, output_tensors)
 
  def __setattr__(self, name, value):
    if name == "post_training_quantize":
      warnings.warn("Property %s is deprecated, "
                    "please use optimizations=[Optimize.OPTIMIZE_FOR_SIZE]"
                    " instead." % name)
      if value:
        # Use OPTIMIZE_FOR_SIZE for post training for now.
        self.optimizations = [Optimize.OPTIMIZE_FOR_SIZE]
      else:
        self.optimizations = []
      return
    object.__setattr__(self, name, value)
 
  def __getattribute__(self, name):
    if name == "post_training_quantize":
      warnings.warn("Property %s is deprecated, "
                    "please use optimizations=[Optimize.OPTIMIZE_FOR_SIZE]"
                    " instead." % name)
      return Optimize.OPTIMIZE_FOR_SIZE in set(self.optimizations)
    return object.__getattribute__(self, name)
 
  def convert(self):
    """Converts a TensorFlow GraphDef based on instance variables.
 
    Returns:
      The converted data in serialized format. Either a TFLite Flatbuffer or a
      Graphviz graph depending on value in `output_format`.
 
    Raises:
      ValueError:
        Input shape is not specified.
        None value for dimension in input_tensor.
    """
    # Checks dimensions in input tensor.
    if self._has_valid_tensors():
      for tensor in self._input_tensors:
        shape = tensor.shape
        if not shape:
          raise ValueError("Provide an input shape for input array "
                           "'{0}'.".format(_tensor_name(tensor)))
        # Note that shape_list might be empty for scalar shapes.
        shape_list = shape.as_list()
        if None in shape_list[1:]:
          raise ValueError(
              "None is only supported in the 1st dimension. Tensor '{0}' has "
              "invalid shape '{1}'.".format(_tensor_name(tensor), shape_list))
        elif shape_list and shape_list[0] is None:
          self._set_batch_size(batch_size=1)
 
    # Get quantization stats. Ensures there is one stat per name if the stats
    # are specified.
    if self.quantized_input_stats:
      quantized_stats = []
      invalid_stats = []
      for name in self.get_input_arrays():
        if name in self.quantized_input_stats:
          quantized_stats.append(self.quantized_input_stats[name])
        else:
          invalid_stats.append(name)
 
      if invalid_stats:
        raise ValueError("Quantization input stats are not available for input "
                         "tensors '{0}'.".format(",".join(invalid_stats)))
    else:
      quantized_stats = None
    if self.representative_dataset:
      if not isinstance(self.representative_dataset, RepresentativeDataset):
        raise TypeError(
            "representative_dataset must be an instance of "
            "RepresentativeDataset")
      if self.representative_dataset.input_gen is None:
        raise ValueError(
            "Provide an input generator for representative_dataset")
 
    # TODO(shashishekhar): For now use optimizations order is ignored.
    # Both size and latency optimizations decide whether to apply post
    # training optimizations.
    post_training_optimize = bool(
        len(set(self.optimizations) & set([Optimize.OPTIMIZE_FOR_LATENCY,
                                           Optimize.OPTIMIZE_FOR_SIZE])))
    # Do weights only quantization if there is no dataset for calibration.
    weights_only_quantize_flag = (
        post_training_optimize and (self.representative_dataset is None))
 
    converter_kwargs = {
        "inference_type": self.inference_type,
        "inference_input_type": self.inference_input_type,
        "input_format": constants.TENSORFLOW_GRAPHDEF,
        "output_format": self.output_format,
        "quantized_input_stats": quantized_stats,
        "default_ranges_stats": self.default_ranges_stats,
        "drop_control_dependency": self.drop_control_dependency,
        "reorder_across_fake_quant": self.reorder_across_fake_quant,
        "change_concat_input_ranges": self.change_concat_input_ranges,
        "allow_custom_ops": self.allow_custom_ops,
        "post_training_quantize": weights_only_quantize_flag,
        "target_ops": self.target_ops,
        "dump_graphviz_dir": self.dump_graphviz_dir,
        "dump_graphviz_video": self.dump_graphviz_video
    }
 
    optimized_graph = None
    if self.inference_type == constants.QUANTIZED_UINT8:
      optimized_graph = self._graph_def
    else:
      try:
        optimized_graph = _run_graph_optimizations(
            self._graph_def, self._input_tensors, self._output_tensors)
      except Exception:
        optimized_graph = self._graph_def
 
    # Converts model.
    if self._has_valid_tensors():
      result = _toco_convert_impl(
          input_data=optimized_graph,
          input_tensors=self._input_tensors,
          output_tensors=self._output_tensors,
          **converter_kwargs)
    else:
      result = _toco_convert_graph_def(
          input_data=optimized_graph,
          input_arrays_with_shape=self._input_arrays_with_shape,
          output_arrays=self._output_arrays,
          **converter_kwargs)
 
    if self.representative_dataset and post_training_optimize:
      calibrate_quantize = _calibrator.Calibrator(result)
      result = calibrate_quantize.calibrate_and_quantize(
          self.representative_dataset.input_gen)
 
    return result
 
  def get_input_arrays(self):
    """Returns a list of the names of the input tensors.
 
    Returns:
      List of strings.
    """
    if self._has_valid_tensors():
      return [_tensor_name(tensor) for tensor in self._input_tensors]
    else:
      return [name for name, _ in self._input_arrays_with_shape]
 
  def _has_valid_tensors(self):
    """Checks if the input and output tensors have been initialized.
 
    Returns:
      Bool.
    """
    return self._input_tensors and self._output_tensors
 
  def _set_batch_size(self, batch_size):
    """Sets the first dimension of the input tensor to `batch_size`.
 
    Args:
      batch_size: Batch size for the model. Replaces the first dimension of an
        input size array if undefined. (default 1)
 
    Raises:
      ValueError: input_tensor is not defined.
    """
    if not self._has_valid_tensors():
      raise ValueError("The batch size cannot be set for this model. Please "
                       "use input_shapes parameter.")
 
    for tensor in self._input_tensors:
      shape = tensor.shape.as_list()
      shape[0] = batch_size
      tensor.set_shape(shape)
 
 
@_tf_export(v1=["lite.TocoConverter"])
class TocoConverter(object):
  """Convert a TensorFlow model into `output_format` using TOCO.
 
  This class has been deprecated. Please use `lite.TFLiteConverter` instead.
  """
 
  @classmethod
  @_deprecation.deprecated(None,
                           "Use `lite.TFLiteConverter.from_session` instead.")
  def from_session(cls, sess, input_tensors, output_tensors):
    """Creates a TocoConverter class from a TensorFlow Session."""
    return TFLiteConverter.from_session(sess, input_tensors, output_tensors)
 
  @classmethod
  @_deprecation.deprecated(
      None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.")
  def from_frozen_graph(cls,
                        graph_def_file,
                        input_arrays,
                        output_arrays,
                        input_shapes=None):
    """Creates a TocoConverter class from a file containing a frozen graph."""
    return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays,
                                             output_arrays, input_shapes)
 
  @classmethod
  @_deprecation.deprecated(
      None, "Use `lite.TFLiteConverter.from_saved_model` instead.")
  def from_saved_model(cls,
                       saved_model_dir,
                       input_arrays=None,
                       input_shapes=None,
                       output_arrays=None,
                       tag_set=None,
                       signature_key=None):
    """Creates a TocoConverter class from a SavedModel."""
    return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays,
                                            input_shapes, output_arrays,
                                            tag_set, signature_key)
 
  @classmethod
  @_deprecation.deprecated(
      None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.")
  def from_keras_model_file(cls,
                            model_file,
                            input_arrays=None,
                            input_shapes=None,
                            output_arrays=None):
    """Creates a TocoConverter class from a tf.keras model file."""
    return TFLiteConverter.from_keras_model_file(model_file, input_arrays,
                                                 input_shapes, output_arrays)
 
 
def _is_frozen_graph(sess):
  """Determines if the graph is frozen.
 
  Determines if a graph has previously been frozen by checking for any
  operations of type Variable*. If variables are found, the graph is not frozen.
 
  Args:
    sess: TensorFlow Session.
 
  Returns:
    Bool.
  """
  for op in sess.graph.get_operations():
    if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
      return False
  return True
 
 
def _freeze_graph(sess, output_tensors):
  """Returns a frozen GraphDef.
 
  Freezes a graph with Variables in it. Otherwise the existing GraphDef is
  returned.
 
  Args:
    sess: TensorFlow Session.
    output_tensors: List of output tensors (only .name is used from this).
 
  Returns:
    Frozen GraphDef.
  """
  if not _is_frozen_graph(sess):
    output_arrays = [_tensor_name(tensor) for tensor in output_tensors]
    return _tf_graph_util.convert_variables_to_constants(
        sess, sess.graph_def, output_arrays)
  else:
    return sess.graph_def