Skip to content

Commit 2bd781c

Browse files
committed
add interaction script/custom prompt
1 parent 96d3da4 commit 2bd781c

File tree

2 files changed

+189
-0
lines changed

2 files changed

+189
-0
lines changed

configs/interact.yaml

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
hydra:
2+
run:
3+
dir: .
4+
output_subdir: null
5+
6+
seed: 1234
7+
8+
logger_level: INFO
9+
folder: ???
10+
11+
saving: ???
12+
length: ???
13+
text: ???
14+
15+
mean: false
16+
number_of_samples: 1
17+
fact: 1
18+
19+
ckpt_name: last.ckpt
20+
last_ckpt_path: ${get_last_checkpoint:${folder},${ckpt_name}}
21+
22+
# only used if trained with kit-amass-rot
23+
# so with smpl rotations
24+
jointstype: mmm
25+
26+
# if jointstype == vertices
27+
# can specify the gender
28+
# neutral / male / female
29+
gender: neutral
30+
31+
# Composing nested config with default
32+
defaults:
33+
- data: null
34+
- machine: null
35+
- trainer: null
36+
- sampler: all_conseq
37+
- /path@path
38+
- override hydra/job_logging: rich # custom
39+
- override hydra/hydra_logging: rich # custom
40+
- _self_
41+
42+
43+
data.batch_size: 1

interact.py

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import logging
2+
3+
import hydra
4+
import os
5+
from pathlib import Path
6+
from omegaconf import DictConfig, OmegaConf
7+
import temos.launch.prepare # noqa
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
@hydra.main(version_base=None, config_path="configs", config_name="interact")
13+
def _interact(cfg: DictConfig):
14+
return interact(cfg)
15+
16+
17+
def cfg_mean_nsamples_resolution(cfg):
18+
if cfg.mean and cfg.number_of_samples > 1:
19+
logger.error("All the samples will be the mean.. cfg.number_of_samples=1 will be forced.")
20+
cfg.number_of_samples = 1
21+
22+
return cfg.number_of_samples == 1
23+
24+
25+
def load_checkpoint(model, last_ckpt_path, *, eval_mode):
26+
# Load the last checkpoint
27+
# model = model.load_from_checkpoint(last_ckpt_path)
28+
# this will overide values
29+
# for example relative to rots2joints
30+
# So only load state dict is preferable
31+
import torch
32+
model.load_state_dict(torch.load(last_ckpt_path)["state_dict"])
33+
logger.info("Model weights restored.")
34+
35+
if eval_mode:
36+
model.eval()
37+
logger.info("Model in eval mode.")
38+
39+
40+
def interact(newcfg: DictConfig) -> None:
41+
# Load last config
42+
output_dir = Path(hydra.utils.to_absolute_path(newcfg.folder))
43+
last_ckpt_path = newcfg.last_ckpt_path
44+
45+
# Load previous config
46+
prevcfg = OmegaConf.load(output_dir / ".hydra/config.yaml")
47+
# Overload it
48+
cfg = OmegaConf.merge(prevcfg, newcfg)
49+
oneinteract = cfg_mean_nsamples_resolution(cfg)
50+
51+
text = cfg.text
52+
logger.info(f"Interaction script. The result will be saved there: {cfg.saving}")
53+
logger.info(f"The sentence is: {text}")
54+
55+
filename = (text
56+
.lower()
57+
.strip()
58+
.replace(" ", "_")
59+
.replace(".", "") + "_len_" + str(cfg.length)
60+
)
61+
62+
os.makedirs(cfg.saving, exist_ok=True)
63+
path = Path(cfg.saving)
64+
65+
import pytorch_lightning as pl
66+
import numpy as np
67+
import torch
68+
from hydra.utils import instantiate
69+
pl.seed_everything(cfg.seed)
70+
71+
logger.info("Loading model")
72+
if cfg.jointstype == "vertices":
73+
assert cfg.gender in ["male", "female", "neutral"]
74+
logger.info(f"The topology will be {cfg.gender}.")
75+
cfg.model.transforms.rots2joints.gender = cfg.gender
76+
77+
logger.info("Loading data module")
78+
data_module = instantiate(cfg.data)
79+
logger.info(f"Data module '{cfg.data.dataname}' loaded")
80+
81+
model = instantiate(cfg.model,
82+
nfeats=data_module.nfeats,
83+
logger_name="none",
84+
nvids_to_save=None,
85+
_recursive_=False)
86+
87+
logger.info(f"Model '{cfg.model.modelname}' loaded")
88+
89+
load_checkpoint(model, last_ckpt_path, eval_mode=True)
90+
91+
if "amass" in cfg.data.dataname and "xyz" not in cfg.data.dataname:
92+
model.transforms.rots2joints.jointstype = cfg.jointstype
93+
94+
model.sample_mean = cfg.mean
95+
model.fact = cfg.fact
96+
97+
if not model.hparams.vae and cfg.number_of_samples > 1:
98+
raise TypeError("Cannot get more than 1 sample if it is not a VAE.")
99+
100+
from temos.data.tools.collate import collate_text_and_length
101+
102+
from temos.data.sampling import upsample
103+
from rich.progress import Progress
104+
from rich.progress import track
105+
106+
# remove printing for changing the seed
107+
logging.getLogger('pytorch_lightning.utilities.seed').setLevel(logging.WARNING)
108+
109+
import torch
110+
with torch.no_grad():
111+
if True:
112+
# with Progress(transient=True) as progress:
113+
# task = progress.add_task("Sampling", total=len(dataset.keyids))
114+
# progress.update(task, description=f"Sampling {keyid}..")
115+
for index in range(cfg.number_of_samples):
116+
# batch_size = 1 for reproductability
117+
element = {"text": text, "length": cfg.length}
118+
batch = collate_text_and_length([element])
119+
120+
# fix the seed
121+
pl.seed_everything(50 + index)
122+
123+
if cfg.jointstype == "vertices":
124+
vertices = model(batch)[0]
125+
motion = vertices.numpy()
126+
# no upsampling here to keep memory
127+
# vertices = upinteract(vertices, cfg.data.framerate, 100)
128+
else:
129+
joints = model(batch)[0]
130+
motion = joints.numpy()
131+
# upscaling to compare with other methods
132+
motion = upsample(motion, cfg.data.framerate, 100)
133+
134+
if cfg.number_of_samples > 1:
135+
npypath = path / f"{filename}_{index}.npy"
136+
else:
137+
npypath = path / f"{filename}.npy"
138+
np.save(npypath, motion)
139+
# progress.update(task, advance=1)
140+
141+
logger.info("All the sampling are done")
142+
logger.info(f"All the sampling are done. You can find them here:\n{path}")
143+
144+
145+
if __name__ == '__main__':
146+
_interact()

0 commit comments

Comments
 (0)