-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathhifiapi.py
executable file
·52 lines (40 loc) · 1.45 KB
/
hifiapi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from hifi.models import Generator
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
class HIFIapi:
def __init__(self, config, device="gpu"):
if config.model_config["vocoder"]["use_cpu"]:
device = "cpu"
# Load checkpoint if exists
weights_path = config.hifi.weights_path
self.model = Generator(config.hifi)
if weights_path is not None:
checkpoint = torch.load(weights_path, map_location="cpu")
self.model.load_state_dict(checkpoint["generator"])
self.cfg = config
self.device = device
self.model.to(device)
self.model.remove_weight_norm()
self.model.eval()
# TODO:
def train(self):
raise NotImplemented(" Train for HiFi was not implemented yet")
def __call__(self, x):
x = x.to(self.device)
# use call for compatablity with other vocoders or functions
return self.model(x)
def generate(self, mel_specs):
"""
Converts mel spectrogramma into an audio file.
Returns cpu audio files.
mel_specs - a batch of mel spectrogramms
"""
self.model.eval()
with torch.no_grad():
audio = self.model(mel_specs)
audio = audio * self.cfg.hifi.MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
return audio