1
+ # !/usr/bin/env python3
2
+ """
3
+ ==== No Bugs in code, just some Random Unexpected FEATURES ====
4
+ ┌─────────────────────────────────────────────────────────────┐
5
+ │┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐│
6
+ ││Esc│!1 │@2 │#3 │$4 │%5 │^6 │&7 │*8 │(9 │)0 │_- │+= │|\ │`~ ││
7
+ │├───┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴───┤│
8
+ ││ Tab │ Q │ W │ E │ R │ T │ Y │ U │ I │ O │ P │{[ │}] │ BS ││
9
+ │├─────┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴─────┤│
10
+ ││ Ctrl │ A │ S │ D │ F │ G │ H │ J │ K │ L │: ;│" '│ Enter ││
11
+ │├──────┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴────┬───┤│
12
+ ││ Shift │ Z │ X │ C │ V │ B │ N │ M │< ,│> .│? /│Shift │Fn ││
13
+ │└─────┬──┴┬──┴──┬┴───┴───┴───┴───┴───┴──┬┴───┴┬──┴┬─────┴───┘│
14
+ │ │Fn │ Alt │ Space │ Alt │Win│ HHKB │
15
+ │ └───┴─────┴───────────────────────┴─────┴───┘ │
16
+ └─────────────────────────────────────────────────────────────┘
17
+
18
+ BLEU指标。
19
+
20
+ Reference:
21
+ https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/metrics/bleu.py
22
+
23
+ Author: pankeyu
24
+ Date: 2022/1/5
25
+ """
26
+ import sys
27
+ import math
28
+ from typing import List
29
+ from collections import defaultdict
30
+
31
+ import numpy as np
32
+
33
+
34
+ def default_trans_func (output , label , seq_mask , vocab ):
35
+ seq_mask = np .expand_dims (seq_mask , axis = 2 ).repeat (output .shape [2 ], axis = 2 )
36
+ output = output * seq_mask
37
+ idx = np .argmax (output , axis = 2 )
38
+ prediction , references = [], []
39
+ for i in range (idx .shape [0 ]):
40
+ token_list = []
41
+ for j in range (idx .shape [1 ]):
42
+ if seq_mask [i ][j ][0 ] == 0 :
43
+ break
44
+ token_list .append (vocab [idx [i ][j ]])
45
+ prediction .append (token_list )
46
+
47
+ label = np .squeeze (label , axis = 2 )
48
+ for i in range (label .shape [0 ]):
49
+ token_list = []
50
+ for j in range (label .shape [1 ]):
51
+ if seq_mask [i ][j ][0 ] == 0 :
52
+ break
53
+ token_list .append (vocab [label [i ][j ]])
54
+
55
+ references .append ([token_list ])
56
+ return prediction , references
57
+
58
+
59
+ def get_match_size (prediction_ngram , refs_ngram ):
60
+ ref_set = defaultdict (int )
61
+ for ref_ngram in refs_ngram :
62
+ tmp_ref_set = defaultdict (int )
63
+ for ngram in ref_ngram :
64
+ tmp_ref_set [tuple (ngram )] += 1
65
+ for ngram , count in tmp_ref_set .items ():
66
+ ref_set [tuple (ngram )] = max (ref_set [tuple (ngram )], count )
67
+ prediction_set = defaultdict (int )
68
+ for ngram in prediction_ngram :
69
+ prediction_set [tuple (ngram )] += 1
70
+ match_size = 0
71
+ for ngram , count in prediction_set .items ():
72
+ match_size += min (count , ref_set .get (tuple (ngram ), 0 ))
73
+ prediction_size = len (prediction_ngram )
74
+ return match_size , prediction_size
75
+
76
+
77
+ def get_ngram (sent , n_size , label = None ):
78
+ def _ngram (sent , n_size ):
79
+ ngram_list = []
80
+ for left in range (len (sent ) - n_size ):
81
+ ngram_list .append (sent [left : left + n_size + 1 ])
82
+ return ngram_list
83
+
84
+ ngram_list = _ngram (sent , n_size )
85
+ if label is not None :
86
+ ngram_list = [ngram + "_" + label for ngram in ngram_list ]
87
+ return ngram_list
88
+
89
+
90
+ class BLEU (object ):
91
+ """
92
+ BLEU 评估器。
93
+
94
+ Reference:
95
+ https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/metrics/bleu.py
96
+
97
+ Examples:
98
+ from bleu_metrics import BLEU
99
+
100
+ bleu = BLEU(n_size=1)
101
+ prediction = ["猫", "在", "桌", "上"]
102
+ references = [["猫", "在", "树", "上"]]
103
+ bleu.add_inst(prediction, references)
104
+ print(bleu.score()) # 0.75
105
+ """
106
+
107
+ def __init__ (self , trans_func = None , vocab = None , n_size = 4 , name = "bleu" ):
108
+ """
109
+ Args:
110
+ trans_func (callable, optional): `trans_func` transforms the network
111
+ output to string to calculate.
112
+ vocab (dict|paddlenlp.data.vocab, optional): Vocab for target language.
113
+ If `trans_func` is None and BLEU is used as `paddle.metric.Metric`
114
+ instance, `default_trans_func` will be performed and `vocab` must
115
+ be provided.
116
+ n_size (int, optional): Number of gram for BLEU metric. Defaults to 4.
117
+ weights (list, optional): The weights of precision of each gram.
118
+ Defaults to None.
119
+ name (str, optional): Name of `paddle.metric.Metric` instance.
120
+ Defaults to "bleu".
121
+ """
122
+ super ().__init__ ()
123
+ weights = [1 / n_size for _ in range (n_size )]
124
+ self ._name = name
125
+ self .match_ngram = {}
126
+ self .prediction_ngram = {}
127
+ self .weights = weights
128
+ self .bp_r = 0
129
+ self .bp_c = 0
130
+ self .n_size = n_size
131
+ self .vocab = vocab
132
+ self .trans_func = trans_func
133
+
134
+ def update (self , output , label , seq_mask = None ):
135
+ if self .trans_func is None :
136
+ if self .vocab is None :
137
+ raise AttributeError (
138
+ "The `update` method requires users to provide `trans_func` or `vocab` when initializing BLEU."
139
+ )
140
+ prediction_list , references = default_trans_func (output , label , seq_mask = seq_mask , vocab = self .vocab )
141
+ else :
142
+ prediction_list , references = self .trans_func (output , label , seq_mask )
143
+ if len (prediction_list ) != len (references ):
144
+ raise ValueError ("Length error! Please check the output of network." )
145
+ for i in range (len (prediction_list )):
146
+ self .add_inst (prediction_list [i ], references [i ])
147
+
148
+ def add_instance (self , prediction : List [str ], references : List [List [str ]]):
149
+ """
150
+ Update the states based on a pair of prediction and references.
151
+
152
+ Args:
153
+ prediction (list): Tokenized prediction sentence.
154
+ references (list of list): List of tokenized ground truth sentences.
155
+ """
156
+ for n_size in range (self .n_size ):
157
+ self .count_ngram (prediction , references , n_size )
158
+ self .count_bp (prediction , references )
159
+
160
+ def count_ngram (self , prediction , references , n_size ):
161
+ prediction_ngram = get_ngram (prediction , n_size )
162
+ refs_ngram = []
163
+ for ref in references :
164
+ refs_ngram .append (get_ngram (ref , n_size ))
165
+ if n_size not in self .match_ngram :
166
+ self .match_ngram [n_size ] = 0
167
+ self .prediction_ngram [n_size ] = 0
168
+ match_size , prediction_size = get_match_size (prediction_ngram , refs_ngram )
169
+
170
+ self .match_ngram [n_size ] += match_size
171
+ self .prediction_ngram [n_size ] += prediction_size
172
+
173
+ def count_bp (self , prediction , references ):
174
+ self .bp_c += len (prediction )
175
+ self .bp_r += min ([(abs (len (prediction ) - len (ref )), len (ref )) for ref in references ])[1 ]
176
+
177
+ def reset (self ):
178
+ self .match_ngram = {}
179
+ self .prediction_ngram = {}
180
+ self .bp_r = 0
181
+ self .bp_c = 0
182
+
183
+ def accumulate (self ):
184
+ """
185
+ Calculates and returns the final bleu metric.
186
+
187
+ Returns:
188
+ Tensor: Returns the accumulated metric `bleu` and its data type is float64.
189
+ """
190
+ prob_list = []
191
+ for n_size in range (self .n_size ):
192
+ try :
193
+ if self .prediction_ngram [n_size ] == 0 :
194
+ _score = 0.0
195
+ else :
196
+ _score = self .match_ngram [n_size ] / float (self .prediction_ngram [n_size ])
197
+ except :
198
+ _score = 0
199
+ if _score == 0 :
200
+ _score = sys .float_info .min
201
+ prob_list .append (_score )
202
+
203
+ logs = math .fsum (w_i * math .log (p_i ) for w_i , p_i in zip (self .weights , prob_list ))
204
+ bp = math .exp (min (1 - self .bp_r / float (self .bp_c ), 0 ))
205
+ bleu = bp * math .exp (logs )
206
+ return bleu
207
+
208
+ def compute (self ):
209
+ return self .accumulate ()
210
+
211
+ def name (self ):
212
+ return self ._name
213
+
214
+
215
+ if __name__ == '__main__' :
216
+ blue = BLEU (n_size = 1 )
217
+ prediction = list ("猫坐在椅子上" )
218
+ references = [list ("猫坐在树上" )]
219
+ blue .add_instance (prediction = prediction , references = references )
220
+ print (blue .compute ())
0 commit comments