gridcv - forages2 - kNN-Lwplsrda
using Jchemo, JchemoData
using JLD2, CairoMakie
using FreqTables
Data importation
path_jdat = dirname(dirname(pathof(JchemoData)))
db = joinpath(path_jdat, "data/forages2.jld2")
@load db dat
@names dat
(:X, :Y)
X = dat.X
@head X
... (485, 700)
3×700 DataFrame
600 columns omitted
1 | -0.000231591 | -0.000175945 | -8.48176e-5 | 2.05217e-5 | 0.000110094 | 0.000161757 | 0.000154953 | 0.000163754 | 0.000187602 | 0.00021499 | 0.000242479 | 0.000265498 | 0.000282141 | 0.000281442 | 0.000271025 | 0.000261075 | 0.000257284 | 0.000252177 | 0.00024293 | 0.000228295 | 0.000219097 | 0.000214136 | 0.000215612 | 0.000218982 | 0.000228004 | 0.000236081 | 0.000236017 | 0.000220327 | 0.000187096 | 0.000137138 | 7.68593e-5 | 1.13679e-5 | -5.00951e-5 | -9.54664e-5 | -0.000119199 | -0.000131897 | -0.000142349 | -0.000161489 | -0.00019387 | -0.000244808 | -0.000303259 | -0.000366904 | -0.000416738 | -0.000451535 | -0.00046995 | -0.000478637 | -0.000477348 | -0.000478142 | -0.000476719 | -0.000479701 | -0.000482037 | -0.000496769 | -0.000511959 | -0.000532094 | -0.000542661 | -0.000540188 | -0.000512715 | -0.00045798 | -0.000370395 | -0.000256256 | -0.000126907 | 1.13716e-6 | 0.000119047 | 0.000212745 | 0.000275685 | 0.000307863 | 0.000313547 | 0.000296977 | 0.000269661 | 0.000247818 | 0.000233944 | 0.000228773 | 0.000224567 | 0.000221256 | 0.000218893 | 0.000217741 | 0.000210144 | 0.00019664 | 0.000181949 | 0.000169774 | 0.000151691 | 0.00012385 | 9.23378e-5 | 5.9959e-5 | 2.58352e-5 | -4.77314e-6 | -3.21835e-5 | -5.53154e-5 | -6.71707e-5 | -6.54166e-5 | -5.16448e-5 | -2.43366e-5 | 1.12255e-5 | 4.68917e-5 | 7.773e-5 | 0.000106785 | 0.000133173 | 0.000153607 | 0.000168518 | 0.000182591 | ⋯ |
2 | -9.66352e-5 | -3.30928e-5 | 5.64966e-5 | 0.000154135 | 0.000237725 | 0.000295789 | 0.000319587 | 0.000357405 | 0.000404611 | 0.000447996 | 0.000479786 | 0.000488339 | 0.000465929 | 0.000402301 | 0.000313648 | 0.000220226 | 0.000138483 | 7.35084e-5 | 3.50018e-5 | 2.83293e-5 | 6.05478e-5 | 0.000118272 | 0.000187726 | 0.000249842 | 0.00029697 | 0.000315062 | 0.000298828 | 0.000251643 | 0.000187055 | 0.000118243 | 5.60849e-5 | 3.8727e-6 | -3.28778e-5 | -4.84688e-5 | -4.38912e-5 | -3.34954e-5 | -2.72637e-5 | -3.65483e-5 | -6.62949e-5 | -0.000121833 | -0.000193587 | -0.000280244 | -0.000362132 | -0.000434981 | -0.000494461 | -0.000546531 | -0.000590606 | -0.000638514 | -0.000684688 | -0.000734688 | -0.000783664 | -0.000842714 | -0.000892596 | -0.000930301 | -0.000938118 | -0.000913585 | -0.000846217 | -0.000737781 | -0.000588122 | -0.000410395 | -0.000220611 | -3.69382e-5 | 0.000131072 | 0.000266078 | 0.000358377 | 0.000408684 | 0.000424528 | 0.000412147 | 0.000383896 | 0.000357957 | 0.000338385 | 0.000326749 | 0.000315572 | 0.00030542 | 0.000293671 | 0.000280005 | 0.000259482 | 0.000233697 | 0.0002044 | 0.000177199 | 0.000147989 | 0.000112325 | 7.33317e-5 | 3.48779e-5 | -2.5229e-6 | -3.27922e-5 | -5.52233e-5 | -7.06412e-5 | -7.49675e-5 | -6.44041e-5 | -4.04393e-5 | -6.50489e-6 | 3.09196e-5 | 6.87358e-5 | 0.000105202 | 0.000142313 | 0.000177182 | 0.000206652 | 0.000230788 | 0.000253703 | ⋯ |
3 | -0.000131769 | -7.8398e-5 | 7.92223e-7 | 8.90044e-5 | 0.000160022 | 0.000198435 | 0.000196598 | 0.000212225 | 0.000241109 | 0.000271235 | 0.000301045 | 0.000324921 | 0.000337619 | 0.000325857 | 0.00029979 | 0.000277167 | 0.00027018 | 0.00027165 | 0.000277606 | 0.000287722 | 0.000308203 | 0.000324847 | 0.000328573 | 0.000310806 | 0.00027728 | 0.000226898 | 0.000160474 | 8.30948e-5 | 7.98825e-6 | -5.32827e-5 | -9.57157e-5 | -0.000123438 | -0.0001371 | -0.000134382 | -0.00011527 | -9.07963e-5 | -6.97458e-5 | -6.29138e-5 | -7.14491e-5 | -9.85941e-5 | -0.000137562 | -0.000192678 | -0.000248177 | -0.000303993 | -0.000356125 | -0.000407616 | -0.0004553 | -0.000507819 | -0.000555473 | -0.000603436 | -0.000647099 | -0.000701763 | -0.000754429 | -0.000806879 | -0.000838493 | -0.000842167 | -0.000803445 | -0.000720829 | -0.000592138 | -0.000428566 | -0.000245567 | -6.43964e-5 | 0.000101193 | 0.000232242 | 0.000322133 | 0.000373605 | 0.000391817 | 0.000379332 | 0.000347829 | 0.000316495 | 0.000292236 | 0.000278431 | 0.000264621 | 0.000250305 | 0.000239387 | 0.000234504 | 0.000224633 | 0.000205684 | 0.000180408 | 0.000157615 | 0.000135108 | 0.000106871 | 7.3258e-5 | 3.90321e-5 | 7.34127e-6 | -1.78231e-5 | -3.94282e-5 | -5.6427e-5 | -6.15935e-5 | -5.19038e-5 | -2.96367e-5 | 3.09722e-6 | 3.98752e-5 | 7.62892e-5 | 0.000108271 | 0.000137632 | 0.000165624 | 0.000191182 | 0.000211586 | 0.000229586 | ⋯ |
Y = dat.Y
@head Y
... (485, 4)
1 | 92.23 | 37.58 | Legume forages | 1 |
2 | 93.26 | 49.6462 | Legume forages | 0 |
3 | 92.9 | 63.2939 | Forage trees | 0 |
y = Y.typ
test = Y.test
tab(y)
OrderedCollections.OrderedDict{String, Int64} with 3 entries:
"Cereal and grass forages" => 160
"Forage trees" => 101
"Legume forages" => 224
freqtable(y, test)
3×2 Named Matrix{Int64}
Dim1 ╲ Dim2 │ 0 1
─────────────────────────┼─────────
Cereal and grass forages │ 100 60
Forage trees │ 56 45
Legume forages │ 167 57
wlst = names(X)
wl = parse.(Int, wlst)
#plotsp(X, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
700-element Vector{Int64}:
1100
1102
1104
1106
1108
1110
1112
1114
1116
1118
⋮
2482
2484
2486
2488
2490
2492
2494
2496
2498
Note:: X-data are already preprocessed (SNV + Savitsky-Golay 2nd deriv).
Split Tot to Train/Test
The model is fitted on Train, and the generalization error is estimated on Test. In this example, Train is already defined in variable typ
of the dataset, and Test is defined by the remaining samples. But Tot could also be split a posteriori, for instance by sampling (random, systematic or any other designs). See for instance functions samprand
, sampsys
, etc.
s = Bool.(test)
Xtrain = rmrow(X, s)
ytrain = rmrow(y, s)
Xtest = X[s, :]
ytest = y[s]
ntot = nro(X)
ntrain = nro(Xtrain)
ntest = nro(Xtest)
(ntot = ntot, ntrain, ntest)
(ntot = 485, ntrain = 323, ntest = 162)
tab(ytrain)
OrderedCollections.OrderedDict{String, Int64} with 3 entries:
"Cereal and grass forages" => 100
"Forage trees" => 56
"Legume forages" => 167
tab(ytest)
OrderedCollections.OrderedDict{String, Int64} with 3 entries:
"Cereal and grass forages" => 60
"Forage trees" => 45
"Legume forages" => 57
K-fold CV
K = 3 # nb. folds (segments)
rep = 1 # nb. replications
segm = segmkf(ntrain, K; rep = rep)
1-element Vector{Vector{Vector{Int64}}}:
[[3, 5, 7, 9, 12, 14, 15, 16, 17, 21 … 286, 290, 291, 292, 294, 296, 299, 300, 317, 322], [1, 8, 10, 13, 20, 22, 23, 26, 27, 29 … 308, 310, 311, 312, 314, 316, 318, 319, 320, 321], [2, 4, 6, 11, 18, 19, 25, 31, 33, 36 … 289, 293, 298, 303, 304, 307, 309, 313, 315, 323]]
nlvdis = [15; 25]; metric = [:mah]
h = [1; 2; 4; 6; Inf]; k = [30; 50; 100]
nlv = 0:15
pars = mpar(nlvdis = nlvdis, metric = metric, h = h, k = k)
length(pars[1])
30
model = lwplsrda()
res = gridcv(model, Xtrain, ytrain; segm, score = merrp, pars, nlv).res
480×6 DataFrame
455 rows omitted
1 | 0 | 15 | mah | 1.0 | 30 | 0.127536 |
2 | 1 | 15 | mah | 1.0 | 30 | 0.122971 |
3 | 2 | 15 | mah | 1.0 | 30 | 0.12443 |
4 | 3 | 15 | mah | 1.0 | 30 | 0.133611 |
5 | 4 | 15 | mah | 1.0 | 30 | 0.120539 |
6 | 5 | 15 | mah | 1.0 | 30 | 0.105979 |
7 | 6 | 15 | mah | 1.0 | 30 | 0.11591 |
8 | 7 | 15 | mah | 1.0 | 30 | 0.109777 |
9 | 8 | 15 | mah | 1.0 | 30 | 0.0992729 |
10 | 9 | 15 | mah | 1.0 | 30 | 0.1028 |
11 | 10 | 15 | mah | 1.0 | 30 | 0.104784 |
12 | 11 | 15 | mah | 1.0 | 30 | 0.104764 |
13 | 12 | 15 | mah | 1.0 | 30 | 0.111248 |
⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ |
469 | 4 | 25 | mah | Inf | 100 | 0.151948 |
470 | 5 | 25 | mah | Inf | 100 | 0.138172 |
471 | 6 | 25 | mah | Inf | 100 | 0.119151 |
472 | 7 | 25 | mah | Inf | 100 | 0.121399 |
473 | 8 | 25 | mah | Inf | 100 | 0.12778 |
474 | 9 | 25 | mah | Inf | 100 | 0.129056 |
475 | 10 | 25 | mah | Inf | 100 | 0.125711 |
476 | 11 | 25 | mah | Inf | 100 | 0.116371 |
477 | 12 | 25 | mah | Inf | 100 | 0.112023 |
478 | 13 | 25 | mah | Inf | 100 | 0.12035 |
479 | 14 | 25 | mah | Inf | 100 | 0.125206 |
480 | 15 | 25 | mah | Inf | 100 | 0.124933 |
group = string.("nvldis=", res.nlvdis, " h=", res.h, " k=", res.k)
plotgrid(res.nlv, res.y1, group; step = 2, xlabel = "Nb. LVs", ylabel = "ERRP-CV", leg_title = "Continuum").f
Selection of the best parameter combination
u = findall(res.y1 .== minimum(res.y1))[1]
res[u, :]
Final prediction (Test) using the optimal model
model = lwplsrda(nlvdis = res.nlvdis[u], metric = res.metric[u], h = res.h[u],
k = res.k[u], nlv = res.nlv[u])
fit!(model, Xtrain, ytrain)
pred = predict(model, Xtest).pred
162×1 Matrix{String}:
"Legume forages"
"Cereal and grass forages"
"Cereal and grass forages"
"Legume forages"
"Cereal and grass forages"
"Cereal and grass forages"
"Legume forages"
"Forage trees"
"Forage trees"
"Forage trees"
⋮
"Cereal and grass forages"
"Cereal and grass forages"
"Forage trees"
"Forage trees"
"Cereal and grass forages"
"Legume forages"
"Legume forages"
"Legume forages"
"Legume forages"
Generalization error
errp(pred, ytest)
1×1 Matrix{Float64}:
0.08641975308641975
merrp(pred, ytest)
1×1 Matrix{Float64}:
0.08849902534113059
cf = conf(pred, ytest)
@names cf
(:cnt, :pct, :A, :Apct, :diagpct, :accpct, :lev)
cf.cnt
1 | Cereal and grass forages | 56 | 1 | 3 |
2 | Forage trees | 1 | 40 | 4 |
3 | Legume forages | 5 | 0 | 52 |
cf.pct
1 | Cereal and grass forages | 93.3 | 1.7 | 5.0 |
2 | Forage trees | 2.2 | 88.9 | 8.9 |
3 | Legume forages | 8.8 | 0.0 | 91.2 |