@@ -5,12 +5,14 @@ from flint.types.fmpz cimport any_as_fmpz
5
5
from flint.types.fmpz cimport fmpz
6
6
from flint.types.fmpq cimport fmpq
7
7
8
+ from flint.flintlib.flint cimport ulong
8
9
from flint.flintlib.fmpz cimport fmpz_t
9
10
from flint.flintlib.nmod cimport nmod_pow_fmpz, nmod_inv
10
11
from flint.flintlib.nmod_vec cimport *
11
12
from flint.flintlib.fmpz cimport fmpz_fdiv_ui, fmpz_init, fmpz_clear
12
13
from flint.flintlib.fmpz cimport fmpz_set_ui, fmpz_get_ui
13
14
from flint.flintlib.fmpq cimport fmpq_mod_fmpz
15
+ from flint.flintlib.ulong_extras cimport n_gcdinv
14
16
15
17
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except - 1 :
16
18
cdef int success
@@ -64,9 +66,6 @@ cdef class nmod(flint_scalar):
64
66
def __int__ (self ):
65
67
return int (self .val)
66
68
67
- def __long__ (self ):
68
- return self .val
69
-
70
69
def modulus (self ):
71
70
return self .mod.n
72
71
@@ -170,6 +169,8 @@ cdef class nmod(flint_scalar):
170
169
cdef nmod r
171
170
cdef mp_limb_t sval, tval, x
172
171
cdef nmod_t mod
172
+ cdef ulong tinvval
173
+
173
174
if typecheck(s, nmod):
174
175
mod = (< nmod> s).mod
175
176
sval = (< nmod> s).val
@@ -180,17 +181,19 @@ cdef class nmod(flint_scalar):
180
181
tval = (< nmod> t).val
181
182
if not any_as_nmod(& sval, s, mod):
182
183
return NotImplemented
184
+
183
185
if tval == 0 :
184
186
raise ZeroDivisionError (" %s is not invertible mod %s " % (tval, mod.n))
185
187
if not s:
186
188
return s
187
- # XXX: check invertibility?
188
- x = nmod_div(sval, tval, mod)
189
- if x == 0 :
189
+
190
+ g = n_gcdinv( & tinvval, < ulong > tval, < ulong > mod.n )
191
+ if g ! = 1 :
190
192
raise ZeroDivisionError (" %s is not invertible mod %s " % (tval, mod.n))
193
+
191
194
r = nmod.__new__ (nmod)
192
195
r.mod = mod
193
- r.val = x
196
+ r.val = nmod_mul(sval, < mp_limb_t > tinvval, mod)
194
197
return r
195
198
196
199
def __truediv__ (s , t ):
@@ -200,18 +203,43 @@ cdef class nmod(flint_scalar):
200
203
return nmod._div_(t, s)
201
204
202
205
def __invert__ (self ):
203
- return (1 / self ) # XXX: speed up
206
+ cdef nmod r
207
+ cdef ulong g, inv, sval
208
+ sval = < ulong> (< nmod> self ).val
209
+ g = n_gcdinv(& inv, sval, self .mod.n)
210
+ if g != 1 :
211
+ raise ZeroDivisionError (" %s is not invertible mod %s " % (sval, self .mod.n))
212
+ r = nmod.__new__ (nmod)
213
+ r.mod = self .mod
214
+ r.val = < mp_limb_t> inv
215
+ return r
204
216
205
- def __pow__ (self , exp ):
217
+ def __pow__ (self , exp , modulus = None ):
206
218
cdef nmod r
219
+ cdef mp_limb_t rval, mod
220
+ cdef ulong g, rinv
221
+
222
+ if modulus is not None :
223
+ raise TypeError (" three-argument pow() not supported by nmod" )
224
+
207
225
e = any_as_fmpz(exp)
208
226
if e is NotImplemented :
209
227
return NotImplemented
210
- r = nmod.__new__ (nmod)
211
- r.mod = self .mod
212
- r.val = self .val
228
+
229
+ rval = (< nmod> self ).val
230
+ mod = (< nmod> self ).mod.n
231
+
232
+ # XXX: It is not clear that it is necessary to special case negative
233
+ # exponents here. The nmod_pow_fmpz function seems to handle this fine
234
+ # but the Flint docs say that the exponent must be nonnegative.
213
235
if e < 0 :
214
- r.val = nmod_inv(r.val, self .mod)
236
+ g = n_gcdinv(& rinv, < ulong> rval, < ulong> mod)
237
+ if g != 1 :
238
+ raise ZeroDivisionError (" %s is not invertible mod %s " % (rval, mod))
239
+ rval = < mp_limb_t> rinv
215
240
e = - e
216
- r.val = nmod_pow_fmpz(r.val, (< fmpz> e).val, self .mod)
241
+
242
+ r = nmod.__new__ (nmod)
243
+ r.mod = self .mod
244
+ r.val = nmod_pow_fmpz(rval, (< fmpz> e).val, self .mod)
217
245
return r
0 commit comments