Skip to content

Commit ecb81ac

Browse files
committed
switch from tqdm to rich
1 parent 68bd1c5 commit ecb81ac

File tree

12 files changed

+134
-44
lines changed

12 files changed

+134
-44
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ pip install torchmetrics==0.7
4343
pip install hydra-core --upgrade
4444
pip install hydra_colorlog --upgrade
4545
pip install shortuuid
46-
pip install tqdm
46+
pip install rich
4747
pip install pandas
4848
pip install transformers
4949
pip install psutil

configs/evaluate.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ logger_level: INFO
2323
defaults:
2424
- /path@path
2525
- /transforms/rots2joints/smplh@rots2joints
26-
- override hydra/job_logging: console
27-
- override hydra/hydra_logging: console
26+
- override hydra/job_logging: rich # console
27+
- override hydra/hydra_logging: rich # console
2828
- _self_
2929

3030
machine:

configs/hydra/hydra_logging/rich.yaml

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
version: 1
2+
3+
formatters:
4+
colorlog:
5+
(): colorlog.ColoredFormatter
6+
format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s'
7+
datefmt: '%d/%m/%y %H:%M:%S'
8+
9+
handlers:
10+
console:
11+
class: rich.logging.RichHandler # logging.StreamHandler
12+
formatter: colorlog
13+
14+
root:
15+
level: INFO
16+
handlers:
17+
- console
18+
19+
disable_existing_loggers: false

configs/hydra/job_logging/rich.yaml

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
version: 1
2+
3+
filters:
4+
onlyimportant:
5+
(): temos.tools.logging.LevelsFilter
6+
levels:
7+
- CRITICAL
8+
- ERROR
9+
- WARNING
10+
noimportant:
11+
(): temos.tools.logging.LevelsFilter
12+
levels:
13+
- INFO
14+
- DEBUG
15+
- NOTSET
16+
17+
formatters:
18+
verysimple:
19+
format: '%(message)s'
20+
simple:
21+
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
22+
datefmt: '%d/%m/%y %H:%M:%S'
23+
24+
colorlog:
25+
(): colorlog.ColoredFormatter
26+
format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s]
27+
- %(message)s'
28+
datefmt: '%d/%m/%y %H:%M:%S'
29+
30+
log_colors:
31+
DEBUG: purple
32+
INFO: green
33+
WARNING: yellow
34+
ERROR: red
35+
CRITICAL: red
36+
37+
handlers:
38+
console:
39+
class: rich.logging.RichHandler # logging.StreamHandler
40+
formatter: verysimple # colorlog
41+
rich_tracebacks: true
42+
43+
file_out:
44+
class: logging.FileHandler
45+
formatter: simple
46+
filename: logs.out
47+
filters:
48+
- noimportant
49+
50+
file_err:
51+
class: logging.FileHandler
52+
formatter: simple
53+
filename: logs.err
54+
filters:
55+
- onlyimportant
56+
57+
root:
58+
level: ${logger_level}
59+
handlers:
60+
- console
61+
- file_out
62+
- file_err
63+
64+
disable_existing_loggers: false

configs/sample.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ defaults:
3333
- trainer: null
3434
- sampler: all_conseq
3535
- /path@path
36-
- override hydra/job_logging: custom
37-
- override hydra/hydra_logging: custom
36+
- override hydra/job_logging: rich # custom
37+
- override hydra/hydra_logging: rich # custom
3838
- _self_
3939

4040

configs/train.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ defaults:
2222
- logger: none
2323
- callback: base
2424
- /path@path
25-
- override hydra/job_logging: custom
26-
- override hydra/hydra_logging: custom
25+
- override hydra/job_logging: rich # custom
26+
- override hydra/hydra_logging: rich # custom
2727
- _self_
2828

2929
data:

evaluate.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,22 @@ def regroup_metrics(metrics):
2424
APE_pose = dico.pop("APE_pose")
2525
for name, ape in zip(pose_names, APE_pose):
2626
dico[f"APE_pose_{name}"] = ape
27-
27+
2828
if "APE_joints" in dico:
2929
APE_joints = dico.pop("APE_joints")
3030
for name, ape in zip(mmm_joints, APE_joints):
3131
dico[f"APE_joints_{name}"] = ape
32-
32+
3333
if "AVE_pose" in dico:
3434
AVE_pose = dico.pop("AVE_pose")
3535
for name, ave in zip(pose_names, AVE_pose):
3636
dico[f"AVE_pose_{name}"] = ave
37-
37+
3838
if "AVE_joints" in dico:
3939
AVE_joints = dico.pop("AVE_joints")
4040
for name, ape in zip(mmm_joints, AVE_joints):
4141
dico[f"AVE_joints_{name}"] = ave
42-
42+
4343
return dico
4444

4545
def sanitize(dico):

sample.py

+30-27
Original file line numberDiff line numberDiff line change
@@ -115,39 +115,42 @@ def sample(newcfg: DictConfig) -> None:
115115
dataset = getattr(data_module, f"{cfg.split}_dataset")
116116

117117
from temos.data.sampling import upsample
118-
from tqdm import tqdm
118+
from rich.progress import Progress
119+
from rich.progress import track
119120

120121
# remove printing for changing the seed
121122
logging.getLogger('pytorch_lightning.utilities.seed').setLevel(logging.WARNING)
122123

123124
import torch
124125
with torch.no_grad():
125-
for keyid in (pbar := tqdm(dataset.keyids)):
126-
pbar.set_description(f"Processing {keyid}")
127-
for index in range(cfg.number_of_samples):
128-
one_data = dataset.load_keyid(keyid)
129-
# batch_size = 1 for reproductability
130-
batch = collate_datastruct_and_text([one_data])
131-
# fix the seed
132-
pl.seed_everything(index)
133-
134-
if cfg.jointstype == "vertices":
135-
vertices = model(batch)[0]
136-
motion = vertices.numpy()
137-
# no upsampling here to keep memory
138-
# vertices = upsample(vertices, cfg.data.framerate, 100)
139-
else:
140-
joints = model(batch)[0]
141-
motion = joints.numpy()
142-
# upscaling to compare with other methods
143-
motion = upsample(motion, cfg.data.framerate, 100)
144-
145-
if cfg.number_of_samples > 1:
146-
npypath = path / f"{keyid}_{index}.npy"
147-
else:
148-
npypath = path / f"{keyid}.npy"
149-
150-
np.save(npypath, motion)
126+
with Progress(transient=True) as progress:
127+
task = progress.add_task("Sampling", total=len(dataset.keyids))
128+
for keyid in dataset.keyids:
129+
progress.update(task, description=f"Sampling {keyid}..")
130+
for index in range(cfg.number_of_samples):
131+
one_data = dataset.load_keyid(keyid)
132+
# batch_size = 1 for reproductability
133+
batch = collate_datastruct_and_text([one_data])
134+
# fix the seed
135+
pl.seed_everything(index)
136+
137+
if cfg.jointstype == "vertices":
138+
vertices = model(batch)[0]
139+
motion = vertices.numpy()
140+
# no upsampling here to keep memory
141+
# vertices = upsample(vertices, cfg.data.framerate, 100)
142+
else:
143+
joints = model(batch)[0]
144+
motion = joints.numpy()
145+
# upscaling to compare with other methods
146+
motion = upsample(motion, cfg.data.framerate, 100)
147+
148+
if cfg.number_of_samples > 1:
149+
npypath = path / f"{keyid}_{index}.npy"
150+
else:
151+
npypath = path / f"{keyid}.npy"
152+
np.save(npypath, motion)
153+
progress.update(task, advance=1)
151154

152155
logger.info("All the sampling are done")
153156
logger.info(f"All the sampling are done. You can find them here:\n{path}")

temos/callback/progress.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule,
2525
if trainer.sanity_checking:
2626
logger.info("Sanity checking ok.")
2727

28-
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, **kwargs) -> None:
28+
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, padding=False, **kwargs) -> None:
2929
metric_format = f"{{:.{self.precision}e}}"
3030
line = f"Epoch {trainer.current_epoch}"
31-
line = f"{line:>{len('Epoch xxxx')}}" # Right padding
31+
if padding:
32+
line = f"{line:>{len('Epoch xxxx')}}" # Right padding
3233
metrics_str = []
3334

3435
losses_dict = trainer.callback_metrics

temos/data/kit.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
import torch
1010
from torch import nn
1111
from torch.utils.data import Dataset
12-
from tqdm import tqdm
1312
from pathlib import Path
1413

1514
from temos.tools.easyconvert import matrix_to, axis_angle_to
1615
from temos.transforms import Transform
1716
from temos.data.sampling import subsample
1817
from temos.data.tools.smpl import smpl_data_to_matrix_and_trans
1918

19+
from rich.progress import track
20+
2021
from .base import BASEDataModule
2122
from .utils import get_split_keyids
2223

@@ -90,7 +91,7 @@ def __init__(self, datapath: str,
9091
kitml_correspondances = json.load(correspondance_path_file)
9192

9293
if progress_bar:
93-
enumerator = enumerate(tqdm(keyids, f"Loading KIT {split}"))
94+
enumerator = enumerate(track(keyids, f"Loading KIT {split}"))
9495
else:
9596
enumerator = enumerate(keyids)
9697

temos/tools/logging.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import logging
2-
import tqdm
3-
42

53
class LevelsFilter(logging.Filter):
64
def __init__(self, levels):
@@ -32,6 +30,7 @@ def __init__(self, level=logging.NOTSET):
3230
super().__init__(level)
3331

3432
def emit(self, record):
33+
import tqdm
3534
try:
3635
msg = self.format(record)
3736
tqdm.tqdm.write(msg)

train.py

+3
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,14 @@ def train(cfg: DictConfig) -> None:
5050
"AVE root": "Metrics/AVE_root",
5151
"AVE mean pose": "Metrics/AVE_mean_pose"
5252
}
53+
5354
callbacks = [
55+
pl.callbacks.RichProgressBar(),
5456
instantiate(cfg.callback.progress, metric_monitor=metric_monitor),
5557
instantiate(cfg.callback.latest_ckpt),
5658
instantiate(cfg.callback.last_ckpt)
5759
]
60+
5861
logger.info("Callbacks initialized")
5962

6063
logger.info("Loading trainer")

0 commit comments

Comments
 (0)