Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pow(int, int, fmpz) #93

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/flint/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,16 @@ def test_fmpz():
(2, 2, 3, 1),
(2, -1, 5, 3),
(2, 0, 5, 1),
(2, 5, 1000, 32),
]
for a, b, c, ab_mod_c in pow_mod_examples:
assert pow(a, b, c) == ab_mod_c
assert pow(flint.fmpz(a), b, c) == ab_mod_c
assert pow(a, flint.fmpz(b), c) == ab_mod_c
assert pow(a, b, flint.fmpz(c)) == ab_mod_c
assert pow(flint.fmpz(a), flint.fmpz(b), c) == ab_mod_c
assert pow(flint.fmpz(a), b, flint.fmpz(c)) == ab_mod_c
assert pow(a, flint.fmpz(b), flint.fmpz(c)) == ab_mod_c
assert pow(flint.fmpz(a), flint.fmpz(b), flint.fmpz(c)) == ab_mod_c

assert raises(lambda: pow(flint.fmpz(2), 2, 0), ValueError)
Expand Down
57 changes: 31 additions & 26 deletions src/flint/types/fmpz.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -360,53 +360,58 @@ cdef class fmpz(flint_scalar):
return u

def __pow__(s, t, m):
cdef fmpz_struct sval[1]
cdef fmpz_struct tval[1]
cdef fmpz_struct mval[1]
cdef int stype = FMPZ_UNKNOWN
cdef int ttype = FMPZ_UNKNOWN
cdef int mtype = FMPZ_UNKNOWN
cdef int success
u = NotImplemented
ttype = fmpz_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented

if m is None:
# fmpz_pow_fmpz throws if x is negative
if fmpz_sgn(tval) == -1:
if ttype == FMPZ_TMP: fmpz_clear(tval)
raise ValueError("negative exponent")
try:
stype = fmpz_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
return NotImplemented
ttype = fmpz_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
if m is None:
# fmpz_pow_fmpz throws if x is negative
if fmpz_sgn(tval) == -1:
raise ValueError("negative exponent")

u = fmpz.__new__(fmpz)
success = fmpz_pow_fmpz((<fmpz>u).val, (<fmpz>s).val, tval)
u = fmpz.__new__(fmpz)
success = fmpz_pow_fmpz((<fmpz>u).val, (<fmpz>s).val, tval)

if not success:
if ttype == FMPZ_TMP: fmpz_clear(tval)
raise OverflowError("fmpz_pow_fmpz: exponent too large")
else:
# Modular exponentiation
mtype = fmpz_set_any_ref(mval, m)
if mtype != FMPZ_UNKNOWN:
if not success:
raise OverflowError("fmpz_pow_fmpz: exponent too large")

return u
else:
# Modular exponentiation
mtype = fmpz_set_any_ref(mval, m)
if mtype == FMPZ_UNKNOWN:
return NotImplemented

if fmpz_is_zero(mval):
if ttype == FMPZ_TMP: fmpz_clear(tval)
if mtype == FMPZ_TMP: fmpz_clear(mval)
raise ValueError("pow(): modulus cannot be zero")

# The Flint docs say that fmpz_powm will throw if m is zero
# but it also throws if m is negative. Python generally allows
# e.g. pow(2, 2, -3) == (2^2) % (-3) == -2. We could implement
# that here as well but it is not clear how useful it is.
if fmpz_sgn(mval) == -1:
if ttype == FMPZ_TMP: fmpz_clear(tval)
if mtype == FMPZ_TMP: fmpz_clear(mval)
raise ValueError("pow(): negative modulua not supported")
raise ValueError("pow(): negative modulus not supported")

u = fmpz.__new__(fmpz)
fmpz_powm((<fmpz>u).val, (<fmpz>s).val, tval, mval)
fmpz_powm((<fmpz>u).val, sval, tval, mval)

if ttype == FMPZ_TMP: fmpz_clear(tval)
if mtype == FMPZ_TMP: fmpz_clear(mval)
return u
return u
finally:
if stype == FMPZ_TMP: fmpz_clear(sval)
if ttype == FMPZ_TMP: fmpz_clear(tval)
if mtype == FMPZ_TMP: fmpz_clear(mval)

def __rpow__(s, t, m):
t = any_as_fmpz(t)
Expand Down