Skip to content

Commit e34c6ac

Browse files
authored
Development pipeline and Module APIs unified (#385)
1. Interface Unify 1.1 Abstract uniform training pipeline into modules: *Augmentor* : used to augment the image training data. Basic class is `Model.BasicAugmentor`. *PreProcessor* : used to generate target heatmap from the augmented image data and the key point annotations for model to regression. Basic class is `Model.BasicPreProcessor`. *PostProcessor*: used to generated detected human body joints from the model predict heatmap. Basic class is `Model.BasicPostProcessor`. *Visualizer*: used to visualize the model predict heatmap and the humans detected together with the origin images. Basic class is `Model.BasicVisualizer`. The difference of the training procedure between different pose estimation methods are then divided into pre-processing handled by the PreProcessor, post-processing handled by the PostProcessor and visualizing handled by the Visualizer. Inherit the corresponding Basic class, implementing member functions according to the provided function protocals, and then use `Config` module to set these 3 custom modules to implement any pose estimation pipeline you want. 1.2 Abtract uniform model API protocal uniform model APIs including `forwarding`, `cal_loss` and `infer` and provide a basic class `Model.BasicModel` for model customization. Inherite the model basic class to implement any pose estimation model you want. 2. Additional handy component 2.1 Metric Manager Introducing `Model.MetricManager` that provides `update` and `report` function to statistic loss values during training and generate report massages for logging. 2.2 Image Processor Introducing `Model.ImageProcessor` that provides useful interfaces to read images, pad and scale them for easily converting images into model input format. 2.3 Weight Examination Introducing weight examination APIs in Model module for model, npz file and npz_dict file to easily exam the model weights. 3. Issue Fix 3.1 Python Demo Providing `python_demo.py` as a python demo to replace the old problematic demo program `infer.py`. python_demo.py is used to easily try the npz model weights (both pretrained model weights or weights trained by users themselves) and also demonstraste the usage of PostProcessor, Visualizer and ImageProcessor modules. 3.2 Shape mismatch issues Fix the shape mismatch issue that occurs when loading the pretrained model weights. (The issue was introduced by version compatibility) 3.3 Domain adapation Fix the domain adaptation loss calculation and optimization issue occurs in tensorflow tape scope. warp all the domain adaptation data pipeline into the Domainadapt_dataset class. Domain adaptation can be put into pratical usage now. 3.4 Parallel training Fit Kungfu new APIs for parallel and distributed training. 3.5 Other other known issues in Processor modules such as tensorflow eager tensor and numpy ndarray compatibility issue, pyplot value clip issue. 4. Standarize 4.1 Logging info standardize Use standard file stream and std-out stream in logging module to output log, split the logging information of hyperpose into [DATA],[MODEL],[TRAIN] 3 parts to regulate the logging information. Formate human body joints output string format 4.2 Channels format standardize. Adapt all the pre-processing, post-processing, and visualizing functions to accept `channel_first` data in default to make the system clearer. 4.3 Model weights format standardize. Adapat all the model weights loading and saving procedure default format from `npz` to `npz_dict` format to make model weight conversion and examination convenient. (`npz` format weights are ordered arrays while `npz_dict` format weights are dictionary, which is more convenient to locate and examine specific weight.) 4.4 Help information standardize. Add help informations about variable definition, object definition, development platform basic usage, development platform custom usage, additional features when constructing models. 4.5 Tidy up all 10 model backbones provided. 4.6 Tidy up custom APIs in Config module.
1 parent 31d79ca commit e34c6ac

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+2797
-3148
lines changed

.gitignore

+6
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ debug.*
6060
!docs/Makefile
6161
!docs/markdown/images/*
6262

63+
test_cmd.txt
64+
6365

6466
/.build
67+
test_dir
68+
*_save_dir
69+
cvt_dir
70+
test_cmd.txt
6571

eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
help='number of visible evaluation')
5959
parser.add_argument('--multiscale',
6060
type=bool,
61-
default=False,
61+
default=True,
6262
help='enable multiscale_search')
6363

6464

export_pb.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def analyze_inputs_outputs(graph):
7272

7373
export_batch_size=args.export_batch_size
7474
export_h,export_w=args.export_h,args.export_w
75-
print(f"{export_batch_size = }\t{export_h = }\t{export_w = }")
75+
print(f"export_batch_size={export_batch_size}\texport_h={export_h}\texport_w={export_w}")
7676
input_path=f"{config.model.model_dir}/newest_model.npz"
7777
output_dir=f"{args.output_dir}/{config.model.model_name}"
7878
output_path=f"{output_dir}/frozen_{config.model.model_name}.pb"

hyperpose/Config/__init__.py

+120-22
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,17 @@
1010

1111
#default train
1212
update_train.optim_type=OPTIM.Adam
13+
update_train.kungfu_option = KUNGFU.Sync_avg
1314

1415
#defualt model config
1516
update_model.model_type=MODEL.Openpose
1617
#userdef model
17-
update_model.userdef_parts=None
18-
update_model.userdef_limbs=None
18+
update_model.custom_parts = None
19+
update_model.custom_limbs = None
20+
update_model.custom_augmentor = None
21+
update_model.custom_preprocessor = None
22+
update_model.custom_postprocessor = None
23+
update_model.custom_visualizer = None
1924

2025
#default dataset config
2126
#official dataset
@@ -56,7 +61,7 @@ def get_config():
5661
an edict object contains all the configuration information.
5762
5863
'''
59-
#import basic configurations
64+
# import basic configurations
6065
if(update_model.model_type==MODEL.Openpose):
6166
from .config_opps import model,train,eval,test,data,log
6267
elif(update_model.model_type==MODEL.LightweightOpenpose):
@@ -67,15 +72,15 @@ def get_config():
6772
from .config_ppn import model,train,eval,test,data,log
6873
elif(update_model.model_type==MODEL.Pifpaf):
6974
from .config_pifpaf import model,train,eval,test,data,log
70-
#merge settings with basic configurations
75+
# merge settings with basic configurations
7176
model.update(update_model)
7277
train.update(update_train)
7378
eval.update(update_eval)
7479
test.update(update_test)
7580
data.update(update_data)
7681
log.update(update_log)
7782
pretrain.update(update_pretrain)
78-
#assemble configure
83+
# assemble configure
7984
config=edict()
8085
config.model=model
8186
config.train=train
@@ -84,7 +89,7 @@ def get_config():
8489
config.data=data
8590
config.log=log
8691
config.pretrain=pretrain
87-
#path configure
92+
# path configure
8893
import tensorflow as tf
8994
import tensorlayer as tl
9095
tl.files.exists_or_mkdir(config.model.model_dir, verbose=True) # to save model files
@@ -93,18 +98,78 @@ def get_config():
9398
tl.files.exists_or_mkdir(config.test.vis_dir, verbose=True) # to save visualization results
9499
tl.files.exists_or_mkdir(config.data.vis_dir, verbose=True) # to save visualization results
95100
tl.files.exists_or_mkdir(config.pretrain.pretrain_model_dir,verbose=True)
96-
#device configure
97-
#FIXME: replace experimental tf functions when in tf 2.1 version
101+
# device configure
102+
# FIXME: replace experimental tf functions when in tf 2.1 version
98103
tf.debugging.set_log_device_placement(False)
99104
tf.config.set_soft_device_placement(True)
100105
for gpu in tf.config.experimental.get_visible_devices("GPU"):
101106
tf.config.experimental.set_memory_growth(gpu,True)
102-
#limit the cpu usage when pretrain
103-
#logging configure
107+
108+
# logging configure
109+
110+
# logging file path init
104111
tl.files.exists_or_mkdir(os.path.dirname(config.log.log_path),verbose=True)
105112
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.WARN)
106113
tl.logging.set_verbosity(tl.logging.WARN)
107-
return deepcopy(config)
114+
115+
# Info logging configure
116+
info_logger = logging.getLogger(name="INFO")
117+
info_logger.setLevel(logging.INFO)
118+
# stream handler
119+
info_cHandler = logging.StreamHandler()
120+
info_cFormat = logging.Formatter("[%(name)s]: %(message)s")
121+
info_cHandler.setFormatter(info_cFormat)
122+
info_logger.addHandler(info_cHandler)
123+
# file handler
124+
info_fHandler = logging.FileHandler(config.log.log_path,mode="a")
125+
info_fFormat = logging.Formatter("%(asctime)s [%(name)s] %(levelname)s: %(message)s")
126+
info_fHandler.setFormatter(info_fFormat)
127+
info_logger.addHandler(info_fHandler)
128+
129+
# Dataset logging configure
130+
data_logger = logging.getLogger(name="DATA")
131+
data_logger.setLevel(logging.INFO)
132+
# stream handler
133+
data_cHandler = logging.StreamHandler()
134+
data_cFormat = logging.Formatter("[%(name)s] %(levelname)s: %(message)s")
135+
data_cHandler.setFormatter(data_cFormat)
136+
data_logger.addHandler(data_cHandler)
137+
# file handler
138+
data_fHandler = logging.FileHandler(config.log.log_path,mode="a")
139+
data_fFormat = logging.Formatter("%(asctime)s [%(name)s] %(levelname)s: %(message)s")
140+
data_fHandler.setFormatter(data_fFormat)
141+
data_logger.addHandler(data_fHandler)
142+
143+
# Model logging configure
144+
model_logger = logging.getLogger(name="MODEL")
145+
model_logger.setLevel(logging.INFO)
146+
# stream handler
147+
model_cHandler = logging.StreamHandler()
148+
model_cFormat = logging.Formatter("[%(name)s] %(levelname)s: %(message)s")
149+
model_cHandler.setFormatter(model_cFormat)
150+
model_logger.addHandler(model_cHandler)
151+
# file handler
152+
model_fHandler = logging.FileHandler(config.log.log_path,mode="a")
153+
model_fFormat = logging.Formatter("%(asctime)s [%(name)s] %(levelname)s: %(message)s")
154+
model_fHandler.setFormatter(model_fFormat)
155+
model_logger.addHandler(model_fHandler)
156+
157+
# Train logging configure
158+
train_logger = logging.getLogger(name="TRAIN")
159+
train_logger.setLevel(logging.INFO)
160+
# stream handler
161+
train_cHandler = logging.StreamHandler()
162+
train_cFormat = logging.Formatter("[%(name)s] %(levelname)s: %(message)s")
163+
train_cHandler.setFormatter(train_cFormat)
164+
train_logger.addHandler(train_cHandler)
165+
# file handler
166+
train_fHandler = logging.FileHandler(config.log.log_path,mode="a")
167+
train_fFormat = logging.Formatter("%(asctime)s [%(name)s] %(levelname)s: %(message)s")
168+
train_fHandler.setFormatter(train_fFormat)
169+
train_logger.addHandler(train_fHandler)
170+
171+
info("Configuration initialized!")
172+
return config
108173

109174
#set configure api
110175
#model configure api
@@ -235,12 +300,6 @@ def set_model_name(model_name):
235300
update_data.vis_dir=f"./save_dir/{update_model.model_name}/data_vis_dir"
236301
update_log.log_path= f"./save_dir/{update_model.model_name}/log.txt"
237302

238-
def set_model_parts(userdef_parts):
239-
update_model.userdef_parts=userdef_parts
240-
241-
def set_model_limbs(userdef_limbs):
242-
update_model.userdef_limbs=userdef_limbs
243-
244303
#train configure api
245304
def set_train_type(train_type):
246305
'''set single_train or parallel train
@@ -280,9 +339,6 @@ def set_learning_rate(learning_rate):
280339
'''
281340
update_train.lr_init=learning_rate
282341

283-
def set_save_interval(save_interval):
284-
update_train.save_interval=save_interval
285-
286342
def set_batch_size(batch_size):
287343
'''set the batch size in training
288344
@@ -423,7 +479,8 @@ def set_dataset_filter(dataset_filter):
423479
'''
424480
update_data.dataset_filter=dataset_filter
425481

426-
#log configure api
482+
# interval APIs
483+
# configure log interval
427484
def set_log_interval(log_interval):
428485
'''set the frequency of logging
429486
@@ -439,10 +496,51 @@ def set_log_interval(log_interval):
439496
-------
440497
None
441498
'''
442-
update_log.log_interval=log_interval
499+
if(log_interval is not None):
500+
update_log.log_interval=log_interval
501+
502+
# configure save_interval
503+
def set_save_interval(save_interval):
504+
if(save_interval is not None):
505+
update_train.save_interval = save_interval
506+
507+
# configure vis_interval
508+
def set_vis_interval(vis_interval):
509+
if(vis_interval is not None):
510+
update_train.vis_interval = vis_interval
511+
512+
# custome module interfaces
513+
# custom parts
514+
def set_custom_parts(custom_parts):
515+
update_model.custom_parts = custom_parts
516+
517+
# custom limbs
518+
def set_custom_limbs(custom_limbs):
519+
update_model.custom_limbs = custom_limbs
520+
521+
# custom augmentor
522+
def set_custom_augmentor(augmentor):
523+
update_model.augmentor = augmentor
524+
525+
# custom preprocessor
526+
def set_custom_preprocessor(preprocessor):
527+
update_model.preprocessor = preprocessor
528+
529+
# custom postprocessor
530+
def set_custom_postprocessor(postprocessor):
531+
update_model.postprocessor = postprocessor
532+
533+
# custom visualizer
534+
def set_custom_visualizer(visualizer):
535+
update_model.visualizer = visualizer
536+
443537

444538
def set_pretrain(enable):
445539
update_pretrain.enable=enable
446540

447541
def set_pretrain_dataset_path(pretrain_dataset_path):
448542
update_pretrain.pretrain_dataset_path=pretrain_dataset_path
543+
544+
def info(msg):
545+
info_logger = logging.getLogger("INFO")
546+
info_logger.info(msg)

hyperpose/Config/config_lopps.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@
3535
train.weight_decay_factor = 2e-4
3636
train.train_type=TRAIN.Single_train
3737
train.vis_dir=f"./save_dir/{model.model_name}/train_vis_dir"
38+
train.vis_interval=1000
3839

3940
#eval configuration
4041
eval =edict()
41-
eval.batch_size=22
42+
eval.batch_size=8
4243
eval.vis_dir=f"./save_dir/{model.model_name}/eval_vis_dir"
4344

4445
#test configuration

hyperpose/Config/config_opps.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
#train configuration
2424
train=edict()
25-
train.batch_size = 8
25+
train.batch_size = 4
2626
train.save_interval = 2000
2727
# total number of step
2828
train.n_step = 1000000

hyperpose/Config/define.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
class BACKBONE(Enum):
44
Default=0
55
Mobilenetv1=1
6-
Vgg19=2
7-
Resnet18=3
8-
Resnet50=4
9-
Vggtiny=5
10-
Mobilenetv2=6
11-
Vgg16=7
6+
Mobilenetv2=2
7+
MobilenetDilated=3
8+
MobilenetThin=4
9+
MobilenetSmall=5
10+
Vggtiny=6
11+
Vgg19=7
12+
Vgg16=8
13+
Resnet18=9
14+
Resnet50=10
1215

1316
class MODEL(Enum):
1417
Openpose=0

hyperpose/Dataset/__init__.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from .mpii_dataset import MPII_dataset
66
from .mscoco_dataset import MSCOCO_dataset
77
from .imagenet_dataset import Imagenet_dataset
8-
from .common import imread_rgb_float,imwrite_rgb_float
8+
from .dmadapt_dataset import Domainadapt_dataset
9+
from .common import log_data as log
910

1011
def get_dataset(config):
1112
'''get dataset object based on the config object
@@ -49,7 +50,7 @@ def get_dataset(config):
4950
model_type=config.model.model_type
5051
dataset_type=config.data.dataset_type
5152
if(dataset_type==DATA.MSCOCO):
52-
print("using MSCOCO dataset!")
53+
log("Using MSCOCO dataset!")
5354
if(model_type==MODEL.LightweightOpenpose or model_type==MODEL.MobilenetThinOpenpose or model_type==MODEL.Openpose):
5455
from .mscoco_dataset.define import opps_input_converter as input_kpt_cvter
5556
from .mscoco_dataset.define import opps_output_converter as output_kpt_cvter
@@ -62,7 +63,7 @@ def get_dataset(config):
6263
dataset=MSCOCO_dataset(config,input_kpt_cvter,output_kpt_cvter)
6364
dataset.prepare_dataset()
6465
elif(dataset_type==DATA.MPII):
65-
print("using MPII dataset!")
66+
log("Using MPII dataset!")
6667
if(model_type==MODEL.LightweightOpenpose or model_type==MODEL.MobilenetThinOpenpose or model_type==MODEL.Openpose):
6768
from .mpii_dataset.define import opps_input_converter as input_kpt_cvter
6869
from .mpii_dataset.define import opps_output_converter as output_kpt_cvter
@@ -72,18 +73,18 @@ def get_dataset(config):
7273
dataset=MPII_dataset(config,input_kpt_cvter,output_kpt_cvter)
7374
dataset.prepare_dataset()
7475
elif(dataset_type==DATA.USERDEF):
75-
print("using user-defined dataset!")
76+
log("Using user-defined dataset!")
7677
userdef_dataset=config.data.userdef_dataset
7778
dataset=userdef_dataset(config)
7879
elif(dataset_type==DATA.MULTIPLE):
79-
print("using multiple-combined dataset!")
80+
log("Using multiple-combined dataset!")
8081
combined_dataset_list=[]
8182
multiple_dataset_configs=config.data.multiple_dataset_configs
82-
print(f"total {len(multiple_dataset_configs)} datasets settled, initializing combined datasets individualy....")
83+
log(f"Total {len(multiple_dataset_configs)} datasets settled, initializing combined datasets individualy....")
8384
for dataset_idx,dataset_config in enumerate(multiple_dataset_configs):
84-
print(f"initializing combined dataset {dataset_idx},config:{dataset_config.data}...")
85+
log(f"Initializing combined dataset {dataset_idx},config:{dataset_config.data}...")
8586
combined_dataset_list.append(get_dataset(dataset_config))
86-
print("initialization finished")
87+
log("Initialization finished")
8788
dataset=Multi_dataset(config,combined_dataset_list)
8889
else:
8990
raise NotImplementedError(f"invalid dataset_type:{dataset_type}")
@@ -93,12 +94,15 @@ def get_dataset(config):
9394
def get_pretrain_dataset(config):
9495
return Imagenet_dataset(config)
9596

97+
def get_domainadapt_dataset(config):
98+
return Domainadapt_dataset(config.domainadapt_train_img_paths)
99+
96100
def enum2dataset(dataset_type):
97101
if(dataset_type==DATA.MSCOCO):
98102
return MSCOCO_dataset
99103
elif(dataset_type==DATA.MPII):
100104
return MPII_dataset
101105
elif(dataset_type==DATA.MULTIPLE):
102-
raise NotImplementedError("multiple dataset shouldn't be nested!")
106+
raise NotImplementedError("Multiple dataset shouldn't be nested!")
103107
else:
104108
raise NotImplementedError("Unknow dataset!")

0 commit comments

Comments
 (0)