Created
May 30, 2017 18:40
-
-
Save arghhhh/abdb8b6d28039f683f5287e124903670 to your computer and use it in GitHub Desktop.
Modified example ModInt code without promotions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This file is a part of Julia. License is MIT: https://julialang.org/license | |
# This is a modifed version of julia/examples/ModInts.jl | |
module ModInts | |
export ModInt | |
import Base: +, -, *, /, inv | |
# this has to be subtyped from at least Number to avoid transpose | |
# being a problem: https://github.com/JuliaLang/julia/issues/20978 | |
# struct ModInt{n} <: Integer | |
struct ModInt{n} <: Integer | |
# k::Int | |
k | |
# ModInt{n}(k) where {n} = new(mod(k,n)) | |
ModInt{n}(k) where {n} = new(k) | |
end | |
Base.show(io::IO, k::ModInt{n}) where {n} = | |
print(io, get(io, :compact, false) ? k.k : "$(k.k) mod $n") | |
(+)(a::ModInt{n}, b::ModInt{n}) where {n} = ModInt{n}(a.k+b.k) | |
(-)(a::ModInt{n}, b::ModInt{n}) where {n} = ModInt{n}(a.k-b.k) | |
(*)(a::ModInt{n}, b::ModInt{n}) where {n} = ModInt{n}(a.k*b.k) | |
(-)(a::ModInt{n}) where {n} = ModInt{n}(-a.k) | |
inv(a::ModInt{n}) where {n} = ModInt{n}(invmod(a.k, n)) | |
(/)(a::ModInt{n}, b::ModInt{n}) where {n} = a*inv(b) # broaden for non-coprime? | |
# Removing conversions and promotions: | |
# Base.promote_rule(::Type{ModInt{n}}, ::Type{Int}) where {n} = ModInt{n} | |
# Base.convert(::Type{ModInt{n}}, i::Int) where {n} = ModInt{n}(i) | |
# Implement zero() and one() instead of using promotions/conversions | |
Base.zero(::Type{ModInt{n}}) where {n} = ModInt{n}(0) | |
Base.zero(::ModInt{n}) where {n} = ModInt{n}(0) | |
Base.one(::Type{ModInt{n}}) where {n} = ModInt{n}(1) | |
Base.one(::ModInt{n}) where {n} = ModInt{n}(1) | |
# Add functions necessary for pivoting | |
# though ideally could make pivoting only depenpent on | |
# equality with zero(), and independent on magnitude | |
Base.abs( a::ModInt{n} ) where {n} = a | |
Base.:<( a::ModInt{n}, b::ModInt{n} ) where {n} = a.k < b.k | |
Base.transpose(a::ModInt{n}) where {n} = a | |
end # module | |
using ModInts | |
# this is already lower triangular, so pivoting is not necessary | |
A = [ ModInt{2}(1) ModInt{2}(0) ModInt{2}(0) ModInt{2}(0) ; | |
ModInt{2}(1) ModInt{2}(1) ModInt{2}(0) ModInt{2}(0) ; | |
ModInt{2}(0) ModInt{2}(0) ModInt{2}(1) ModInt{2}(0) ; | |
ModInt{2}(0) ModInt{2}(1) ModInt{2}(1) ModInt{2}(1) ] | |
b = [ ModInt{2}(1), ModInt{2}(1), ModInt{2}(0), ModInt{2}(1) ] | |
x = A \ b | |
@assert A * x == b | |
#= | |
# this one needs pivoting | |
# example taken from https://github.com/andrewcooke/IntModN.jl | |
# without the zfield macro | |
A = [ ModInt{2}(1) ModInt{2}(1) ModInt{2}(1) ModInt{2}(0) ; | |
ModInt{2}(1) ModInt{2}(1) ModInt{2}(0) ModInt{2}(1) ; | |
ModInt{2}(1) ModInt{2}(0) ModInt{2}(1) ModInt{2}(1) ; | |
ModInt{2}(0) ModInt{2}(1) ModInt{2}(1) ModInt{2}(1) ] | |
b = [ ModInt{2}(1), ModInt{2}(1), ModInt{2}(0), ModInt{2}(1) ] | |
lufact( A, Val{true} ) | |
x = A \ b # this fails because it doesn't use pivoting | |
@assert A * x == b | |
# expect [ 0, 1, 0, 0 ] all mod 2 | |
=# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment