@@ -115,39 +115,42 @@ def sample(newcfg: DictConfig) -> None:
115
115
dataset = getattr (data_module , f"{ cfg .split } _dataset" )
116
116
117
117
from temos .data .sampling import upsample
118
- from tqdm import tqdm
118
+ from rich .progress import Progress
119
+ from rich .progress import track
119
120
120
121
# remove printing for changing the seed
121
122
logging .getLogger ('pytorch_lightning.utilities.seed' ).setLevel (logging .WARNING )
122
123
123
124
import torch
124
125
with torch .no_grad ():
125
- for keyid in (pbar := tqdm (dataset .keyids )):
126
- pbar .set_description (f"Processing { keyid } " )
127
- for index in range (cfg .number_of_samples ):
128
- one_data = dataset .load_keyid (keyid )
129
- # batch_size = 1 for reproductability
130
- batch = collate_datastruct_and_text ([one_data ])
131
- # fix the seed
132
- pl .seed_everything (index )
133
-
134
- if cfg .jointstype == "vertices" :
135
- vertices = model (batch )[0 ]
136
- motion = vertices .numpy ()
137
- # no upsampling here to keep memory
138
- # vertices = upsample(vertices, cfg.data.framerate, 100)
139
- else :
140
- joints = model (batch )[0 ]
141
- motion = joints .numpy ()
142
- # upscaling to compare with other methods
143
- motion = upsample (motion , cfg .data .framerate , 100 )
144
-
145
- if cfg .number_of_samples > 1 :
146
- npypath = path / f"{ keyid } _{ index } .npy"
147
- else :
148
- npypath = path / f"{ keyid } .npy"
149
-
150
- np .save (npypath , motion )
126
+ with Progress (transient = True ) as progress :
127
+ task = progress .add_task ("Sampling" , total = len (dataset .keyids ))
128
+ for keyid in dataset .keyids :
129
+ progress .update (task , description = f"Sampling { keyid } .." )
130
+ for index in range (cfg .number_of_samples ):
131
+ one_data = dataset .load_keyid (keyid )
132
+ # batch_size = 1 for reproductability
133
+ batch = collate_datastruct_and_text ([one_data ])
134
+ # fix the seed
135
+ pl .seed_everything (index )
136
+
137
+ if cfg .jointstype == "vertices" :
138
+ vertices = model (batch )[0 ]
139
+ motion = vertices .numpy ()
140
+ # no upsampling here to keep memory
141
+ # vertices = upsample(vertices, cfg.data.framerate, 100)
142
+ else :
143
+ joints = model (batch )[0 ]
144
+ motion = joints .numpy ()
145
+ # upscaling to compare with other methods
146
+ motion = upsample (motion , cfg .data .framerate , 100 )
147
+
148
+ if cfg .number_of_samples > 1 :
149
+ npypath = path / f"{ keyid } _{ index } .npy"
150
+ else :
151
+ npypath = path / f"{ keyid } .npy"
152
+ np .save (npypath , motion )
153
+ progress .update (task , advance = 1 )
151
154
152
155
logger .info ("All the sampling are done" )
153
156
logger .info (f"All the sampling are done. You can find them here:\n { path } " )
0 commit comments