transformations.cpp
5.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2019 Intel Corporation
#include "precomp.hpp"
#include <ade/util/zip_range.hpp>
#include <ade/graph.hpp>
#include "api/gcomputation_priv.hpp"
#include "compiler/gmodel.hpp"
#include "compiler/gmodelbuilder.hpp"
#include "compiler/passes/passes.hpp"
#include "compiler/passes/pattern_matching.hpp"
#include <sstream>
namespace cv { namespace gimpl { namespace passes {
namespace
{
using Graph = GModel::Graph;
using Metadata = typename Graph::CMetadataT;
// Checks pairs of {pattern node, substitute node} and asserts if there are any incompatibilities
void checkDataNodes(const Graph& pattern,
const Graph& substitute,
const std::vector<ade::NodeHandle>& patternNodes,
const std::vector<ade::NodeHandle>& substituteNodes)
{
for (auto it : ade::util::zip(patternNodes, substituteNodes)) {
auto pNodeMeta = pattern.metadata(std::get<0>(it));
auto sNodeMeta = substitute.metadata(std::get<1>(it));
GAPI_Assert(pNodeMeta.get<NodeType>().t == NodeType::DATA);
GAPI_Assert(pNodeMeta.get<NodeType>().t == sNodeMeta.get<NodeType>().t);
GAPI_Assert(pNodeMeta.get<Data>().shape == sNodeMeta.get<Data>().shape);
}
}
// Checks compatibility of pattern and substitute graphs based on in/out nodes
void checkCompatibility(const Graph& pattern,
const Graph& substitute,
const Protocol& patternP,
const Protocol& substituteP)
{
const auto& patternDataInputs = patternP.in_nhs;
const auto& patternDataOutputs = patternP.out_nhs;
const auto& substituteDataInputs = substituteP.in_nhs;
const auto& substituteDataOutputs = substituteP.out_nhs;
// number of data nodes must be the same
GAPI_Assert(patternDataInputs.size() == substituteDataInputs.size());
GAPI_Assert(patternDataOutputs.size() == substituteDataOutputs.size());
// for each pattern input node, verify a corresponding substitute input node
checkDataNodes(pattern, substitute, patternDataInputs, substituteDataInputs);
// for each pattern output node, verify a corresponding substitute output node
checkDataNodes(pattern, substitute, patternDataOutputs, substituteDataOutputs);
}
// Tries to substitute __single__ pattern with substitute in the given graph
bool tryToSubstitute(ade::Graph& main,
const std::unique_ptr<ade::Graph>& patternG,
const cv::GComputation& substitute)
{
GModel::Graph gm(main);
// 1. find a pattern in main graph
auto match1 = findMatches(*patternG, gm);
if (!match1.ok()) {
return false;
}
// 2. build substitute graph inside the main graph
cv::gimpl::GModelBuilder builder(main);
auto expr = cv::util::get<cv::GComputation::Priv::Expr>(substitute.priv().m_shape);
const auto& proto_slots = builder.put(expr.m_ins, expr.m_outs);
Protocol substituteP;
std::tie(substituteP.inputs, substituteP.outputs, substituteP.in_nhs, substituteP.out_nhs) =
proto_slots;
const Protocol& patternP = GModel::Graph(*patternG).metadata().get<Protocol>();
// 3. check that pattern and substitute are compatible
// FIXME: in theory, we should always have compatible pattern/substitute. if not, we're in
// half-completed state where some transformations are already applied - what can we do
// to handle the situation better? -- use transactional API as in fuse_islands pass?
checkCompatibility(*patternG, gm, patternP, substituteP);
// 4. make substitution
performSubstitution(gm, patternP, substituteP, match1);
return true;
}
} // anonymous namespace
void applyTransformations(ade::passes::PassContext& ctx,
const gapi::GKernelPackage& pkg,
const std::vector<std::unique_ptr<ade::Graph>>& patterns)
{
const auto& transforms = pkg.get_transformations();
const auto size = transforms.size();
if (0u == size) return;
// Note: patterns are already generated at this point
GAPI_Assert(patterns.size() == transforms.size());
// transform as long as it is possible
bool canTransform = true;
while (canTransform)
{
canTransform = false;
// iterate through every transformation and try to transform graph parts
for (auto it : ade::util::zip(ade::util::toRange(transforms), ade::util::toRange(patterns)))
{
const auto& t = std::get<0>(it);
auto& pattern = std::get<1>(it); // Note: using pre-created graphs
GAPI_Assert(nullptr != pattern);
// if transformation is successful (pattern found and substituted), it is possible that
// other transformations will also be successful, so set canTransform to the returned
// value from tryToSubstitute
canTransform = tryToSubstitute(ctx.graph, pattern, t.substitute());
// Note: apply the __same__ substitution as many times as possible and only after go to
// the next one. BUT it can happen that after applying some substitution, some
// _previous_ patterns will also be found and these will be applied first
if (canTransform) {
break;
}
}
}
}
} // namespace passes
} // namespace gimpl
} // namespace cv