-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdep-graph.lisp
215 lines (195 loc) · 10.4 KB
/
dep-graph.lisp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
;; -*- mode: Lisp; coding: utf-8-unix; -*-
;; Copyright (c) 2024, April & May
;; SPDX-License-Identifier: 0BSD
(in-package aprnlp)
(defclass dep-graph-parser (dep-parser)
((name :initform "unnamed-dep-graph-parser")))
(defun dep-graph-features (head word)
(let (;(head-id (word-id head))
;(word-id (word-id word))
(head-form (word-form head))
(word-form (word-form word))
(head-pos (word-upos head))
(word-pos (word-upos word))
(head-shape (word-shape head))
(word-shape (word-shape word))
(head-suffix (word-suffix head))
(word-suffix (word-suffix word))
(distance (if (= (word-id head) 0) 0
(- (word-id head) (word-id word)))))
(vector ;(list :form head-form word-form)
;(list :form head-form t)
;(list :form t word-form)
(list :pos head-pos word-pos)
(list :pos head-pos t)
(list :pos t word-pos)
;(list :shape head-shape t)
;(list :shape t word-shape)
(list :shape head-shape word-shape)
(list :word-pos head-form word-pos)
(list :pos-word head-pos word-form)
;(list :head head-form head-pos)
;(list :head head-shape head-pos)
;(list :word word-form word-pos)
;(list :word word-shape word-pos)
;(list :id head-id word-id)
(list :distance t distance)
(list head-pos word-pos distance)
;(list head-form word-form distance)
(list :suffix head-pos word-suffix)
(list :suffix head-suffix word-pos)
(list :suffix head-suffix word-suffix)
)))
(defun generate-vertexes (parser sentence)
"Generate vertexes and all their incoming edges from SENTENCE.
vertexes := ((vertex (income score) (income score) ...) (vertex ...) ...)"
(iter (for vertex :in-vector sentence)
(collect
(cons vertex
(cons (let ((features (dep-graph-features *root-word* vertex)))
(list *root-word* (iter (for feature :in-vector features)
(sum (get-weight (slot-value parser 'weights) feature :uas)))))
(iter (for income :in-vector sentence)
(let* ((features (dep-graph-features income vertex))
(score (iter (for feature :in-vector features)
(sum (get-weight (slot-value parser 'weights) feature :uas)))))
(unless (eq income vertex)
(collect (list income score))))))))))
(defun adjust-weights (vertexes)
"subtracting the score of the maximum edge entering each vertex from
the score of all the edges entering that vertex."
(iter (for vertex-index :from 0)
(for (nil . income-score-list) :in vertexes)
(for max-score :next (iter (for (nil score) :in income-score-list)
(maximize score)))
(iter (for income-index :from 1)
(for (nil score) :in income-score-list)
(setf (nth 1 (nth income-index (nth vertex-index vertexes)))
(- score max-score))))
vertexes)
(defun select-best-score (vertexes)
"Return vertexes with their best-scored incoming edge in a form of ((vertex income score) ...)"
(iter (for (vertex . income-score-list) :in vertexes)
(collect (cons vertex (iter (for list :in income-score-list)
(finding list maximizing (second list)))))))
(defun spanning-tree-p (vertexes)
"Test if the vertexes consist a spanning tree, i.e. without cycle.
If there is a cycle, return the first cycle as the second value."
(let (previous)
(iter (for (vertex income) :in vertexes)
(if (member (list income vertex) previous :test #'equal)
(return (values nil (list vertex income)))
(push (list vertex income) previous))
(finally (return t)))))
(defun contract (vertexes cycled collapsed-vertex)
(destructuring-bind (cycled-vertex-1 cycled-vertex-2) cycled
(let ((cycled-income-scores-1 (cdr (find cycled-vertex-1 vertexes :key #'first)))
(cycled-income-scores-2 (cdr (find cycled-vertex-2 vertexes :key #'first))))
(setq vertexes (delete-if (op (member (first _) cycled)) vertexes))
(push (cons collapsed-vertex
(nconc (mapcar (op (push-end cycled-vertex-1 _))
(delete cycled-vertex-2 cycled-income-scores-1 :key #'first))
(mapcar (op (push-end cycled-vertex-2 _))
(delete cycled-vertex-1 cycled-income-scores-2 :key #'first))))
vertexes)
vertexes)))
(defun expand (best-vertexes cycled collapsed-vertex)
(destructuring-bind (income score real-vertex)
(cdr (find collapsed-vertex best-vertexes :key #'first))
(deletef best-vertexes collapsed-vertex :key #'first)
(push (list real-vertex income score) best-vertexes)
(push (list (first (delete real-vertex cycled)) real-vertex 0) best-vertexes)
(iter (for vertex-index :from 0)
(for (nil income) :in best-vertexes)
(when (eq income collapsed-vertex)
(setf (nth 1 (nth vertex-index best-vertexes)) real-vertex)))
best-vertexes))
(defun max-spanning-tree (vertexes)
(let* ((adjusted (adjust-weights vertexes))
(best (select-best-score adjusted)))
(multiple-value-bind (spanningp cycled)
(spanning-tree-p best)
(if spanningp best
(let ((collapsed (gensym "COLLAPSED-")))
(expand (max-spanning-tree (contract adjusted cycled collapsed)) cycled collapsed))))))
(defmethod process ((parser dep-graph-parser) sentence)
(iter (for (vertex head) :in (max-spanning-tree (generate-vertexes parser sentence)))
(setf (word-head vertex) (word-id head))))
(defun dep-graph-parser-update (parser word head-truth head-guess)
(declare (optimize (speed 3) (space 0) (safety 0) #+lispworks (hcl:fixnum-safety 0)))
(with-slots (weights living-weights last-updates update-count) parser
(flet ((upd (feature val)
(declare (inline upd))
(let* ((last-update (get-weight last-updates feature :uas))
(lived-cycle (- update-count last-update)))
(add-weight living-weights feature :uas
(* lived-cycle (get-weight weights feature :uas)))
(set-weight last-updates feature :uas update-count)
(add-weight weights feature :uas val))))
(incf update-count)
(unless (eq head-truth head-guess)
(let ((truth-features (dep-graph-features head-truth word))
(guess-features (when head-guess (dep-graph-features head-guess word))))
(iter (for feature :in-vector truth-features)
(upd feature 1))
(when head-guess
(iter (for feature :in-vector guess-features)
(upd feature -1))))))))
(defun dep-graph-parser-train-sentence (parser sentence)
(let ((tree (max-spanning-tree (generate-vertexes parser sentence)))
(correct-count 0)
(total-count 0))
(iter (for (vertex head-guess) :in tree)
(for head-truth :next (if (= (word-head vertex) 0)
*root-word*
(find (word-head vertex) sentence :key #'word-id)))
(when (eq head-truth head-guess)
(incf correct-count))
(dep-graph-parser-update parser vertex head-truth head-guess)
(incf total-count))
(values correct-count total-count)))
(defmethod test ((parser dep-graph-parser) sentences)
(let ((correct-count 0)
(total-count 0)
(start-time (get-internal-real-time)))
(iter (for sentence :in-vector sentences)
(iter (for (vertex head-guess) :in (max-spanning-tree (generate-vertexes parser sentence)))
(when (= (word-head vertex) (word-id head-guess))
(incf correct-count))
(incf total-count)))
(log-info "Test ~D sentences using ~,2Fs, result: ~D/~D (~,2F%)"
(length sentences)
#+lispworks (/ (- (get-internal-real-time) start-time) 1000)
#-lispworks (/ (- (get-internal-real-time) start-time) 1000000)
correct-count total-count (* 100 (/ correct-count total-count)))
(float (* 100 (/ correct-count total-count)))))
(defmethod train ((parser dep-graph-parser) sentences &key (cycles 5) (save-dir (asdf:system-source-directory :aprnlp)))
(log-info "Start training with ~D sentences, ~D cycles. ~A"
(length sentences) cycles
#+lispworks (lw:string-append "Heap size: " (print-size (getf (sys:room-values) :total-size)))
#-lispworks "")
(iter (for cycle :range cycles)
(let ((correct-count 0)
(total-count 0)
(cycle-start-time (get-internal-real-time)))
(iter (for sentence :in-vector sentences)
(multiple-value-bind (correct total)
(dep-graph-parser-train-sentence parser sentence)
(incf correct-count correct)
(incf total-count total)))
(log-info "Cycle ~D/~D completed using ~,2Fs with ~D/~D (~,2F%) correct. ~A"
(1+ cycle) cycles (/ (- (get-internal-real-time) cycle-start-time) 1000)
correct-count total-count (* 100.0 (/ correct-count total-count))
#+lispworks (lw:string-append "Heap size: " (print-size (getf (sys:room-values) :total-size)))
#-lispworks ""))
(shuffle sentences))
;(dep-parser-average-weights parser)
(save-processor parser save-dir))
(defmethod test-training ((class (eql 'dep-graph-parser)))
(let ((parser (make-instance 'dep-graph-parser))
(ud-dir (merge-pathnames "ud-treebanks-v2.14/" (asdf:system-source-directory :aprnlp))))
(train parser (read-conllu-files (merge-pathnames "UD_English-GUM/en_gum-ud-train.conllu" ud-dir))
:cycles 5)
(test parser (read-conllu-files (merge-pathnames "UD_English-GUM/en_gum-ud-test.conllu" ud-dir)))
(setq *loaded-dep-parser* parser)))
;(test-training 'dep-graph-parser)