-
Notifications
You must be signed in to change notification settings - Fork 246
/
Copy pathRingSolver.agda
388 lines (321 loc) · 14.1 KB
/
RingSolver.agda
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
------------------------------------------------------------------------
-- The Agda standard library
--
-- A solver that uses reflection to automatically obtain and solve
-- equations over rings.
------------------------------------------------------------------------
{-# OPTIONS --cubical-compatible --safe #-}
module Tactic.RingSolver where
open import Algebra
open import Data.Fin.Base as Fin using (Fin)
open import Data.Vec.Base as Vec using (Vec; _∷_; [])
open import Data.List.Base as List using (List; _∷_; [])
open import Data.Maybe.Base as Maybe using (Maybe; just; nothing; fromMaybe)
open import Data.Nat.Base using (ℕ; suc; zero; _<ᵇ_)
open import Data.Bool.Base using (Bool; if_then_else_; true; false)
open import Data.Unit.Base using (⊤)
open import Data.String.Base as String using (String; _++_; parens)
open import Data.Product.Base using (_,_; proj₁)
open import Function.Base
open import Reflection
open import Reflection.AST.Argument
open import Reflection.AST.Term as Term
open import Reflection.AST.AlphaEquality
open import Reflection.AST.Name as Name
open import Reflection.TCM.Syntax
open import Data.Nat.Reflection
open import Data.List.Reflection
import Data.Vec.Reflection as Vec
open import Tactic.RingSolver.NonReflective renaming (solve to solver)
open import Tactic.RingSolver.Core.AlmostCommutativeRing
open import Tactic.RingSolver.Core.NatSet as NatSet
open AlmostCommutativeRing
------------------------------------------------------------------------
-- Utilities
private
VarMap : Set
VarMap = ℕ → Maybe Term
getVisible : Arg Term → Maybe Term
getVisible (arg (arg-info visible _) x) = just x
getVisible _ = nothing
getVisibleArgs : ∀ n → Term → Maybe (Vec Term n)
getVisibleArgs n (def _ xs) = Maybe.map Vec.reverse
(List.foldl f c (List.mapMaybe getVisible xs) n)
where
f : (∀ n → Maybe (Vec Term n)) → Term → ∀ n → Maybe (Vec Term n)
f xs x zero = just []
f xs x (suc n) = Maybe.map (x ∷_) (xs n)
c : ∀ n → Maybe (Vec Term n)
c zero = just []
c (suc _ ) = nothing
getVisibleArgs _ _ = nothing
curriedTerm : NatSet → Term
curriedTerm = List.foldr go Vec.`[] ∘ NatSet.toList
where
go : ℕ → Term → Term
go x xs = var x [] Vec.`∷ xs
------------------------------------------------------------------------
-- Reflection utilities for rings
`AlmostCommutativeRing : Term
`AlmostCommutativeRing = def (quote AlmostCommutativeRing) (2 ⋯⟨∷⟩ [])
record RingOperatorTerms : Set where
constructor add⇒_mul⇒_pow⇒_neg⇒_sub⇒_
field
add mul pow neg sub : Term
checkIsRing : Term → TC Term
checkIsRing ring = checkType ring `AlmostCommutativeRing
module RingReflection (`ring : Term) where
-- Takes the name of a function that takes the ring as it's first
-- explicit argument and the terms of it's arguments and inserts
-- the required ring arguments
-- e.g. "_+_" $ʳ xs = "_+_ {_} {_} ring xs"
infixr 6 _$ʳ_
_$ʳ_ : Name → Args Term → Term
nm $ʳ args = def nm (2 ⋯⟅∷⟆ `ring ⟨∷⟩ args)
`Carrier : Term
`Carrier = quote Carrier $ʳ []
`refl : Term
`refl = quote refl $ʳ (1 ⋯⟅∷⟆ [])
`sym : Term → Term
`sym x≈y = quote sym $ʳ (2 ⋯⟅∷⟆ x≈y ⟨∷⟩ [])
`trans : Term → Term → Term
`trans x≈y y≈z = quote trans $ʳ (3 ⋯⟅∷⟆ x≈y ⟨∷⟩ y≈z ⟨∷⟩ [])
-- Normalises each of the fields of the ring operator so we can
-- compare the result against the normalised definitions we come
-- across when converting the term passed to the macro.
getRingOperatorTerms : TC RingOperatorTerms
getRingOperatorTerms = ⦇
add⇒ normalise (quote _+_ $ʳ [])
mul⇒ normalise (quote _*_ $ʳ [])
pow⇒ normalise (quote _^_ $ʳ [])
neg⇒ normalise (quote (-_) $ʳ [])
sub⇒ normalise (quote _-_ $ʳ [])
⦈
------------------------------------------------------------------------
-- Reflection utilities for ring solver
module RingSolverReflection (ring : Term) (numberOfVariables : ℕ) where
open RingReflection ring
`numberOfVariables : Term
`numberOfVariables = toTerm numberOfVariables
-- This function applies the hidden arguments that the constructors
-- that Expr needs. The first is the universe level, the second is the
-- type it contains, and the third is the number of variables it's
-- indexed by. All three of these could likely be inferred, but to
-- make things easier we supply the third because we know it.
infix -1 _$ᵉ_
_$ᵉ_ : Name → List (Arg Term) → Term
e $ᵉ xs = con e (1 ⋯⟅∷⟆ `Carrier ⟅∷⟆ `numberOfVariables ⟅∷⟆ xs)
-- A constant expression.
`Κ : Term → Term
`Κ x = quote Κ $ᵉ (x ⟨∷⟩ [])
`I : Term → Term
`I x = quote Ι $ᵉ (x ⟨∷⟩ [])
infixl 6 _`⊜_
_`⊜_ : Term → Term → Term
x `⊜ y = quote _⊜_ $ʳ (`numberOfVariables ⟅∷⟆ x ⟨∷⟩ y ⟨∷⟩ [])
`correct : Term → Term → Term
`correct x ρ = quote Ops.correct $ʳ (1 ⋯⟅∷⟆ x ⟨∷⟩ ρ ⟨∷⟩ [])
`solver : Term → Term → Term
`solver `f `eq = quote solver $ʳ (`numberOfVariables ⟨∷⟩ `f ⟨∷⟩ `eq ⟨∷⟩ [])
-- Converts the raw terms provided by the macro into the `Expr`s
-- used internally by the solver.
--
-- When trying to figure out the shape of an expression, one of
-- the difficult tasks is recognizing where constants in the
-- underlying ring are used. If we were only dealing with ℕ, we
-- might look for its constructors: however, we want to deal with
-- arbitrary types which implement AlmostCommutativeRing. If the
-- Term type contained type information we might be able to
-- recognize it there, but it doesn't.
--
-- We're in luck, though, because all other cases in the following
-- function *are* recognizable. As a result, the "catch-all" case
-- will just assume that it has a constant expression.
convertTerm : RingOperatorTerms → VarMap → Term → TC Term
convertTerm operatorTerms varMap = convert
where
open RingOperatorTerms operatorTerms
mutual
convert : Term → TC Term
-- First try and match directly against the fields
convert (def (quote _+_) xs) = convertOp₂ (quote _⊕_) xs
convert (def (quote _*_) xs) = convertOp₂ (quote _⊗_) xs
convert (def (quote -_) xs) = convertOp₁ (quote ⊝_) xs
convert (def (quote _^_) xs) = convertExp xs
convert (def (quote _-_) xs) = convertSub xs
-- Other definitions the underlying implementation of the ring's fields
convert (def nm xs) = convertUnknownName nm xs
-- Variables
convert v@(var x _) = pure $ fromMaybe (`Κ v) (varMap x)
-- Special case to recognise "suc" for naturals
convert (`suc x) = convertSuc x
-- Otherwise we're forced to treat it as a constant
convert t = pure $ `Κ t
-- Application of a ring operator often doesn't have a type as
-- simple as "Carrier → Carrier → Carrier": there may be hidden
-- arguments, etc. Here, we do our best to handle those cases,
-- by just taking the last two explicit arguments.
convertOp₂ : Name → Args Term → TC Term
convertOp₂ nm (x ⟨∷⟩ y ⟨∷⟩ []) = do
x' ← convert x
y' ← convert y
pure (nm $ᵉ (x' ⟨∷⟩ y' ⟨∷⟩ []))
convertOp₂ nm (x ∷ xs) = convertOp₂ nm xs
convertOp₂ _ _ = pure unknown
convertOp₁ : Name → Args Term → TC Term
convertOp₁ nm (x ⟨∷⟩ []) = do
x' ← convert x
pure (nm $ᵉ (x' ⟨∷⟩ []))
convertOp₁ nm (x ∷ xs) = convertOp₁ nm xs
convertOp₁ _ _ = pure unknown
convertExp : Args Term → TC Term
convertExp (x ⟨∷⟩ y ⟨∷⟩ []) = do
x' ← convert x
pure (quote _⊛_ $ᵉ (x' ⟨∷⟩ y ⟨∷⟩ []))
convertExp (x ∷ xs) = convertExp xs
convertExp _ = pure unknown
convertSub : Args Term → TC Term
convertSub (x ⟨∷⟩ y ⟨∷⟩ []) = do
x' ← convert x
-y' ← convertOp₁ (quote (⊝_)) (y ⟨∷⟩ [])
pure (quote _⊕_ $ᵉ x' ⟨∷⟩ -y' ⟨∷⟩ [])
convertSub (x ∷ xs) = convertSub xs
convertSub _ = pure unknown
convertUnknownName : Name → Args Term → TC Term
convertUnknownName nm xs = do
nameTerm ← normalise (def nm [])
if (nameTerm =α= add) then convertOp₂ (quote _⊕_) xs else
if (nameTerm =α= mul) then convertOp₂ (quote _⊗_) xs else
if (nameTerm =α= neg) then convertOp₁ (quote ⊝_) xs else
if (nameTerm =α= pow) then convertExp xs else
if (nameTerm =α= sub) then convertSub xs else
pure (`Κ (def nm xs))
convertSuc : Term → TC Term
convertSuc x = do x' ← convert x; pure (quote _⊕_ $ᵉ (`Κ (toTerm 1) ⟨∷⟩ x' ⟨∷⟩ []))
------------------------------------------------------------------------
-- Macros
------------------------------------------------------------------------
-- Quantified macro
open RingReflection
open RingSolverReflection
malformedForallTypeError : ∀ {a} {A : Set a} → Term → TC A
malformedForallTypeError found = typeError
( strErr "Malformed call to solve."
∷ strErr "Expected target type to be like: ∀ x y → x + y ≈ y + x."
∷ strErr "Instead: "
∷ termErr found
∷ [])
quantifiedVarMap : ℕ → VarMap
quantifiedVarMap numVars i =
if i <ᵇ numVars
then just (var i [])
else nothing
constructCallToSolver : Term → RingOperatorTerms → List String → Term → Term → TC Term
constructCallToSolver `ring opNames variables `lhs `rhs = do
`lhsExpr ← conv `lhs
`rhsExpr ← conv `rhs
pure $ `solver `ring numVars
(prependVLams variables (_`⊜_ `ring numVars `lhsExpr `rhsExpr))
(prependHLams variables (`refl `ring))
where
numVars : ℕ
numVars = List.length variables
conv : Term → TC Term
conv = convertTerm `ring numVars opNames (quantifiedVarMap numVars)
-- This is the main macro which solves for equations in which the
-- variables are universally quantified over:
--
-- lemma : ∀ x y → x + y ≈ y + x
-- lemma = solve-∀ ring
--
-- where ring is your implementation of AlmostCommutativeRing.
-- (Find some example implementations in
-- Polynomial.Solver.Ring.AlmostCommutativeRing.Instances).
solve-∀-macro : Name → Term → TC ⊤
solve-∀-macro ring hole = do
`ring ← checkIsRing (def ring [])
commitTC
operatorTerms ← getRingOperatorTerms `ring
-- Obtain and sanitise the goal type
`hole ← inferType hole >>= reduce
let variablesAndTypes , equation = stripPis `hole
let variables = List.map proj₁ variablesAndTypes
just (lhs ∷ rhs ∷ []) ← pure (getVisibleArgs 2 equation)
where nothing → malformedForallTypeError `hole
solverCall ← constructCallToSolver `ring operatorTerms variables lhs rhs
unify hole solverCall
macro
solve-∀ : Name → Term → TC ⊤
solve-∀ = solve-∀-macro
------------------------------------------------------------------------
-- Unquantified macro
malformedArgumentListError : ∀ {a} {A : Set a} → Term → TC A
malformedArgumentListError found = typeError
( strErr "Malformed call to solve."
∷ strErr "First argument should be a list of free variables."
∷ strErr "Instead: "
∷ termErr found
∷ [])
malformedGoalError : ∀ {a} {A : Set a} → Term → TC A
malformedGoalError found = typeError
( strErr "Malformed call to solve."
∷ strErr "Goal type should be of the form: LHS ≈ RHS"
∷ strErr "Instead: "
∷ termErr found
∷ [])
checkIsListOfVariables : Term → Term → TC Term
checkIsListOfVariables `ring `xs = checkType `xs (`List (`Carrier `ring)) >>= normalise
-- Extracts the deBruijn indices from a list of variables
getVariableIndices : Term → Maybe NatSet
getVariableIndices = go []
where
go : NatSet → Term → Maybe NatSet
go t (var i [] `∷` xs) = go (insert i t) xs
go t `[]` = just t
go _ _ = nothing
constructSolution : Term → RingOperatorTerms → NatSet → Term → Term → TC Term
constructSolution `ring opTerms variables `lhs `rhs = do
`lhsExpr ← conv `lhs
`rhsExpr ← conv `rhs
pure $ `trans `ring (`sym `ring `lhsExpr) `rhsExpr
where
numVars = List.length variables
varMap : VarMap
varMap i = Maybe.map (λ x → `I `ring numVars (toFinTerm x)) (lookup variables i)
ρ : Term
ρ = curriedTerm variables
conv = λ t → do
t' ← convertTerm `ring numVars opTerms varMap t
pure $ `correct `ring numVars t' ρ
-- Use this macro when you want to solve something *under* a lambda.
-- For example: say you have a long proof, and you just want the solver
-- to deal with an intermediate step. Call it like so:
--
-- lemma₃ : ∀ x y → x + y * 1 + 3 ≈ 2 + 1 + y + x
-- lemma₃ x y = begin
-- x + y * 1 + 3 ≈⟨ +-comm x (y * 1) ⟨ +-cong ⟩ refl ⟩
-- y * 1 + x + 3 ≈⟨ solve (x ∷ y ∷ []) Int.ring ⟩
-- 3 + y + x ≡⟨ refl ⟩
-- 2 + 1 + y + x ∎
--
-- The first argument is the free variables, and the second is the
-- ring implementation (as before).
solve-macro : Term → Name → Term → TC ⊤
solve-macro variables ring hole = do
`ring ← checkIsRing (def ring [])
commitTC
operatorTerms ← getRingOperatorTerms `ring
-- Obtain and sanitise the list of variables
listOfVariables′ ← checkIsListOfVariables `ring variables
commitTC
just variableIndices ← pure (getVariableIndices listOfVariables′)
where nothing → malformedArgumentListError listOfVariables′
-- Obtain and santise the goal type
hole′ ← inferType hole >>= reduce
just (lhs ∷ rhs ∷ []) ← pure (getVisibleArgs 2 hole′)
where nothing → malformedGoalError hole′
solution ← constructSolution `ring operatorTerms variableIndices lhs rhs
unify hole solution
macro
solve : Term → Name → Term → TC ⊤
solve = solve-macro