Skip to content

Commit 4ddc8c9

Browse files
committed
Add function to an 'id' attrib when using the DecisionTreeClassifier
Something on the type stability for truecolor.
1 parent 1297eb8 commit 4ddc8c9

1 file changed

Lines changed: 35 additions & 13 deletions

File tree

src/utils.jl

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ end
119119

120120
# ----------------------------------------------------------------------------------------------------------
121121
# This method fall into the description of the general 'truecolor(bndR, bndG, bndB)'
122-
function truecolor(bandR::T, bandG::T, bandB::T) where {T<:Union{GMTgrid{<:AbstractFloat, 2},Matrix{<:AbstractFloat}}}
122+
function truecolor(bandR::T, bandG::T, bandB::T) where {T<:Union{GMT.GMTgrid{<:AbstractFloat, 2},Matrix{<:AbstractFloat}}}
123123
@assert size(bandR) == size(bandG) == size(bandB)
124124
n1 = length(bandR); n2 = 2 * n1; n3 = 3 * n1
125125
img = Array{UInt8}(undef, size(bandR,1), size(bandR,2), 3)
@@ -140,39 +140,42 @@ function truecolor(bandR::T, bandG::T, bandB::T) where {T<:Union{GMTgrid{<:Abstr
140140
return isa(bandR, GMTgrid) ? mat2img(img, bandR) : mat2img(img)
141141
end
142142

143-
truecolor(cube::GMT.GMTimage{UInt16, 3}, layers::Vector{Int}) = truecolor(cube, layers=layers)
144-
function truecolor(cube::GMT.GMTimage{UInt16, 3}; layers::Vector{Int}=Int[])
143+
truecolor(cube::GMT.GMTimage{UInt16, 3}, layers::Vector{Int}; stretch=true) = truecolor(cube, layers=layers, stretch=stretch)
144+
function truecolor(cube::GMT.GMTimage{UInt16, 3}; layers::Vector{Int}=Int[], stretch=true)
145145
(length(layers) != 3) && error("For an RGB composition 'bands' must be a 3 elements array and not $(length(layers))")
146146
(cube.layout[3] != 'B') && error("For an RGB composition the image object must be Band interleaved and not $(cube.layout)")
147147
img = Array{UInt8, 3}(undef, size(cube,1), size(cube,2), 3)
148148
layers = find_layers(cube, layers, 3)
149-
_ = mat2img(@view(cube[:,:,layers[1]]), stretch=true, img8=view(img,:,:,1), scale_only=1)
150-
_ = mat2img(@view(cube[:,:,layers[2]]), stretch=true, img8=view(img,:,:,2), scale_only=1)
151-
_ = mat2img(@view(cube[:,:,layers[3]]), stretch=true, img8=view(img,:,:,3), scale_only=1)
149+
stch = (stretch == 1) ? true : stretch # This bloody type unstable and does not test stupid inputs
150+
_ = mat2img(@view(cube[:,:,layers[1]]), stretch=stch, img8=view(img,:,:,1), scale_only=1)
151+
_ = mat2img(@view(cube[:,:,layers[2]]), stretch=stch, img8=view(img,:,:,2), scale_only=1)
152+
_ = mat2img(@view(cube[:,:,layers[3]]), stretch=stch, img8=view(img,:,:,3), scale_only=1)
152153
Io = mat2img(img, cube); Io.layout = "TRBa"
153154
Io
154155
end
155156

156-
truecolor(cube::GMT.GMTgrid{Float32, 3}, layers::Vector{Int}) = truecolor(cube, layers=layers)
157+
truecolor(cube::GMT.GMTgrid{Float32, 3}, layers::Vector{Int}; stretch=true) = truecolor(cube, layers=layers, stretch=stretch)
157158
function truecolor(cube::GMT.GMTgrid{Float32, 3}; bands::Vector{Int}=Int[], layers::Vector{Int}=Int[],
158-
bandnames::Vector{String}=String[], type::DataType=UInt8)
159+
bandnames::Vector{String}=String[], stretch=true, type::DataType=UInt8)
159160
(isempty(bands) && isempty(bandnames) && isempty(layers)) && (bandnames = ["red", "green", "blue"])
160161
isempty(layers) && (layers = find_layers(cube, bandnames, bands))
161162
(length(layers) != 3) && error("For an RGB composition 'bands' must be a 3 elements array and not $(length(layers))")
163+
stch = (stretch == 1) ? true : stretch # This bloody type unstable and does not test stupid inputs
162164
img = Array{type, 3}(undef, size(cube,1), size(cube,2), 3)
163-
img[:,:,1] = rescale(@view(cube[:,:,layers[1]]), stretch=true, type=type)
164-
img[:,:,2] = rescale(@view(cube[:,:,layers[2]]), stretch=true, type=type)
165-
img[:,:,3] = rescale(@view(cube[:,:,layers[3]]), stretch=true, type=type)
165+
img[:,:,1] = rescale(@view(cube[:,:,layers[1]]), stretch=stch, type=type)
166+
img[:,:,2] = rescale(@view(cube[:,:,layers[2]]), stretch=stch, type=type)
167+
img[:,:,3] = rescale(@view(cube[:,:,layers[3]]), stretch=stch, type=type)
166168
Io = mat2img(img, cube); Io.layout = "TRBa"
167169
Io
168170
end
169171

170-
function truecolor(cube::String; bands::Vector{Int}=Int[], layers::Vector{Int}=Int[], bandnames::Vector{String}=String[], raw::Bool=false)
172+
function truecolor(cube::String; bands::Vector{Int}=Int[], layers::Vector{Int}=Int[], bandnames::Vector{String}=String[],
173+
raw::Bool=false, stretch=true)
171174
# The `raw` option returns a GMTimage{UInt16, 3} and does not convert to UInt8 with auto-stretch (default)
172175
(isempty(bands) && isempty(bandnames) && isempty(layers)) && (bandnames = ["red", "green", "blue"])
173176
rgb = subcube(cube, bands=bands, layers=layers, bandnames=bandnames)
174177
(raw) && return rgb
175-
return (eltype(rgb) <: AbstractFloat) ? truecolor(rgb, layers=[1,2,3]) : mat2img(rgb, stretch=:auto)
178+
return (eltype(rgb) <: AbstractFloat) ? truecolor(rgb, layers=[1,2,3], stretch=stretch) : mat2img(rgb, stretch=stretch)
176179
end
177180

178181
# ----------------------------------------------------------------------------------------------------------
@@ -850,6 +853,9 @@ Returns the trained model and the class names.
850853
"""
851854
function train_raster(cube::GItype, train::Union{Vector{<:GMTdataset}, String}; np::Int=0, density=0.1, max_depth=3)
852855
samples = isa(train, String) ? gmtread(train) : train
856+
get(samples[1].attrib, "class", "") == "" && error("The datasets used for training MUST have an attribute called 'class'.")
857+
(get(samples[1].attrib, "id", "") == "") && add_class_id!(samples) # If no 'id' attribute, create one from 'class'
858+
853859
pts = randinpolygon(samples, np=np, density=density)
854860
LCsamp = grdinterpolate(cube, S=pts, nocoords=true)
855861
features = GMT.ds2ds(LCsamp)
@@ -861,6 +867,22 @@ function train_raster(cube::GItype, train::Union{Vector{<:GMTdataset}, String};
861867
return model, classes
862868
end
863869

870+
function add_class_id!(D::Vector{<:GMTdataset})
871+
seen = Dict{String,Int}()
872+
n = 0
873+
for d in D
874+
v = d.attrib["class"]
875+
cls = v isa Vector ? first(v) : v
876+
haskey(seen, cls) || (n += 1; seen[cls] = n)
877+
end
878+
for d in D
879+
v = d.attrib["class"]
880+
cls = v isa Vector ? first(v) : v
881+
d.attrib["id"] = string(seen[cls])
882+
end
883+
return nothing
884+
end
885+
864886
# ----------------------------------------------------------------------------------------------------------
865887
#=
866888
From https://rspatial.org/rs/5-supclassification.html

0 commit comments

Comments
 (0)