gridcv - tecator - Plsr

using Jchemo, JchemoData
using JLD2, CairoMakie

Data importation

path_jdat = dirname(dirname(pathof(JchemoData)))
db = joinpath(path_jdat, "data/tecator.jld2") 
@load db dat
@names dat
(:X, :Y)
X = dat.X
@head X
... (178, 100)
3×100 DataFrame
Row8508528548568588608628648668688708728748768788808828848868888908928948968989009029049069089109129149169189209229249269289309329349369389409429449469489509529549569589609629649669689709729749769789809829849869889909929949969981000100210041006100810101012101410161018102010221024102610281030103210341036103810401042104410461048
Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64
12.617762.618142.618592.619122.619812.620712.621862.623342.625112.627222.629642.632452.635652.639332.643532.648252.65352.659372.665852.672812.680082.687332.694272.700732.706842.712812.719142.726282.734622.744162.754662.765682.776792.78792.799492.812252.827062.843562.861062.878572.894972.909242.920852.930152.938462.947712.960192.978313.003063.035063.074283.119633.168683.217713.262543.299883.328473.348993.363423.373793.381523.387413.391643.394183.39493.393663.390453.385413.378693.370413.360733.349793.337693.324433.310133.294873.278913.262323.245423.228283.21083.192873.174333.155033.134753.113393.091163.06853.045963.023933.002472.981452.960722.940132.919782.899662.879642.85962.83942.8192
22.834542.838712.842832.847052.851382.855872.86062.865662.870932.876612.882642.888982.895772.903082.910972.919532.928732.938632.949292.960722.972722.984932.99693.008333.01923.02993.041013.053453.067773.084163.102213.121063.139833.15813.176233.195193.215843.237473.258893.278353.293843.303623.306813.303933.2973.289253.284093.285053.293263.309233.332673.362513.396613.431883.464923.492953.514583.530043.540673.547973.553063.556753.559213.560453.560343.558763.555713.551323.545853.53953.532353.524423.515833.506683.4973.486833.476263.465523.455013.444813.434773.424653.414193.403033.390823.377313.362653.347453.332453.318183.304733.291863.279213.266553.253693.240453.226593.211813.1963.17942
32.582842.584582.586292.588082.589962.591922.594012.596272.598732.601312.604142.607142.610292.613612.617142.620892.624862.629092.633612.638352.64332.648382.653542.65872.663752.66882.673832.678922.684112.689372.69472.700122.705632.711412.717752.72492.733442.743272.754332.766422.779312.792722.806492.820642.835412.851212.868722.889052.912892.940882.973253.009463.04783.085543.119473.146963.166773.179383.186313.189243.18953.188013.184983.180393.174113.166113.156413.145123.132413.118433.103293.087143.070143.052373.033933.015042.995692.976122.956422.93662.916672.896552.876222.855632.834742.813612.792352.771132.750152.729562.709342.689512.670092.651122.632622.614612.597182.580342.564042.54816
Y = dat.Y 
@head Y
... (178, 4)
3×4 DataFrame
Rowwaterfatproteintyp
Float64Float64Float64String
160.522.516.7train
246.040.113.5train
371.08.420.5train
typ = Y.typ
tab(typ)
OrderedCollections.OrderedDict{String, Int64} with 3 entries:
  "test"  => 31
  "train" => 115
  "val"   => 32
wlst = names(X)
wl = parse.(Float64, wlst) 
#plotsp(X, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
100-element Vector{Float64}:
  850.0
  852.0
  854.0
  856.0
  858.0
  860.0
  862.0
  864.0
  866.0
  868.0
    ⋮
 1032.0
 1034.0
 1036.0
 1038.0
 1040.0
 1042.0
 1044.0
 1046.0
 1048.0

Preprocessing

model1 = snv()
model2 = savgol(npoint = 15, deriv = 2, degree = 3)
model = pip(model1, model2)
fit!(model, X)
@head Xp = transf(model, X) 
#plotsp(Xp, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
3×100 Matrix{Float64}:
 0.000397076  0.000495203  0.000598623  …  0.00827138  0.00917311  0.00946072
 0.00242055   0.00244366   0.00234233      0.00580631  0.00689249  0.00749316
 0.0011927    0.00122721   0.00120098      0.0101019   0.0108142   0.0108444
... (178, 100)

Split Tot to Train/Test

The model is fitted on Train, and the generalization error is estimated on Test. In this example, Train is already defined in variable typ of the dataset, and Test is defined by the remaining samples. But Tot could also be split a posteriori, for instance by sampling (random, systematic or any other designs). See for instance functions samprand, sampsys, etc.

s = typ .== "train"
Xtrain = Xp[s, :] 
Ytrain = Y[s, :]
Xtest = rmrow(Xp, s)
Ytest = rmrow(Y, s)
ntrain = nro(Xtrain)
ntest = nro(Xtest)
ntot = ntrain + ntest
(ntot = ntot, ntrain, ntest)
(ntot = 178, ntrain = 115, ntest = 63)

Working response y

namy = names(Y)[1:3]
j = 2  
nam = namy[j]    # work on the j-th y-variable
ytrain = Ytrain[:, nam]
ytest = Ytest[:, nam]
63-element Vector{Float64}:
 29.8
  1.4
  4.6
 11.0
 17.0
 22.4
 27.9
 46.5
  6.1
  2.0
  ⋮
 18.1
 19.4
 24.8
 27.2
 28.4
 31.3
 33.8
 35.5
 42.5

CV-Segments for model tuning

Two methods can be used to build the CV segments within Train (for the same total number of segments, these two methods return close resuts):

  • (1) Replicated K-fold CV

    • Train is splitted in a number of K folds (segments),

    • and this split can be replicated (==> replicated K-Fold CV).

K = 3     # nb. folds (segments)
rep = 25  # nb. replications
segm = segmkf(ntrain, K; rep = rep)
25-element Vector{Vector{Vector{Int64}}}:
 [[3, 4, 5, 6, 12, 16, 22, 24, 26, 28  …  91, 93, 94, 95, 98, 100, 104, 105, 107, 115], [1, 2, 7, 8, 13, 14, 15, 17, 18, 25  …  90, 92, 99, 106, 108, 109, 110, 111, 112, 113], [9, 10, 11, 19, 20, 21, 23, 27, 30, 31  …  81, 84, 85, 86, 96, 97, 101, 102, 103, 114]]
 [[4, 6, 9, 11, 13, 14, 16, 18, 19, 20  …  77, 85, 89, 94, 97, 103, 108, 111, 112, 114], [1, 2, 5, 7, 12, 15, 24, 25, 26, 29  …  90, 92, 96, 98, 99, 101, 106, 107, 110, 113], [3, 8, 10, 17, 21, 22, 27, 35, 37, 40  …  84, 91, 93, 95, 100, 102, 104, 105, 109, 115]]
 [[3, 4, 13, 14, 16, 18, 24, 28, 30, 32  …  93, 99, 101, 105, 107, 109, 110, 111, 114, 115], [5, 6, 7, 10, 11, 19, 20, 21, 22, 25  …  84, 85, 86, 89, 92, 94, 96, 102, 103, 106], [1, 2, 8, 9, 12, 15, 17, 23, 27, 31  …  90, 91, 95, 97, 98, 100, 104, 108, 112, 113]]
 [[1, 7, 10, 12, 14, 16, 22, 24, 26, 28  …  76, 82, 83, 84, 85, 94, 96, 101, 108, 115], [2, 3, 13, 15, 17, 18, 21, 23, 36, 37  …  95, 97, 98, 100, 102, 103, 104, 105, 109, 110], [4, 5, 6, 8, 9, 11, 19, 20, 25, 27  …  89, 90, 92, 99, 106, 107, 111, 112, 113, 114]]
 [[2, 4, 6, 10, 19, 22, 25, 30, 31, 32  …  88, 90, 99, 101, 103, 104, 105, 107, 108, 109], [1, 3, 7, 8, 9, 11, 14, 15, 16, 17  …  80, 82, 89, 95, 106, 110, 112, 113, 114, 115], [5, 12, 13, 18, 20, 21, 23, 26, 28, 29  …  91, 92, 93, 94, 96, 97, 98, 100, 102, 111]]
 [[2, 4, 5, 18, 21, 26, 30, 31, 34, 35  …  90, 95, 98, 102, 105, 107, 112, 113, 114, 115], [8, 9, 10, 12, 13, 14, 15, 16, 17, 24  …  91, 94, 97, 99, 103, 104, 106, 108, 110, 111], [1, 3, 6, 7, 11, 19, 20, 22, 23, 28  …  78, 79, 81, 89, 92, 93, 96, 100, 101, 109]]
 [[1, 3, 7, 14, 16, 17, 18, 19, 23, 24  …  80, 81, 83, 86, 91, 92, 93, 94, 99, 107], [4, 5, 6, 8, 10, 11, 20, 28, 31, 32  …  97, 103, 104, 106, 109, 110, 112, 113, 114, 115], [2, 9, 12, 13, 15, 21, 22, 29, 33, 36  …  88, 90, 96, 98, 100, 101, 102, 105, 108, 111]]
 [[1, 3, 5, 9, 10, 16, 19, 21, 26, 28  …  86, 89, 93, 96, 100, 101, 105, 107, 109, 115], [2, 7, 11, 13, 14, 22, 24, 27, 30, 31  …  88, 91, 92, 97, 98, 99, 104, 112, 113, 114], [4, 6, 8, 12, 15, 17, 18, 20, 23, 25  …  85, 90, 94, 95, 102, 103, 106, 108, 110, 111]]
 [[5, 7, 10, 12, 13, 18, 20, 21, 22, 35  …  102, 103, 104, 105, 106, 109, 110, 111, 113, 114], [3, 4, 8, 9, 11, 16, 23, 25, 26, 27  …  77, 85, 90, 91, 94, 95, 107, 108, 112, 115], [1, 2, 6, 14, 15, 17, 19, 24, 28, 33  …  76, 79, 80, 81, 83, 84, 92, 96, 98, 101]]
 [[1, 3, 5, 12, 16, 17, 20, 22, 36, 38  …  92, 94, 95, 98, 99, 100, 103, 108, 109, 112], [4, 7, 11, 18, 19, 28, 29, 31, 33, 34  …  88, 93, 96, 97, 104, 105, 107, 110, 113, 115], [2, 6, 8, 9, 10, 13, 14, 15, 21, 23  …  79, 80, 83, 85, 90, 101, 102, 106, 111, 114]]
 ⋮
 [[2, 3, 6, 15, 16, 23, 24, 25, 30, 31  …  87, 90, 93, 94, 97, 98, 101, 111, 114, 115], [1, 4, 8, 11, 12, 14, 17, 19, 20, 22  …  77, 81, 83, 86, 88, 89, 96, 100, 108, 113], [5, 7, 9, 10, 13, 18, 21, 27, 33, 34  …  99, 102, 103, 104, 105, 106, 107, 109, 110, 112]]
 [[1, 5, 11, 14, 18, 20, 21, 25, 43, 44  …  99, 100, 101, 105, 106, 108, 109, 110, 111, 112], [2, 3, 4, 7, 9, 10, 16, 23, 24, 29  …  88, 90, 92, 93, 94, 95, 96, 103, 113, 115], [6, 8, 12, 13, 15, 17, 19, 22, 26, 27  …  78, 81, 82, 84, 86, 97, 102, 104, 107, 114]]
 [[3, 4, 7, 8, 11, 12, 16, 22, 29, 30  …  96, 98, 99, 100, 102, 106, 107, 108, 110, 115], [1, 9, 10, 13, 15, 17, 18, 19, 20, 23  …  81, 82, 84, 89, 92, 95, 97, 101, 109, 112], [2, 5, 6, 14, 21, 24, 25, 26, 27, 28  …  88, 90, 93, 94, 103, 104, 105, 111, 113, 114]]
 [[4, 5, 6, 7, 13, 19, 25, 27, 29, 31  …  89, 90, 92, 96, 102, 106, 108, 110, 111, 114], [10, 11, 12, 14, 15, 16, 17, 18, 22, 23  …  80, 82, 87, 91, 98, 99, 100, 104, 105, 109], [1, 2, 3, 8, 9, 20, 21, 28, 30, 35  …  93, 94, 95, 97, 101, 103, 107, 112, 113, 115]]
 [[6, 8, 10, 14, 15, 26, 30, 33, 34, 37  …  88, 93, 98, 99, 101, 104, 105, 110, 111, 115], [2, 9, 11, 12, 13, 17, 18, 20, 21, 23  …  81, 85, 89, 94, 95, 96, 102, 103, 107, 112], [1, 3, 4, 5, 7, 16, 19, 22, 25, 27  …  90, 91, 92, 97, 100, 106, 108, 109, 113, 114]]
 [[1, 4, 6, 8, 12, 13, 14, 21, 27, 29  …  85, 88, 91, 96, 98, 100, 101, 106, 112, 113], [5, 9, 16, 18, 19, 24, 25, 34, 35, 38  …  94, 95, 99, 105, 107, 108, 109, 110, 111, 114], [2, 3, 7, 10, 11, 15, 17, 20, 22, 23  …  77, 78, 83, 87, 92, 97, 102, 103, 104, 115]]
 [[1, 3, 6, 10, 12, 14, 20, 21, 23, 24  …  86, 87, 88, 89, 90, 96, 103, 106, 107, 112], [2, 4, 7, 9, 16, 18, 26, 29, 31, 32  …  91, 93, 94, 95, 97, 99, 102, 104, 109, 113], [5, 8, 11, 13, 15, 17, 19, 22, 25, 28  …  92, 98, 100, 101, 105, 108, 110, 111, 114, 115]]
 [[1, 3, 9, 13, 20, 21, 22, 27, 29, 31  …  87, 90, 98, 99, 100, 101, 106, 111, 113, 115], [2, 5, 8, 11, 12, 18, 23, 24, 25, 26  …  93, 95, 96, 97, 103, 104, 105, 107, 110, 112], [4, 6, 7, 10, 14, 15, 16, 17, 19, 28  …  86, 88, 89, 91, 92, 94, 102, 108, 109, 114]]
 [[2, 4, 9, 12, 16, 17, 18, 22, 25, 27  …  87, 88, 89, 94, 104, 105, 107, 108, 110, 111], [1, 3, 7, 8, 11, 15, 21, 29, 33, 36  …  85, 91, 95, 97, 99, 100, 101, 106, 109, 115], [5, 6, 10, 13, 14, 19, 20, 23, 24, 26  …  90, 92, 93, 96, 98, 102, 103, 112, 113, 114]]
  • (2) Replicated "Test-set" CV

    • Train is split to Cal/Val (e.g. Cal = 70% of Train, Val = 30% of Train), and this is replicated.

#pct = .30
#m = Int(round(pct * ntrain))
#segm = segmts(ntrain, m; rep = 30)

Illustration of segments:

i = 1  
segm[i]      # the K segments of replication 'i'
3-element Vector{Vector{Int64}}:
 [3, 4, 5, 6, 12, 16, 22, 24, 26, 28  …  91, 93, 94, 95, 98, 100, 104, 105, 107, 115]
 [1, 2, 7, 8, 13, 14, 15, 17, 18, 25  …  90, 92, 99, 106, 108, 109, 110, 111, 112, 113]
 [9, 10, 11, 19, 20, 21, 23, 27, 30, 31  …  81, 84, 85, 86, 96, 97, 101, 102, 103, 114]
k = 1
segm[i][k]   # segment 'k' of replication 'i'
39-element Vector{Int64}:
   3
   4
   5
   6
  12
  16
  22
  24
  26
  28
   ⋮
  93
  94
  95
  98
 100
 104
 105
 107
 115

Grid-search

The best syntax to use function gridcv for LV-based functions (eg. plskern, kplsr, lwplsr, etc.) is to set parameter nlv outside argument pars defining the grid (in general with function mpar). In that case, the computation time is reduced [See the naïve and non-optimized syntax at the end of this script]. This is the same principle when definining the parameter lb in ridge-based functions (rr, krr, etc.).

nlv = 0:20
model = plskern()
rescv = gridcv(model, Xtrain, ytrain; segm, score = rmsep, nlv)
@names rescv 
res = rescv.res
res_rep = rescv.res_rep
1575×4 DataFrame
1550 rows omitted
Rowrepsegmnlvy1
Int64Int64Int64Float64
111013.5666
21112.44251
31122.24555
41132.42385
51142.58804
61152.22739
71162.28135
81172.23079
91182.19531
101192.17222
1111102.16421
1211112.17249
1311122.23543
156425392.23427
1565253102.15717
1566253112.13504
1567253122.20721
1568253132.64451
1569253142.48442
1570253152.40507
1571253162.04081
1572253171.88561
1573253181.66322
1574253191.56811
1575253201.55932
plotgrid(res.nlv, res.y1; step = 2, xlabel = "Nb. LVs", ylabel = "RMSEP-CV").f
f, ax = plotgrid(res.nlv, res.y1; step = 2, xlabel = "Nb. LVs", ylabel = "RMSEP-CV")
for i = 1:rep, j = 1:K
    zres = res_rep[res_rep.rep .== i .&& res_rep.segm .== j, :]
    lines!(ax, zres.nlv, zres.y1; color = (:grey, .2))
end
lines!(ax, res.nlv, res.y1; color = :red, linewidth = 1)
f

Selection of the best parameter combination

u = findall(res.y1 .== minimum(res.y1))[1] 
res[u, :]
DataFrameRow (2 columns)
Rownlvy1
Int64Float64
762.2196

Final prediction (Test) using the optimal model

model = plskern(nlv = res.nlv[u])
fit!(model, Xtrain, ytrain)
pred = predict(model, Xtest).pred
63×1 Matrix{Float64}:
 27.396108728712
  4.4912779253048605
  4.9900980331906455
 11.016752894644581
 13.6730456055306
 20.679552141109767
 25.814043422530865
 54.06157064233883
  9.87588864713091
  5.04840120243277
  ⋮
 17.006189347070176
 17.53307437947076
 20.679445265823034
 26.951106243297836
 25.25250542344898
 32.719328178913216
 33.196189093807185
 37.01967890654217
 46.24659766456356

Generalization error

rmsep(pred, ytest)
1×1 Matrix{Float64}:
 2.075776908672637

Plotting predictions vs. observed data

plotxy(pred, ytest; size = (500, 400), color = (:red, .5), bisect = true, title = "Test set - variable $nam", 
    xlabel = "Prediction", ylabel = "Observed").f

Additional parameters in the grid

If other parameters have to be defined in the grid, they have to be set in argument pars built with function mpar, such as in the example below.

pars = mpar(scal = [false; true])
nlv = 0:20
model = plskern()
res = gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars, nlv).res
42×3 DataFrame
17 rows omitted
Rownlvscaly1
Int64BoolFloat64
10false12.6693
20true12.6693
31false2.5128
41true2.76265
52false2.34379
62true2.35703
73false2.31464
83true2.42326
94false2.24089
104true2.47409
115false2.22243
125true2.39829
136false2.2196
3115false2.82116
3215true2.83473
3316false2.80642
3416true2.84314
3517false2.82318
3617true2.8986
3718false2.82288
3818true3.02593
3919false2.90825
4019true3.2283
4120false3.00598
4220true3.3288
plotgrid(res.nlv, res.y1, res.scal; step = 2, xlabel = "Nb. LVs", ylabel = "RMSEP-Val").f
u = findall(res.y1 .== minimum(res.y1))[1] 
res[u, :]
DataFrameRow (3 columns)
Rownlvscaly1
Int64BoolFloat64
136false2.2196
model = plskern(nlv = res.nlv[u], scal = res.scal[u])
fit!(model, Xtrain, ytrain)
pred = predict(model, Xtest).pred
63×1 Matrix{Float64}:
 27.396108728712
  4.4912779253048605
  4.9900980331906455
 11.016752894644581
 13.6730456055306
 20.679552141109767
 25.814043422530865
 54.06157064233883
  9.87588864713091
  5.04840120243277
  ⋮
 17.006189347070176
 17.53307437947076
 20.679445265823034
 26.951106243297836
 25.25250542344898
 32.719328178913216
 33.196189093807185
 37.01967890654217
 46.24659766456356

Naïve (not time-efficient) syntax to use gridcv for LV-based functions

Parameter nlv can also be set in argument pars (wich is the generic approach to define the grid). This is strictly equivalent (gives the same results) but the computations are slower.

pars = mpar(nlv = 0:20)
model = plskern()
gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars).res
21×2 DataFrame
Rownlvy1
Int64Float64
1012.6693
212.5128
322.34379
432.31464
542.24089
652.22243
762.2196
872.27627
982.27375
1092.22622
11102.26324
12112.3036
13122.47049
14132.6224
15142.73301
16152.82116
17162.80642
18172.82318
19182.82288
20192.90825
21203.00598
pars = mpar(nlv = 0:20, scal = [false; true])
model = plskern()
gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars).res
42×3 DataFrame
17 rows omitted
Rownlvscaly1
IntegerIntegerFloat64
10false12.6693
20true12.6693
31false2.5128
41true2.76265
52false2.34379
62true2.35703
73false2.31464
83true2.42326
94false2.24089
104true2.47409
115false2.22243
125true2.39829
136false2.2196
3115false2.82116
3215true2.83473
3316false2.80642
3416true2.84314
3517false2.82318
3617true2.8986
3718false2.82288
3818true3.02593
3919false2.90825
4019true3.2283
4120false3.00598
4220true3.3288