Skip to content

Instantly share code, notes, and snippets.

@tkf
Last active February 12, 2018 04:40
Show Gist options
  • Save tkf/16b06fd6a4943c8d524295152c231b13 to your computer and use it in GitHub Desktop.
Save tkf/16b06fd6a4943c8d524295152c231b13 to your computer and use it in GitHub Desktop.
using IterTools: shortest, longest
struct IterView{I, S}
iter::I
indices::S
end
const viewing = IterView
mutable struct ViewState{IV, SV}
index::Int
iter_state
indices_state
iter_value::IV
ViewState{IV, SV}(index, iter_state, indices_state) where {IV, SV} =
new{IV, SV}(index, iter_state, indices_state)
# ViewState(index, iter_state, indices_state, iter_value::IV) where {IV} =
# new{IV}(index, iter_state, indices_state, iter_value)
end
function Base.length(it::IterView{I, <:OrdinalRange}) where {I}
effective_indices = intersect(1:length(it.iter), it.indices)
return length((1:length(it.iter))[effective_indices])
end
Base.iteratorsize(::Type{<: IterView{I, S}}) where {I, S <: OrdinalRange} =
shortest(Base.iteratorsize(I), Base.iteratorsize(S))
Base.iteratorsize(::Type{<: IterView{I, S}}) where {I, S} =
longest(Base.SizeUnknown(),
shortest(Base.iteratorsize(I), Base.iteratorsize(S)))
Base.iteratoreltype(::Type{<: IterView{I, S}}) where {I, S} =
Base.iteratoreltype(I)
Base.eltype(::Type{<: IterView{I, S}}) where {I, S} = eltype(I)
function Base.start(it::IterView{I, S}) where {I, S}
if ! (eltype(S) <: Integer || eltype(S) === Any)
error("Type $(eltype(S)) cannot be used as indices.")
end
return ViewState{
eltype(I),
eltype(S),
}(0, start(it.iter), start(it.indices))
end
check_index(::ViewState{IV, <: Integer}, _) where {IV} = nothing
check_index(it::ViewState{IV, Any}, ::Integer) where {IV} = nothing
check_index(it::ViewState{IV, Any}, n::T) where {IV, T} =
error("$n of type $T cannot be used as an index")
function Base.done(it::IterView, state)
done(it.indices, state.indices_state) && return true
i = state.index
n, ss = next(it.indices, state.indices_state)
check_index(state, n)
if i >= n
if i == 0
# This is the first time done is called, since
# state.index=0. It means that the first element returned
# by it.indices is less than 1.
error("Trying to access index (=$n) smaller than 1.")
end
error("Indices are not strictly increasing.",
" Previous index: $i, Current index: $n")
end
state.indices_state = ss
si = state.iter_state
local x
while i < n
if done(it.iter, si)
return true
end
x, si = next(it.iter, si)
i += 1
end
state.iter_state = si
state.iter_value = x
state.index = i
return false
end
Base.next(it::IterView, state) = (state.iter_value, state)
using Base.Test
# @show collect(viewing(1:100, 1:2:5))
@testset "viewing" begin
@testset "collect" begin
@test collect(viewing(1:100, 1:2:5)) == collect(1:2:5)
@test collect(viewing(1:5, 1:2:10)) == collect(1:2:5)
@test collect(viewing(1:10, [1, 3, 7, 11])) == [1, 3, 7]
@test collect(viewing(1:10, Any[1, 3, 7, 11])) == [1, 3, 7]
end
@testset "length: iter=$iter, indices=$indices" for (iter, indices) in [
(1:10, 1:5),
(1:10, 1:2:9),
(1:10, 1:2:5),
]
result = iter[indices]
@test length(viewing(iter, indices)) == length(result)
end
@testset "errors" begin
@test try
collect(viewing(1:2, 1:0.1:2))
false
catch err
contains(err.msg, "Type Float64 cannot be used")
end
@test try
collect(viewing(1:2, Any[1.0]))
false
catch err
contains(err.msg, "type Float64 cannot be used")
end
@test try
collect(viewing(1:2, [0]))
false
catch err
contains(err.msg, "smaller than 1")
end
@test try
collect(viewing(1:2, [2, 1]))
false
catch err
contains(err.msg, "not strictly increasing")
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment