gridscore - mnistpct20 - kNN-Lwplsrda

using Jchemo, JchemoData
using JLD2, CairoMakie
using FreqTables

Data importation

path_jdat = dirname(dirname(pathof(JchemoData)))
db = joinpath(path_jdat, "data/mnist20pct.jld2") 
@load db dat
@names dat
(:Xtrain, :ytrain, :Xtest, :ytest)
Xtrain = dat.Xtrain
ytrain = dat.ytrain
Xtest = dat.Xtest
ytest = dat.ytest
@head Xtrain
@head Xtest
tab(ytrain)
tab(ytest)
ntrain, p = size(Xtrain)
ntest = nro(Xtest)
ntot = ntrain + ntest
(ntot = ntot, ntrain, ntest)
... (12000, 784)
 
... (2000, 784)
 
(ntot = 14000, ntrain = 12000, ntest = 2000)
3×784 DataFrame
684 columns omitted
Row1x11x21x31x41x51x61x71x81x91x101x111x121x131x141x151x161x171x181x191x201x211x221x231x241x251x261x271x282x12x22x32x42x52x62x72x82x92x102x112x122x132x142x152x162x172x182x192x202x212x222x232x242x252x262x272x283x13x23x33x43x53x63x73x83x93x103x113x123x133x143x153x163x173x183x193x203x213x223x233x243x253x263x273x284x14x24x34x44x54x64x74x84x94x104x114x124x134x144x154x16
Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32
10.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
20.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
30.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
3×784 DataFrame
684 columns omitted
Row1x11x21x31x41x51x61x71x81x91x101x111x121x131x141x151x161x171x181x191x201x211x221x231x241x251x261x271x282x12x22x32x42x52x62x72x82x92x102x112x122x132x142x152x162x172x182x192x202x212x222x232x242x252x262x272x283x13x23x33x43x53x63x73x83x93x103x113x123x133x143x153x163x173x183x193x203x213x223x233x243x253x263x273x284x14x24x34x44x54x64x74x84x94x104x114x124x134x144x154x16
Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32
10.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
20.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
30.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0

Grey levels 0-255 are standardized between 0-1 (the standardization is not required here but used when fitting deep learning models).

Xtrain = Matrix(Xtrain) / 255
Xtest = Matrix(Xtest) / 255
2000×784 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱                 ⋮              
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0

Example of one observation (= one unfolded image):

plotsp(Xtrain, 1:p; nsamp = 1, xlabel = "Pixel", ylabel = "Grey level").f

Split Train to Cal/Val for model tuning

Below, Cal and Val sets are built by random sampling (other designs could be used).

nval = 1000
nval / ntrain # sampling proportion 
s = samprand(ntrain, nval)
(train = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10  …  11990, 11991, 11992, 11993, 11994, 11995, 11996, 11997, 11998, 12000], test = [17, 21, 29, 32, 38, 43, 77, 88, 111, 113  …  11896, 11913, 11919, 11928, 11929, 11942, 11943, 11967, 11985, 11999])
Xcal = Xtrain[s.train, :]
ycal = ytrain[s.train]
Xval = Xtrain[s.test, :]
yval = ytrain[s.test]
ncal = ntrain - nval 
(ntot = ntot, ntrain, ncal, nval, ntest)
(ntot = 14000, ntrain = 12000, ncal = 11000, nval = 1000, ntest = 2000)

Grid-search

nlvdis = [10; 20]; metric = [:mah]
h = [1; 2; 5; Inf]; k = [200; 300; 500; 1000]  
nlv = 0:15
pars = mpar(nlvdis = nlvdis, metric = metric, h = h, k = k) 
length(pars[1])
32
model = lwplsrda()
res = gridscore(model, Xcal, ycal, Xval, yval; score = errp, pars, nlv)
512×6 DataFrame
487 rows omitted
Rownlvdismetrichknlvy1
AnyAnyAnyAnyInt64Float32
110mah1.020000.083
210mah1.020010.054
310mah1.020020.043
410mah1.020030.041
510mah1.020040.033
610mah1.020050.034
710mah1.020060.033
810mah1.020070.032
910mah1.020080.034
1010mah1.020090.035
1110mah1.0200100.037
1210mah1.0200110.038
1310mah1.0200120.039
50120mahInf100040.061
50220mahInf100050.059
50320mahInf100060.055
50420mahInf100070.054
50520mahInf100080.054
50620mahInf100090.053
50720mahInf1000100.05
50820mahInf1000110.049
50920mahInf1000120.049
51020mahInf1000130.052
51120mahInf1000140.053
51220mahInf1000150.054
group = string.("nvldis=", res.nlvdis, " h=", res.h, " k=", res.k)
plotgrid(res.nlv, res.y1, group; step = 2, xlabel = "Nb. LVs", ylabel = "ERRP-Val").f

Selection of the best parameter combination

u = findall(res.y1 .== minimum(res.y1))[1] 
res[u, :]
DataFrameRow (6 columns)
Rownlvdismetrichknlvy1
AnyAnyAnyAnyInt64Float32
15120mah1.030060.025

Final prediction (Test) using the optimal model

model = lwplsrda(nlvdis = res.nlvdis[u], metric = res.metric[u], h = res.h[u], k = res.k[u], nlv = res.nlv[u])
fit!(model, Xtrain, ytrain)
@head pred = predict(model, Xtest).pred
3×1 Matrix{Float32}:
 1.0
 0.0
 4.0
... (2000, 1)

Generalization error

errp(pred, ytest)
1×1 Matrix{Float64}:
 0.0275
merrp(pred, ytest)
1×1 Matrix{Float64}:
 0.0280975117313055
cf = conf(pred, ytest)
@names cf
(:cnt, :pct, :A, :Apct, :diagpct, :accpct, :lev)
cf.cnt
10×11 DataFrame
Rowypred_0.0pred_1.0pred_2.0pred_3.0pred_4.0pred_5.0pred_6.0pred_7.0pred_8.0pred_9.0
StringInt64Int64Int64Int64Int64Int64Int64Int64Int64Int64
10.0195000001000
21.0022601000000
32.0012050000100
43.0002194020030
54.0111019200002
65.0200201674111
76.0100011189000
87.0133000019602
98.0000211011891
109.0100332010192
cf.pct
10×11 DataFrame
Rowlevelspred_0.0pred_1.0pred_2.0pred_3.0pred_4.0pred_5.0pred_6.0pred_7.0pred_8.0pred_9.0
StringFloat64Float64Float64Float64Float64Float64Float64Float64Float64Float64
10.099.50.00.00.00.00.00.50.00.00.0
21.00.099.60.00.40.00.00.00.00.00.0
32.00.00.599.00.00.00.00.00.50.00.0
43.00.00.01.096.50.01.00.00.01.50.0
54.00.50.50.50.097.50.00.00.00.01.0
65.01.10.00.01.10.093.82.20.60.60.6
76.00.50.00.00.00.50.598.40.00.00.0
87.00.51.51.50.00.00.00.095.60.01.0
98.00.00.00.01.00.50.50.00.596.90.5
109.00.50.00.01.51.51.00.00.50.095.0
plotconf(cf).f
plotconf(cf; cnt = false).f