Skip to content

Commit b30ceac

Browse files
committed
成功的版本
0 parents  commit b30ceac

File tree

5 files changed

+743
-0
lines changed

5 files changed

+743
-0
lines changed

bleu_metrics.py

+220
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# !/usr/bin/env python3
2+
"""
3+
==== No Bugs in code, just some Random Unexpected FEATURES ====
4+
┌─────────────────────────────────────────────────────────────┐
5+
│┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐│
6+
││Esc│!1 │@2 │#3 │$4 │%5 │^6 │&7 │*8 │(9 │)0 │_- │+= │|\ │`~ ││
7+
│├───┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴───┤│
8+
││ Tab │ Q │ W │ E │ R │ T │ Y │ U │ I │ O │ P │{[ │}] │ BS ││
9+
│├─────┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴─────┤│
10+
││ Ctrl │ A │ S │ D │ F │ G │ H │ J │ K │ L │: ;│" '│ Enter ││
11+
│├──────┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴────┬───┤│
12+
││ Shift │ Z │ X │ C │ V │ B │ N │ M │< ,│> .│? /│Shift │Fn ││
13+
│└─────┬──┴┬──┴──┬┴───┴───┴───┴───┴───┴──┬┴───┴┬──┴┬─────┴───┘│
14+
│ │Fn │ Alt │ Space │ Alt │Win│ HHKB │
15+
│ └───┴─────┴───────────────────────┴─────┴───┘ │
16+
└─────────────────────────────────────────────────────────────┘
17+
18+
BLEU指标。
19+
20+
Reference:
21+
https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/metrics/bleu.py
22+
23+
Author: pankeyu
24+
Date: 2022/1/5
25+
"""
26+
import sys
27+
import math
28+
from typing import List
29+
from collections import defaultdict
30+
31+
import numpy as np
32+
33+
34+
def default_trans_func(output, label, seq_mask, vocab):
35+
seq_mask = np.expand_dims(seq_mask, axis=2).repeat(output.shape[2], axis=2)
36+
output = output * seq_mask
37+
idx = np.argmax(output, axis=2)
38+
prediction, references = [], []
39+
for i in range(idx.shape[0]):
40+
token_list = []
41+
for j in range(idx.shape[1]):
42+
if seq_mask[i][j][0] == 0:
43+
break
44+
token_list.append(vocab[idx[i][j]])
45+
prediction.append(token_list)
46+
47+
label = np.squeeze(label, axis=2)
48+
for i in range(label.shape[0]):
49+
token_list = []
50+
for j in range(label.shape[1]):
51+
if seq_mask[i][j][0] == 0:
52+
break
53+
token_list.append(vocab[label[i][j]])
54+
55+
references.append([token_list])
56+
return prediction, references
57+
58+
59+
def get_match_size(prediction_ngram, refs_ngram):
60+
ref_set = defaultdict(int)
61+
for ref_ngram in refs_ngram:
62+
tmp_ref_set = defaultdict(int)
63+
for ngram in ref_ngram:
64+
tmp_ref_set[tuple(ngram)] += 1
65+
for ngram, count in tmp_ref_set.items():
66+
ref_set[tuple(ngram)] = max(ref_set[tuple(ngram)], count)
67+
prediction_set = defaultdict(int)
68+
for ngram in prediction_ngram:
69+
prediction_set[tuple(ngram)] += 1
70+
match_size = 0
71+
for ngram, count in prediction_set.items():
72+
match_size += min(count, ref_set.get(tuple(ngram), 0))
73+
prediction_size = len(prediction_ngram)
74+
return match_size, prediction_size
75+
76+
77+
def get_ngram(sent, n_size, label=None):
78+
def _ngram(sent, n_size):
79+
ngram_list = []
80+
for left in range(len(sent) - n_size):
81+
ngram_list.append(sent[left : left + n_size + 1])
82+
return ngram_list
83+
84+
ngram_list = _ngram(sent, n_size)
85+
if label is not None:
86+
ngram_list = [ngram + "_" + label for ngram in ngram_list]
87+
return ngram_list
88+
89+
90+
class BLEU(object):
91+
"""
92+
BLEU 评估器。
93+
94+
Reference:
95+
https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/metrics/bleu.py
96+
97+
Examples:
98+
from bleu_metrics import BLEU
99+
100+
bleu = BLEU(n_size=1)
101+
prediction = ["猫", "在", "桌", "上"]
102+
references = [["猫", "在", "树", "上"]]
103+
bleu.add_inst(prediction, references)
104+
print(bleu.score()) # 0.75
105+
"""
106+
107+
def __init__(self, trans_func=None, vocab=None, n_size=4, name="bleu"):
108+
"""
109+
Args:
110+
trans_func (callable, optional): `trans_func` transforms the network
111+
output to string to calculate.
112+
vocab (dict|paddlenlp.data.vocab, optional): Vocab for target language.
113+
If `trans_func` is None and BLEU is used as `paddle.metric.Metric`
114+
instance, `default_trans_func` will be performed and `vocab` must
115+
be provided.
116+
n_size (int, optional): Number of gram for BLEU metric. Defaults to 4.
117+
weights (list, optional): The weights of precision of each gram.
118+
Defaults to None.
119+
name (str, optional): Name of `paddle.metric.Metric` instance.
120+
Defaults to "bleu".
121+
"""
122+
super().__init__()
123+
weights = [1 / n_size for _ in range(n_size)]
124+
self._name = name
125+
self.match_ngram = {}
126+
self.prediction_ngram = {}
127+
self.weights = weights
128+
self.bp_r = 0
129+
self.bp_c = 0
130+
self.n_size = n_size
131+
self.vocab = vocab
132+
self.trans_func = trans_func
133+
134+
def update(self, output, label, seq_mask=None):
135+
if self.trans_func is None:
136+
if self.vocab is None:
137+
raise AttributeError(
138+
"The `update` method requires users to provide `trans_func` or `vocab` when initializing BLEU."
139+
)
140+
prediction_list, references = default_trans_func(output, label, seq_mask=seq_mask, vocab=self.vocab)
141+
else:
142+
prediction_list, references = self.trans_func(output, label, seq_mask)
143+
if len(prediction_list) != len(references):
144+
raise ValueError("Length error! Please check the output of network.")
145+
for i in range(len(prediction_list)):
146+
self.add_inst(prediction_list[i], references[i])
147+
148+
def add_instance(self, prediction: List[str], references: List[List[str]]):
149+
"""
150+
Update the states based on a pair of prediction and references.
151+
152+
Args:
153+
prediction (list): Tokenized prediction sentence.
154+
references (list of list): List of tokenized ground truth sentences.
155+
"""
156+
for n_size in range(self.n_size):
157+
self.count_ngram(prediction, references, n_size)
158+
self.count_bp(prediction, references)
159+
160+
def count_ngram(self, prediction, references, n_size):
161+
prediction_ngram = get_ngram(prediction, n_size)
162+
refs_ngram = []
163+
for ref in references:
164+
refs_ngram.append(get_ngram(ref, n_size))
165+
if n_size not in self.match_ngram:
166+
self.match_ngram[n_size] = 0
167+
self.prediction_ngram[n_size] = 0
168+
match_size, prediction_size = get_match_size(prediction_ngram, refs_ngram)
169+
170+
self.match_ngram[n_size] += match_size
171+
self.prediction_ngram[n_size] += prediction_size
172+
173+
def count_bp(self, prediction, references):
174+
self.bp_c += len(prediction)
175+
self.bp_r += min([(abs(len(prediction) - len(ref)), len(ref)) for ref in references])[1]
176+
177+
def reset(self):
178+
self.match_ngram = {}
179+
self.prediction_ngram = {}
180+
self.bp_r = 0
181+
self.bp_c = 0
182+
183+
def accumulate(self):
184+
"""
185+
Calculates and returns the final bleu metric.
186+
187+
Returns:
188+
Tensor: Returns the accumulated metric `bleu` and its data type is float64.
189+
"""
190+
prob_list = []
191+
for n_size in range(self.n_size):
192+
try:
193+
if self.prediction_ngram[n_size] == 0:
194+
_score = 0.0
195+
else:
196+
_score = self.match_ngram[n_size] / float(self.prediction_ngram[n_size])
197+
except:
198+
_score = 0
199+
if _score == 0:
200+
_score = sys.float_info.min
201+
prob_list.append(_score)
202+
203+
logs = math.fsum(w_i * math.log(p_i) for w_i, p_i in zip(self.weights, prob_list))
204+
bp = math.exp(min(1 - self.bp_r / float(self.bp_c), 0))
205+
bleu = bp * math.exp(logs)
206+
return bleu
207+
208+
def compute(self):
209+
return self.accumulate()
210+
211+
def name(self):
212+
return self._name
213+
214+
215+
if __name__ == '__main__':
216+
blue = BLEU(n_size=1)
217+
prediction = list("猫坐在椅子上")
218+
references = [list("猫坐在树上")]
219+
blue.add_instance(prediction=prediction, references=references)
220+
print(blue.compute())

iTrainingLogger.py

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# !/usr/bin/env python3
2+
"""
3+
==== No Bugs in code, just some Random Unexpected FEATURES ====
4+
┌─────────────────────────────────────────────────────────────┐
5+
│┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐│
6+
││Esc│!1 │@2 │#3 │$4 │%5 │^6 │&7 │*8 │(9 │)0 │_- │+= │|\ │`~ ││
7+
│├───┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴───┤│
8+
││ Tab │ Q │ W │ E │ R │ T │ Y │ U │ I │ O │ P │{[ │}] │ BS ││
9+
│├─────┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴─────┤│
10+
││ Ctrl │ A │ S │ D │ F │ G │ H │ J │ K │ L │: ;│" '│ Enter ││
11+
│├──────┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴────┬───┤│
12+
││ Shift │ Z │ X │ C │ V │ B │ N │ M │< ,│> .│? /│Shift │Fn ││
13+
│└─────┬──┴┬──┴──┬┴───┴───┴───┴───┴───┴──┬┴───┴┬──┴┬─────┴───┘│
14+
│ │Fn │ Alt │ Space │ Alt │Win│ HHKB │
15+
│ └───┴─────┴───────────────────────┴─────┴───┘ │
16+
└─────────────────────────────────────────────────────────────┘
17+
18+
训练过程中的记录器,类似于SummaryWriter功能,只是SummaryWriter需要依赖于tensorboard和浏览器做可视化,
19+
该工具依赖matplotlib采用静态本地图片存储的形式,便于服务器快速查看训练结果。
20+
21+
Authors: pankeyu
22+
Date: 2021/10/17
23+
"""
24+
import os
25+
26+
import numpy as np
27+
import matplotlib.pyplot as plt
28+
29+
30+
class iSummaryWriter(object):
31+
32+
def __init__(self, log_path: str, log_name: str, params=[], extention='.png', max_columns=2,
33+
log_title=None, figsize=None):
34+
"""
35+
初始化函数,创建日志类。
36+
37+
Args:
38+
log_path (str): 日志存放文件夹
39+
log_name (str): 日志文件名
40+
parmas (list): 要记录的参数名字列表,e.g. -> ["loss", "reward", ...]
41+
extension (str): 图片存储格式
42+
max_columns (int): 一行中排列几张图,默认为一行2张(2个变量)的图。
43+
"""
44+
self.log_path = log_path
45+
if not os.path.exists(log_path):
46+
os.makedirs(log_path)
47+
self.log_name = log_name
48+
self.extention = extention
49+
self.max_param_index = -1
50+
self.max_columns_threshold = max_columns
51+
self.figsize = figsize
52+
self.params_dict = self.create_params_dict(params)
53+
self.log_title = log_title
54+
self.init_plt()
55+
self.update_ax_list()
56+
57+
def init_plt(self) -> None:
58+
plt.style.use('seaborn-v0_8-darkgrid')
59+
60+
def create_params_dict(self, params: list) -> dict:
61+
"""
62+
根据传入需要记录的变量名列表,创建监控变量字典。
63+
64+
Args:
65+
params (list): 监控变量名列表
66+
67+
Returns:
68+
dict: 监控变量名字典 -> {
69+
'loss': {'values': [0.44, 0.32, ...], 'epochs': [10, 20, ...], 'index': 0},
70+
'reward': {'values': [10.2, 13.2, ...], 'epochs': [10, 20, ...], 'index': 1},
71+
...
72+
}
73+
"""
74+
params_dict = {}
75+
for i, param in enumerate(params):
76+
params_dict[param] = {'values': [], 'epochs': [], 'index': i}
77+
self.max_param_index = i
78+
return params_dict
79+
80+
def update_ax_list(self) -> None:
81+
"""
82+
根据当前的监控变量字典,为每一个变量分配一个图区。
83+
"""
84+
# * 重新计算每一个变量对应的图幅索引
85+
params_num = self.max_param_index + 1
86+
if params_num <= 0:
87+
return
88+
89+
self.max_columns = params_num if params_num < self.max_columns_threshold else self.max_columns_threshold
90+
max_rows = (params_num - 1) // self.max_columns + 1 # * 所有变量最多几行
91+
figsize = self.figsize if self.figsize else (self.max_columns * 6,max_rows * 3) # 根据图个数计算整个图的figsize
92+
self.fig, self.axes = plt.subplots(max_rows, self.max_columns, figsize=figsize)
93+
94+
# * 如果只有一行但又不止一个图,需要手动reshape成(1, n)的形式
95+
if params_num > 1 and len(self.axes.shape) == 1:
96+
self.axes = np.expand_dims(self.axes, axis=0)
97+
98+
# * 重新设置log标题
99+
log_title = self.log_title if self.log_title else '[Training Log] {}'.format(
100+
self.log_name)
101+
self.fig.suptitle(log_title, fontsize=15)
102+
103+
def add_scalar(self, param: str, value: float, epoch: int) -> None:
104+
"""
105+
添加一条新的变量值记录。
106+
107+
Args:
108+
param (str): 变量名,e.g. -> 'loss'
109+
value (float): 此时的值。
110+
epoch (int): 此时的epoch数。
111+
"""
112+
# * 如果该参数是第一次加入,则将该参数加入到监控变量字典中
113+
if param not in self.params_dict:
114+
self.max_param_index += 1
115+
self.params_dict[param] = {'values': [],
116+
'epochs': [], 'index': self.max_param_index}
117+
self.update_ax_list()
118+
119+
self.params_dict[param]['values'].append(value)
120+
self.params_dict[param]['epochs'].append(epoch)
121+
122+
def record(self, dpi=200) -> None:
123+
"""
124+
调用该接口,对该类中目前所有监控的变量状态进行一次记录,将结果保存到本地文件中。
125+
"""
126+
for param, param_elements in self.params_dict.items():
127+
param_index = param_elements["index"]
128+
param_row, param_column = param_index // self.max_columns, param_index % self.max_columns
129+
ax = self.axes[param_row, param_column] if self.max_param_index > 0 else self.axes
130+
# ax.set_title(param)
131+
ax.set_xlabel('Step or Epoch')
132+
ax.set_ylabel(param)
133+
ax.plot(self.params_dict[param]['epochs'],
134+
self.params_dict[param]['values'],
135+
color='darkorange')
136+
print("param_index,param_row, param_column",param_index,param_row, param_column)
137+
print(self.params_dict)
138+
plt.savefig(os.path.join(self.log_path,
139+
self.log_name + self.extention), dpi=dpi)
140+
141+
142+
if __name__ == '__main__':
143+
import random
144+
import time
145+
146+
n_epochs = 10
147+
log_path, log_name = './', 'test'
148+
writer = iSummaryWriter(log_path=log_path, log_name=log_name)
149+
for i in range(n_epochs):
150+
loss, reward = 100 - random.random() * i, random.random() * i
151+
writer.add_scalar('loss', loss, i)
152+
writer.add_scalar('reward', reward, i)
153+
writer.add_scalar('random', reward, i)
154+
writer.record()
155+
print("Log has been saved at: {}".format(
156+
os.path.join(log_path, log_name)))
157+
time.sleep(3)

0 commit comments

Comments
 (0)