forked from tensorflow/tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
106 lines (88 loc) · 3.89 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import argparse
import json
from .object_detection import download_model, download_dataset, optimize_model, benchmark_model
def test(test_config_path):
"""Runs an object detection test configuration
This runs an object detection test configuration. This involves
1. Download a model architecture (or use cached).
2. Optimize the downloaded model architecrue
3. Benchmark the optimized model against a dataset
4. (optional) Run assertions to check the benchmark output
The input to this function is a JSON file which specifies the test
configuration.
example_test_config.json:
{
"source_model": { ... },
"optimization_config": { ... },
"benchmark_config": { ... },
"assertions": [ ... ]
}
source_model: A dictionary of arguments passed to download_model, which
specify the pre-optimized model architure. The model downloaded (or
the cached model if found) will be passed to optimize_model.
optimization_config: A dictionary of arguments passed to optimize_model.
Please see help(optimize_model) for more details.
benchmark_config: A dictionary of arguments passed to benchmark_model.
Please see help(benchmark_model) for more details.
assertions: A list of strings containing python code that will be
evaluated. If the code returns false, an error will be thrown. These
assertions can reference any variables local to this 'test' function.
Some useful values are
statistics['map']
statistics['avg_latency']
statistics['avg_throughput']
Args
----
test_config_path: A string corresponding to the test configuration
JSON file.
"""
with open(args.test_config_path, 'r') as f:
test_config = json.load(f)
print(json.dumps(test_config, sort_keys=True, indent=4))
# download model or use cached
config_path, checkpoint_path = download_model(**test_config['source_model'])
# optimize model using source model
frozen_graph = optimize_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
**test_config['optimization_config'])
# benchmark optimized model
statistics = benchmark_model(
frozen_graph=frozen_graph,
**test_config['benchmark_config'])
# print some statistics to command line
print_statistics = statistics
if 'runtimes_ms' in print_statistics:
print_statistics.pop('runtimes_ms')
print(json.dumps(print_statistics, sort_keys=True, indent=4))
# run assertions
if 'assertions' in test_config:
for a in test_config['assertions']:
if not eval(a):
raise AssertionError('ASSERTION FAILED: %s' % a)
else:
print('ASSERTION PASSED: %s' % a)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'test_config_path',
help='Path of JSON file containing test configuration. Please'
'see help(tftrt.examples.object_detection.test) for more information')
args=parser.parse_args()
test(args.test_config_path)