Skip to content

Instantly share code, notes, and snippets.

@arghhhh
Created May 30, 2017 18:40
Show Gist options
  • Save arghhhh/abdb8b6d28039f683f5287e124903670 to your computer and use it in GitHub Desktop.
Save arghhhh/abdb8b6d28039f683f5287e124903670 to your computer and use it in GitHub Desktop.
Modified example ModInt code without promotions
# This file is a part of Julia. License is MIT: https://julialang.org/license
# This is a modifed version of julia/examples/ModInts.jl
module ModInts
export ModInt
import Base: +, -, *, /, inv
# this has to be subtyped from at least Number to avoid transpose
# being a problem: https://github.com/JuliaLang/julia/issues/20978
# struct ModInt{n} <: Integer
struct ModInt{n} <: Integer
# k::Int
k
# ModInt{n}(k) where {n} = new(mod(k,n))
ModInt{n}(k) where {n} = new(k)
end
Base.show(io::IO, k::ModInt{n}) where {n} =
print(io, get(io, :compact, false) ? k.k : "$(k.k) mod $n")
(+)(a::ModInt{n}, b::ModInt{n}) where {n} = ModInt{n}(a.k+b.k)
(-)(a::ModInt{n}, b::ModInt{n}) where {n} = ModInt{n}(a.k-b.k)
(*)(a::ModInt{n}, b::ModInt{n}) where {n} = ModInt{n}(a.k*b.k)
(-)(a::ModInt{n}) where {n} = ModInt{n}(-a.k)
inv(a::ModInt{n}) where {n} = ModInt{n}(invmod(a.k, n))
(/)(a::ModInt{n}, b::ModInt{n}) where {n} = a*inv(b) # broaden for non-coprime?
# Removing conversions and promotions:
# Base.promote_rule(::Type{ModInt{n}}, ::Type{Int}) where {n} = ModInt{n}
# Base.convert(::Type{ModInt{n}}, i::Int) where {n} = ModInt{n}(i)
# Implement zero() and one() instead of using promotions/conversions
Base.zero(::Type{ModInt{n}}) where {n} = ModInt{n}(0)
Base.zero(::ModInt{n}) where {n} = ModInt{n}(0)
Base.one(::Type{ModInt{n}}) where {n} = ModInt{n}(1)
Base.one(::ModInt{n}) where {n} = ModInt{n}(1)
# Add functions necessary for pivoting
# though ideally could make pivoting only depenpent on
# equality with zero(), and independent on magnitude
Base.abs( a::ModInt{n} ) where {n} = a
Base.:<( a::ModInt{n}, b::ModInt{n} ) where {n} = a.k < b.k
Base.transpose(a::ModInt{n}) where {n} = a
end # module
using ModInts
# this is already lower triangular, so pivoting is not necessary
A = [ ModInt{2}(1) ModInt{2}(0) ModInt{2}(0) ModInt{2}(0) ;
ModInt{2}(1) ModInt{2}(1) ModInt{2}(0) ModInt{2}(0) ;
ModInt{2}(0) ModInt{2}(0) ModInt{2}(1) ModInt{2}(0) ;
ModInt{2}(0) ModInt{2}(1) ModInt{2}(1) ModInt{2}(1) ]
b = [ ModInt{2}(1), ModInt{2}(1), ModInt{2}(0), ModInt{2}(1) ]
x = A \ b
@assert A * x == b
#=
# this one needs pivoting
# example taken from https://github.com/andrewcooke/IntModN.jl
# without the zfield macro
A = [ ModInt{2}(1) ModInt{2}(1) ModInt{2}(1) ModInt{2}(0) ;
ModInt{2}(1) ModInt{2}(1) ModInt{2}(0) ModInt{2}(1) ;
ModInt{2}(1) ModInt{2}(0) ModInt{2}(1) ModInt{2}(1) ;
ModInt{2}(0) ModInt{2}(1) ModInt{2}(1) ModInt{2}(1) ]
b = [ ModInt{2}(1), ModInt{2}(1), ModInt{2}(0), ModInt{2}(1) ]
lufact( A, Val{true} )
x = A \ b # this fails because it doesn't use pivoting
@assert A * x == b
# expect [ 0, 1, 0, 0 ] all mod 2
=#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment