-
Notifications
You must be signed in to change notification settings - Fork 5
/
tunable_nasfpn_search_space.py
141 lines (127 loc) � 5.38 KB
/
tunable_nasfpn_search_space.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tunable NAS-FPN search-space.
Golnaz Ghiasi, Tsung-Yi Lin, Ruoming Pang, Quoc V. Le.
NAS-FPN: Learning Scalable Feature Pyramid Architecture for Object Detection.
https://arxiv.org/abs/1904.07392. CVPR 2019.
"""
import pyglove as pg
@pg.members([
('level', pg.typing.Int(min_value=0),
'The feature level of the current block.'),
('combine_fn',
pg.typing.Enum(default='sum', values=['sum', 'attention']),
'The type of block functions.'),
('input_offsets', pg.typing.List(pg.typing.Int(min_value=0), size=2),
'The offsets of two input features from the previously accumlated '
'features.'),
('is_output', pg.typing.Bool(), 'Whether the current block is an output.'),
])
class TunableBlockSpec(pg.Object):
"""The tunable specifications of a NAS-FPN block."""
@pg.functor([
('intermediate_blocks',
pg.typing.List(
pg.typing.Dict([
('level', pg.typing.Int(min_value=1).noneable(),
'The base feature level of the current block.'),
('combine_fn', pg.typing.Enum('sum', ['sum', 'attention']),
'The type of combine functions.'),
('input_offsets',
pg.typing.List(pg.typing.Int(min_value=0), size=2),
'The offsets of two input features from the previous accumlated '
'features.'),
])), 'A list of specifications that define the intermediate blocks.'),
('output_blocks',
pg.typing.List(
pg.typing.Dict([
('level', pg.typing.Int(min_value=1).noneable(),
'The base feature level of the current block.'),
('combine_fn', pg.typing.Enum('sum', ['sum', 'attention']),
'The type of combine functions.'),
('input_offsets',
pg.typing.List(pg.typing.Int(min_value=0), size=2),
'The offsets of two input features from the previous accumlated '
'features.'),
])), 'A list of specifications that define the output blocks.'),
('output_block_levels', pg.typing.List(pg.typing.Int(min_value=1)),
'A list of intgers that specify the feature level of output blocks'),
])
def build_tunable_block_specs(intermediate_blocks, output_blocks,
output_block_levels):
"""Builds the NAS-FPN block specification."""
def _build_block_spec(block, is_output):
return TunableBlockSpec(
level=block.level,
combine_fn=block.combine_fn,
input_offsets=block.input_offsets,
is_output=is_output)
# Rebinds level_base as they are default to None initially.
for block, level in zip(output_blocks, output_block_levels):
block.rebind(level=level)
return ([
_build_block_spec(block, is_output=False) for block in intermediate_blocks
] + [_build_block_spec(block, is_output=True) for block in output_blocks])
def nasfpn_search_space(min_level,
max_level,
level_candidates,
num_intermediate_blocks):
"""Builds the NAS-FPN search space.
Args:
min_level: an integer that represents the min feature level of the output
features.
max_level: an integer that represents the max feature level of the output
features.
level_candidates: the set of candidate levels to be searched.
num_intermediate_blocks: the number of blocks to be searched.
Returns:
a TunableBlockSpecsBuilder object that can be called to build the list of
BlockSpec for NAS-FPN architecture.
"""
num_inputs = max_level - min_level + 1
num_outputs = num_inputs
# pylint: disable=g-complex-comprehension
# pylint: disable=g-long-ternary
intermediate_blocks = [
pg.Dict(
level=pg.one_of(level_candidates),
combine_fn=pg.one_of(['sum', 'attention']),
input_offsets=pg.sublist_of(
k=2,
candidates=list(range(num_inputs + i)),
choices_distinct=True,
choices_sorted=True))
for i in range(num_intermediate_blocks)]
output_blocks = [
pg.Dict(
level=None,
combine_fn=pg.one_of(['sum', 'attention']),
input_offsets=pg.sublist_of(
k=2,
candidates=list(range(num_inputs + num_intermediate_blocks + i)),
choices_distinct=True,
choices_sorted=True))
for i in range(num_outputs)]
output_levels = pg.sublist_of(
k=num_outputs,
candidates=list(range(min_level, max_level + 1)),
choices_distinct=True,
choices_sorted=False)
# pylint: enable=g-long-ternary
# pylint: enable=g-complex-comprehension
return build_tunable_block_specs(
intermediate_blocks=intermediate_blocks,
output_blocks=output_blocks,
output_block_levels=output_levels)