-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathutils.py
57 lines (44 loc) · 2.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch import nn
import torch.nn.functional as F
def compute_kl_divergence(model, target_model, inputs):
with torch.no_grad():
ref_outputs = target_model(**inputs)
ref_probs = F.log_softmax(ref_outputs.logits, dim=-1)
ref_probs = F.log_softmax(ref_outputs.logits, dim=-1)
ref_probs = ref_probs.view(-1, ref_outputs.logits.shape[-1])
outputs = model(**inputs)
current_probs = F.log_softmax(outputs.logits, dim=-1)
current_probs = current_probs.view(-1, outputs.logits.shape[-1])
# minimum KL divergence
return nn.functional.kl_div(
current_probs, ref_probs, reduction="batchmean", log_target=True
), outputs
def compute_batch_nll(model, inputs):
# get the sum loss for each sequence in a batch
# NOTE: not same as model(**inputs).loss but has sum loss for each seq in a batch
outputs = model(**inputs)
logits = outputs.logits
labels = inputs["labels"]
shifted_labels = labels[..., 1:].contiguous()
logits = logits[..., :-1, :].contiguous()
loss_function = nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
loss = loss_function(logits.transpose(-1, -2), shifted_labels).sum(dim=-1)
return loss, outputs
def compute_dpo_loss(model, ref_model, win_inputs=None, lose_inputs=None, beta=1.0):
if win_inputs is None and lose_inputs is None:
raise ValueError("Both win_inputs and lose_inputs can't be None")
win_log_ratio, lose_log_ratio = 0.0, 0.0
win_outputs, lose_outputs = None, None
if win_inputs is not None:
win_loss, win_outputs = compute_batch_nll(model, win_inputs)
with torch.no_grad():
win_ref_loss, _ = compute_batch_nll(ref_model, win_inputs)
win_log_ratio = -(win_loss - win_ref_loss)
if lose_inputs is not None:
lose_loss, lose_outputs = compute_batch_nll(model, lose_inputs)
with torch.no_grad():
lose_ref_loss, _ = compute_batch_nll(ref_model, lose_inputs)
lose_log_ratio = -(lose_loss - lose_ref_loss)
loss = -2 / beta * F.logsigmoid(beta * (win_log_ratio - lose_log_ratio)).mean()
return loss, (win_outputs, lose_outputs)