gridscore - mnistpct20 - kNN-Lwplsrda

using Jchemo, JchemoData
using JLD2, CairoMakie
using FreqTables

Data importation

path_jdat = dirname(dirname(pathof(JchemoData)))
db = joinpath(path_jdat, "data/mnist20pct.jld2") 
@load db dat
@names dat
(:Xtrain, :ytrain, :Xtest, :ytest)
Xtrain = dat.Xtrain
ytrain = dat.ytrain
Xtest = dat.Xtest
ytest = dat.ytest
@head Xtrain
@head Xtest
tab(ytrain)
tab(ytest)
ntrain, p = size(Xtrain)
ntest = nro(Xtest)
ntot = ntrain + ntest
(ntot = ntot, ntrain, ntest)
... (12000, 784)
 
... (2000, 784)
 
(ntot = 14000, ntrain = 12000, ntest = 2000)
3×784 DataFrame
684 columns omitted
Row1x11x21x31x41x51x61x71x81x91x101x111x121x131x141x151x161x171x181x191x201x211x221x231x241x251x261x271x282x12x22x32x42x52x62x72x82x92x102x112x122x132x142x152x162x172x182x192x202x212x222x232x242x252x262x272x283x13x23x33x43x53x63x73x83x93x103x113x123x133x143x153x163x173x183x193x203x213x223x233x243x253x263x273x284x14x24x34x44x54x64x74x84x94x104x114x124x134x144x154x16
Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32
10.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
20.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
30.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
3×784 DataFrame
684 columns omitted
Row1x11x21x31x41x51x61x71x81x91x101x111x121x131x141x151x161x171x181x191x201x211x221x231x241x251x261x271x282x12x22x32x42x52x62x72x82x92x102x112x122x132x142x152x162x172x182x192x202x212x222x232x242x252x262x272x283x13x23x33x43x53x63x73x83x93x103x113x123x133x143x153x163x173x183x193x203x213x223x233x243x253x263x273x284x14x24x34x44x54x64x74x84x94x104x114x124x134x144x154x16
Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32Float32
10.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
20.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
30.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0

Grey levels 0-255 standardized between 0-1 (not required here but used when fitting deep learning models)

Xtrain = Matrix(Xtrain) / 255
Xtest = Matrix(Xtest) / 255
2000×784 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱                 ⋮              
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0

Example of one sample (= one unfolded image)

plotsp(Xtrain, 1:p; nsamp = 1, xlabel = "Pixel", ylabel = "Grey level").f

Split Train to Cal/Val for model tuning

Below, Cal and Val are built by random sampling (other designs could be used)

nval = 1000
nval / ntrain # sampling proportion 
s = samprand(ntrain, nval)
(train = [1, 2, 3, 4, 5, 6, 8, 10, 11, 12  …  11991, 11992, 11993, 11994, 11995, 11996, 11997, 11998, 11999, 12000], test = [7, 9, 20, 30, 42, 55, 70, 84, 93, 96  …  11839, 11845, 11853, 11869, 11876, 11877, 11900, 11905, 11925, 11934])
Xcal = Xtrain[s.train, :]
ycal = ytrain[s.train]
Xval = Xtrain[s.test, :]
yval = ytrain[s.test]
ncal = ntrain - nval 
(ntot = ntot, ntrain, ncal, nval, ntest)
(ntot = 14000, ntrain = 12000, ncal = 11000, nval = 1000, ntest = 2000)

Grid-search

nlvdis = [10; 20]; metric = [:mah]
h = [1; 2; 5; Inf]; k = [200; 300; 500; 1000]  
nlv = 0:15
pars = mpar(nlvdis = nlvdis, metric = metric, h = h, k = k) 
length(pars[1])
32
model = lwplsrda()
res = gridscore(model, Xcal, ycal, Xval, yval; score = errp, pars, nlv)
512×6 DataFrame
487 rows omitted
Rownlvdismetrichknlvy1
AnyAnyAnyAnyInt64Float32
110mah1.020000.083
210mah1.020010.048
310mah1.020020.041
410mah1.020030.036
510mah1.020040.031
610mah1.020050.031
710mah1.020060.037
810mah1.020070.037
910mah1.020080.039
1010mah1.020090.039
1110mah1.0200100.042
1210mah1.0200110.041
1310mah1.0200120.037
50120mahInf100040.064
50220mahInf100050.062
50320mahInf100060.061
50420mahInf100070.061
50520mahInf100080.056
50620mahInf100090.056
50720mahInf1000100.06
50820mahInf1000110.06
50920mahInf1000120.06
51020mahInf1000130.057
51120mahInf1000140.062
51220mahInf1000150.061
group = string.("nvldis=", res.nlvdis, " h=", res.h, " k=", res.k)
plotgrid(res.nlv, res.y1, group; step = 2, xlabel = "Nb. LVs", ylabel = "ERRP-Val").f

Selection of the best parameter combination

u = findall(res.y1 .== minimum(res.y1))[1] 
res[u, :]
DataFrameRow (6 columns)
Rownlvdismetrichknlvy1
AnyAnyAnyAnyInt64Float32
15320mah1.030080.026

Final prediction (Test) using the optimal model

model = lwplsrda(nlvdis = res.nlvdis[u], metric = res.metric[u], h = res.h[u], 
    k = res.k[u], nlv = res.nlv[u])
fit!(model, Xtrain, ytrain)
@head pred = predict(model, Xtest).pred
3×1 Matrix{Float32}:
 1.0
 0.0
 4.0
... (2000, 1)

Generalization error

errp(pred, ytest)
1×1 Matrix{Float64}:
 0.0275
merrp(pred, ytest)
1×1 Matrix{Float64}:
 0.02810010015224467
cf = conf(pred, ytest)
@names cf
(:cnt, :pct, :A, :Apct, :diagpct, :accpct, :lev)
cf.cnt
10×11 DataFrame
Rowypred_0.0pred_1.0pred_2.0pred_3.0pred_4.0pred_5.0pred_6.0pred_7.0pred_8.0pred_9.0
StringInt64Int64Int64Int64Int64Int64Int64Int64Int64Int64
10.0195000001000
21.0022502000000
32.0012041000100
43.0001196010030
54.0111019200002
65.0200201663131
76.0100011189000
87.0124000019701
98.0000210101892
109.0100232020192
cf.pct
10×11 DataFrame
Rowlevelspred_0.0pred_1.0pred_2.0pred_3.0pred_4.0pred_5.0pred_6.0pred_7.0pred_8.0pred_9.0
StringFloat64Float64Float64Float64Float64Float64Float64Float64Float64Float64
10.099.50.00.00.00.00.00.50.00.00.0
21.00.099.10.00.90.00.00.00.00.00.0
32.00.00.598.60.50.00.00.00.50.00.0
43.00.00.00.597.50.00.50.00.01.50.0
54.00.50.50.50.097.50.00.00.00.01.0
65.01.10.00.01.10.093.31.70.61.70.6
76.00.50.00.00.00.50.598.40.00.00.0
87.00.51.02.00.00.00.00.096.10.00.5
98.00.00.00.01.00.50.00.50.096.91.0
109.00.50.00.01.01.51.00.01.00.095.0
plotconf(cf).f
plotconf(cf; cnt = false).f