Skip to content

Commit 243d896

Browse files
Merge pull request #79 from oscarbenjamin/pr_nmod_pow
fix(nmod): ZeroDivisionError instead of coredump
2 parents 9a86d4e + 630ea01 commit 243d896

File tree

2 files changed

+52
-17
lines changed

2 files changed

+52
-17
lines changed

src/flint/test/test.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,7 @@ def test_nmod():
12791279
assert G(1,2) != G(0,2)
12801280
assert G(0,2) != G(0,3)
12811281
assert G(3,5) == G(8,5)
1282+
assert G(1,2) != (1,2)
12821283
assert isinstance(hash(G(3, 5)), int)
12831284
assert raises(lambda: G([], 3), TypeError)
12841285
#assert G(3,5) == 8 # do we want this?
@@ -1304,14 +1305,20 @@ def test_nmod():
13041305
assert G(0,3) / G(1,3) == G(0,3)
13051306
assert G(3,17) * flint.fmpq(11,5) == G(10,17)
13061307
assert G(3,17) / flint.fmpq(11,5) == G(6,17)
1308+
assert raises(lambda: G(flint.fmpq(2, 3), 3), ZeroDivisionError)
1309+
assert raises(lambda: G(2,5) / G(0,5), ZeroDivisionError)
1310+
assert raises(lambda: G(2,5) / 0, ZeroDivisionError)
1311+
assert G(1,6) / G(5,6) == G(5,6)
1312+
assert raises(lambda: G(1,6) / G(3,6), ZeroDivisionError)
13071313
assert G(1,3) ** 2 == G(1,3)
13081314
assert G(2,3) ** flint.fmpz(2) == G(1,3)
1315+
assert ~G(2,7) == G(2,7) ** -1 == G(4,7)
1316+
assert raises(lambda: G(3,6) ** -1, ZeroDivisionError)
1317+
assert raises(lambda: ~G(3,6), ZeroDivisionError)
1318+
assert raises(lambda: pow(G(1,3), 2, 7), TypeError)
13091319
assert G(flint.fmpq(2, 3), 5) == G(4,5)
13101320
assert raises(lambda: G(2,5) ** G(2,5), TypeError)
13111321
assert raises(lambda: flint.fmpz(2) ** G(2,5), TypeError)
1312-
assert raises(lambda: G(flint.fmpq(2, 3), 3), ZeroDivisionError)
1313-
assert raises(lambda: G(2,5) / G(0,5), ZeroDivisionError)
1314-
assert raises(lambda: G(2,5) / 0, ZeroDivisionError)
13151322
assert raises(lambda: G(2,5) + G(2,7), ValueError)
13161323
assert raises(lambda: G(2,5) - G(2,7), ValueError)
13171324
assert raises(lambda: G(2,5) * G(2,7), ValueError)

src/flint/types/nmod.pyx

+42-14
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ from flint.types.fmpz cimport any_as_fmpz
55
from flint.types.fmpz cimport fmpz
66
from flint.types.fmpq cimport fmpq
77

8+
from flint.flintlib.flint cimport ulong
89
from flint.flintlib.fmpz cimport fmpz_t
910
from flint.flintlib.nmod cimport nmod_pow_fmpz, nmod_inv
1011
from flint.flintlib.nmod_vec cimport *
1112
from flint.flintlib.fmpz cimport fmpz_fdiv_ui, fmpz_init, fmpz_clear
1213
from flint.flintlib.fmpz cimport fmpz_set_ui, fmpz_get_ui
1314
from flint.flintlib.fmpq cimport fmpq_mod_fmpz
15+
from flint.flintlib.ulong_extras cimport n_gcdinv
1416

1517
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1:
1618
cdef int success
@@ -64,9 +66,6 @@ cdef class nmod(flint_scalar):
6466
def __int__(self):
6567
return int(self.val)
6668

67-
def __long__(self):
68-
return self.val
69-
7069
def modulus(self):
7170
return self.mod.n
7271

@@ -170,6 +169,8 @@ cdef class nmod(flint_scalar):
170169
cdef nmod r
171170
cdef mp_limb_t sval, tval, x
172171
cdef nmod_t mod
172+
cdef ulong tinvval
173+
173174
if typecheck(s, nmod):
174175
mod = (<nmod>s).mod
175176
sval = (<nmod>s).val
@@ -180,17 +181,19 @@ cdef class nmod(flint_scalar):
180181
tval = (<nmod>t).val
181182
if not any_as_nmod(&sval, s, mod):
182183
return NotImplemented
184+
183185
if tval == 0:
184186
raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n))
185187
if not s:
186188
return s
187-
# XXX: check invertibility?
188-
x = nmod_div(sval, tval, mod)
189-
if x == 0:
189+
190+
g = n_gcdinv(&tinvval, <ulong>tval, <ulong>mod.n)
191+
if g != 1:
190192
raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n))
193+
191194
r = nmod.__new__(nmod)
192195
r.mod = mod
193-
r.val = x
196+
r.val = nmod_mul(sval, <mp_limb_t>tinvval, mod)
194197
return r
195198

196199
def __truediv__(s, t):
@@ -200,18 +203,43 @@ cdef class nmod(flint_scalar):
200203
return nmod._div_(t, s)
201204

202205
def __invert__(self):
203-
return (1 / self) # XXX: speed up
206+
cdef nmod r
207+
cdef ulong g, inv, sval
208+
sval = <ulong>(<nmod>self).val
209+
g = n_gcdinv(&inv, sval, self.mod.n)
210+
if g != 1:
211+
raise ZeroDivisionError("%s is not invertible mod %s" % (sval, self.mod.n))
212+
r = nmod.__new__(nmod)
213+
r.mod = self.mod
214+
r.val = <mp_limb_t>inv
215+
return r
204216

205-
def __pow__(self, exp):
217+
def __pow__(self, exp, modulus=None):
206218
cdef nmod r
219+
cdef mp_limb_t rval, mod
220+
cdef ulong g, rinv
221+
222+
if modulus is not None:
223+
raise TypeError("three-argument pow() not supported by nmod")
224+
207225
e = any_as_fmpz(exp)
208226
if e is NotImplemented:
209227
return NotImplemented
210-
r = nmod.__new__(nmod)
211-
r.mod = self.mod
212-
r.val = self.val
228+
229+
rval = (<nmod>self).val
230+
mod = (<nmod>self).mod.n
231+
232+
# XXX: It is not clear that it is necessary to special case negative
233+
# exponents here. The nmod_pow_fmpz function seems to handle this fine
234+
# but the Flint docs say that the exponent must be nonnegative.
213235
if e < 0:
214-
r.val = nmod_inv(r.val, self.mod)
236+
g = n_gcdinv(&rinv, <ulong>rval, <ulong>mod)
237+
if g != 1:
238+
raise ZeroDivisionError("%s is not invertible mod %s" % (rval, mod))
239+
rval = <mp_limb_t>rinv
215240
e = -e
216-
r.val = nmod_pow_fmpz(r.val, (<fmpz>e).val, self.mod)
241+
242+
r = nmod.__new__(nmod)
243+
r.mod = self.mod
244+
r.val = nmod_pow_fmpz(rval, (<fmpz>e).val, self.mod)
217245
return r

0 commit comments

Comments
 (0)