Blame view

3rdparty/opencv-4.5.4/samples/dnn/shrink_tf_graph_weights.py 2.25 KB
f4334277   Hu Chunming   提交3rdparty
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
  # This file is part of OpenCV project.
  # It is subject to the license terms in the LICENSE file found in the top-level directory
  # of this distribution and at http://opencv.org/license.html.
  #
  # Copyright (C) 2017, Intel Corporation, all rights reserved.
  # Third party copyrights are property of their respective owners.
  import tensorflow as tf
  import struct
  import argparse
  import numpy as np
  
  parser = argparse.ArgumentParser(description='Convert weights of a frozen TensorFlow graph to fp16.')
  parser.add_argument('--input', required=True, help='Path to frozen graph.')
  parser.add_argument('--output', required=True, help='Path to output graph.')
  parser.add_argument('--ops', default=['Conv2D', 'MatMul'], nargs='+',
                      help='List of ops which weights are converted.')
  args = parser.parse_args()
  
  DT_FLOAT = 1
  DT_HALF = 19
  
  # For the frozen graphs, an every node that uses weights connected to Const nodes
  # through an Identity node. Usually they're called in the same way with '/read' suffix.
  # We'll replace all of them to Cast nodes.
  
  # Load the model
  with tf.gfile.FastGFile(args.input) as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
  
  # Set of all inputs from desired nodes.
  inputs = []
  for node in graph_def.node:
      if node.op in args.ops:
          inputs += node.input
  
  weightsNodes = []
  for node in graph_def.node:
      # From the whole inputs we need to keep only an Identity nodes.
      if node.name in inputs and node.op == 'Identity' and node.attr['T'].type == DT_FLOAT:
          weightsNodes.append(node.input[0])
  
          # Replace Identity to Cast.
          node.op = 'Cast'
          node.attr['DstT'].type = DT_FLOAT
          node.attr['SrcT'].type = DT_HALF
          del node.attr['T']
          del node.attr['_class']
  
  # Convert weights to halfs.
  for node in graph_def.node:
      if node.name in weightsNodes:
          node.attr['dtype'].type = DT_HALF
          node.attr['value'].tensor.dtype = DT_HALF
  
          floats = node.attr['value'].tensor.tensor_content
  
          floats = struct.unpack('f' * (len(floats) / 4), floats)
          halfs = np.array(floats).astype(np.float16).view(np.uint16)
          node.attr['value'].tensor.tensor_content = struct.pack('H' * len(halfs), *halfs)
  
  tf.train.write_graph(graph_def, "", args.output, as_text=False)