gridcv - tecator - Plsr

using Jchemo, JchemoData
using JLD2, CairoMakie

Data importation

path_jdat = dirname(dirname(pathof(JchemoData)))
db = joinpath(path_jdat, "data/tecator.jld2") 
@load db dat
@names dat
(:X, :Y)
X = dat.X
@head X
... (178, 100)
3×100 DataFrame
Row8508528548568588608628648668688708728748768788808828848868888908928948968989009029049069089109129149169189209229249269289309329349369389409429449469489509529549569589609629649669689709729749769789809829849869889909929949969981000100210041006100810101012101410161018102010221024102610281030103210341036103810401042104410461048
Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64
12.617762.618142.618592.619122.619812.620712.621862.623342.625112.627222.629642.632452.635652.639332.643532.648252.65352.659372.665852.672812.680082.687332.694272.700732.706842.712812.719142.726282.734622.744162.754662.765682.776792.78792.799492.812252.827062.843562.861062.878572.894972.909242.920852.930152.938462.947712.960192.978313.003063.035063.074283.119633.168683.217713.262543.299883.328473.348993.363423.373793.381523.387413.391643.394183.39493.393663.390453.385413.378693.370413.360733.349793.337693.324433.310133.294873.278913.262323.245423.228283.21083.192873.174333.155033.134753.113393.091163.06853.045963.023933.002472.981452.960722.940132.919782.899662.879642.85962.83942.8192
22.834542.838712.842832.847052.851382.855872.86062.865662.870932.876612.882642.888982.895772.903082.910972.919532.928732.938632.949292.960722.972722.984932.99693.008333.01923.02993.041013.053453.067773.084163.102213.121063.139833.15813.176233.195193.215843.237473.258893.278353.293843.303623.306813.303933.2973.289253.284093.285053.293263.309233.332673.362513.396613.431883.464923.492953.514583.530043.540673.547973.553063.556753.559213.560453.560343.558763.555713.551323.545853.53953.532353.524423.515833.506683.4973.486833.476263.465523.455013.444813.434773.424653.414193.403033.390823.377313.362653.347453.332453.318183.304733.291863.279213.266553.253693.240453.226593.211813.1963.17942
32.582842.584582.586292.588082.589962.591922.594012.596272.598732.601312.604142.607142.610292.613612.617142.620892.624862.629092.633612.638352.64332.648382.653542.65872.663752.66882.673832.678922.684112.689372.69472.700122.705632.711412.717752.72492.733442.743272.754332.766422.779312.792722.806492.820642.835412.851212.868722.889052.912892.940882.973253.009463.04783.085543.119473.146963.166773.179383.186313.189243.18953.188013.184983.180393.174113.166113.156413.145123.132413.118433.103293.087143.070143.052373.033933.015042.995692.976122.956422.93662.916672.896552.876222.855632.834742.813612.792352.771132.750152.729562.709342.689512.670092.651122.632622.614612.597182.580342.564042.54816
Y = dat.Y 
@head Y
... (178, 4)
3×4 DataFrame
Rowwaterfatproteintyp
Float64Float64Float64String
160.522.516.7train
246.040.113.5train
371.08.420.5train
typ = Y.typ
tab(typ)
OrderedCollections.OrderedDict{String, Int64} with 3 entries:
  "test"  => 31
  "train" => 115
  "val"   => 32
wlst = names(X)
wl = parse.(Float64, wlst) 
#plotsp(X, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
100-element Vector{Float64}:
  850.0
  852.0
  854.0
  856.0
  858.0
  860.0
  862.0
  864.0
  866.0
  868.0
    ⋮
 1032.0
 1034.0
 1036.0
 1038.0
 1040.0
 1042.0
 1044.0
 1046.0
 1048.0

Preprocessing

model1 = snv()
model2 = savgol(npoint = 15, deriv = 2, degree = 3)
model = pip(model1, model2)
fit!(model, X)
@head Xp = transf(model, X) 
#plotsp(Xp, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
3×100 Matrix{Float64}:
 0.000397076  0.000495203  0.000598623  …  0.00827138  0.00917311  0.00946072
 0.00242055   0.00244366   0.00234233      0.00580631  0.00689249  0.00749316
 0.0011927    0.00122721   0.00120098      0.0101019   0.0108142   0.0108444
... (178, 100)

Split Tot to Train/Test

The model is fitted on Train, and the generalization error is estimated on Test. In this example, Train is already defined in variable typ of the dataset, and Test is defined by the remaining samples. But Tot could also be split a posteriori, for instance by sampling (random, systematic or any other designs). See for instance functions samprand, sampsys, etc.

s = typ .== "train"
Xtrain = Xp[s, :] 
Ytrain = Y[s, :]
Xtest = rmrow(Xp, s)
Ytest = rmrow(Y, s)
ntrain = nro(Xtrain)
ntest = nro(Xtest)
ntot = ntrain + ntest
(ntot = ntot, ntrain, ntest)
(ntot = 178, ntrain = 115, ntest = 63)

Working response y

namy = names(Y)[1:3]
j = 2  
nam = namy[j]    # work on the j-th y-variable
ytrain = Ytrain[:, nam]
ytest = Ytest[:, nam]
63-element Vector{Float64}:
 29.8
  1.4
  4.6
 11.0
 17.0
 22.4
 27.9
 46.5
  6.1
  2.0
  ⋮
 18.1
 19.4
 24.8
 27.2
 28.4
 31.3
 33.8
 35.5
 42.5

CV-Segments for model tuning

Two methods can be used to build the CV segments within Train (for the same total number of segments, these two methods return close resuts):

  • (1) Replicated K-fold CV

    • Train is splitted in a number of K folds (segments),

    • and this split can be replicated (==> replicated K-Fold CV).

K = 3     # nb. folds (segments)
rep = 25  # nb. replications
segm = segmkf(ntrain, K; rep = rep)
25-element Vector{Vector{Vector{Int64}}}:
 [[3, 4, 5, 12, 14, 16, 19, 25, 28, 31  …  95, 96, 97, 103, 104, 105, 106, 109, 112, 114], [6, 7, 8, 11, 15, 18, 20, 23, 26, 29  …  84, 86, 94, 98, 102, 107, 108, 110, 113, 115], [1, 2, 9, 10, 13, 17, 21, 22, 24, 27  …  82, 85, 87, 88, 89, 92, 99, 100, 101, 111]]
 [[4, 6, 17, 18, 19, 21, 23, 26, 27, 29  …  92, 93, 96, 104, 107, 110, 112, 113, 114, 115], [12, 13, 14, 16, 20, 24, 25, 28, 34, 39  …  91, 95, 97, 99, 100, 101, 102, 105, 106, 108], [1, 2, 3, 5, 7, 8, 9, 10, 11, 15  …  82, 83, 85, 88, 89, 94, 98, 103, 109, 111]]
 [[2, 8, 13, 15, 22, 24, 26, 28, 29, 30  …  90, 93, 96, 100, 103, 105, 106, 107, 108, 115], [3, 4, 5, 6, 9, 11, 12, 14, 16, 17  …  69, 81, 82, 86, 91, 95, 102, 110, 111, 112], [1, 7, 10, 19, 20, 23, 27, 33, 35, 45  …  92, 94, 97, 98, 99, 101, 104, 109, 113, 114]]
 [[1, 3, 9, 11, 12, 13, 16, 17, 19, 22  …  84, 85, 91, 92, 94, 98, 100, 101, 112, 114], [5, 7, 18, 21, 24, 30, 31, 34, 38, 40  …  89, 90, 93, 96, 97, 99, 102, 104, 107, 115], [2, 4, 6, 8, 10, 14, 15, 20, 23, 26  …  87, 95, 103, 105, 106, 108, 109, 110, 111, 113]]
 [[2, 4, 6, 12, 14, 16, 17, 23, 29, 30  …  85, 87, 93, 94, 96, 97, 99, 105, 110, 115], [1, 3, 7, 8, 18, 24, 25, 26, 32, 37  …  90, 91, 95, 98, 100, 104, 107, 109, 111, 113], [5, 9, 10, 11, 13, 15, 19, 20, 21, 22  …  86, 89, 92, 101, 102, 103, 106, 108, 112, 114]]
 [[4, 6, 7, 8, 9, 11, 12, 14, 18, 21  …  79, 94, 96, 98, 101, 102, 103, 111, 114, 115], [2, 3, 10, 13, 17, 19, 24, 25, 26, 27  …  66, 67, 78, 80, 87, 100, 104, 107, 108, 113], [1, 5, 15, 16, 20, 34, 36, 37, 39, 45  …  92, 93, 95, 97, 99, 105, 106, 109, 110, 112]]
 [[3, 4, 8, 17, 18, 19, 21, 25, 26, 28  …  99, 100, 101, 103, 105, 108, 110, 112, 114, 115], [1, 2, 10, 13, 14, 16, 20, 22, 24, 34  …  86, 88, 89, 90, 91, 92, 102, 106, 111, 113], [5, 6, 7, 9, 11, 12, 15, 23, 27, 31  …  71, 73, 74, 75, 79, 80, 87, 104, 107, 109]]
 [[5, 12, 16, 18, 24, 27, 28, 29, 31, 35  …  92, 94, 96, 101, 102, 104, 109, 111, 112, 113], [3, 4, 6, 7, 10, 11, 13, 15, 21, 26  …  85, 86, 93, 98, 100, 103, 107, 108, 110, 114], [1, 2, 8, 9, 14, 17, 19, 20, 22, 23  …  87, 89, 90, 91, 95, 97, 99, 105, 106, 115]]
 [[1, 7, 9, 12, 13, 15, 16, 21, 28, 29  …  86, 87, 92, 93, 94, 97, 100, 102, 111, 113], [2, 8, 10, 11, 14, 19, 23, 24, 25, 27  …  95, 98, 101, 104, 106, 107, 108, 110, 114, 115], [3, 4, 5, 6, 17, 18, 20, 22, 26, 30  …  78, 79, 85, 89, 96, 99, 103, 105, 109, 112]]
 [[1, 3, 4, 6, 8, 10, 16, 18, 21, 22  …  84, 88, 90, 93, 99, 102, 104, 106, 111, 115], [5, 11, 17, 19, 20, 25, 27, 30, 39, 40  …  94, 95, 96, 98, 103, 107, 108, 110, 113, 114], [2, 7, 9, 12, 13, 14, 15, 26, 28, 31  …  82, 86, 87, 91, 97, 100, 101, 105, 109, 112]]
 ⋮
 [[5, 6, 8, 12, 13, 21, 22, 27, 28, 30  …  89, 91, 97, 98, 104, 108, 109, 110, 113, 114], [1, 4, 11, 16, 17, 19, 23, 25, 26, 32  …  84, 90, 93, 94, 100, 105, 106, 107, 111, 115], [2, 3, 7, 9, 10, 14, 15, 18, 20, 24  …  86, 87, 92, 95, 96, 99, 101, 102, 103, 112]]
 [[5, 10, 11, 12, 13, 15, 26, 28, 30, 31  …  79, 80, 84, 90, 91, 95, 104, 111, 112, 115], [4, 6, 14, 17, 18, 20, 22, 27, 32, 36  …  89, 93, 94, 97, 98, 99, 100, 101, 106, 107], [1, 2, 3, 7, 8, 9, 16, 19, 21, 23  …  92, 96, 102, 103, 105, 108, 109, 110, 113, 114]]
 [[10, 17, 23, 25, 26, 29, 31, 33, 34, 37  …  92, 97, 100, 101, 103, 105, 107, 109, 114, 115], [3, 4, 5, 6, 8, 9, 12, 24, 27, 30  …  80, 85, 89, 96, 98, 99, 108, 111, 112, 113], [1, 2, 7, 11, 13, 14, 15, 16, 18, 19  …  84, 86, 91, 93, 94, 95, 102, 104, 106, 110]]
 [[2, 4, 12, 15, 19, 26, 27, 29, 30, 37  …  85, 93, 94, 95, 99, 100, 104, 110, 114, 115], [5, 6, 7, 8, 9, 11, 20, 21, 22, 25  …  77, 89, 92, 96, 102, 103, 105, 108, 109, 112], [1, 3, 10, 13, 14, 16, 17, 18, 23, 24  …  88, 90, 91, 97, 98, 101, 106, 107, 111, 113]]
 [[3, 7, 11, 12, 15, 21, 22, 23, 25, 27  …  97, 100, 101, 103, 104, 106, 107, 109, 113, 114], [1, 2, 6, 8, 10, 13, 18, 19, 24, 28  …  85, 88, 89, 94, 96, 98, 105, 111, 112, 115], [4, 5, 9, 14, 16, 17, 20, 26, 30, 32  …  78, 81, 86, 91, 93, 95, 99, 102, 108, 110]]
 [[5, 7, 11, 12, 15, 16, 17, 24, 28, 30  …  84, 85, 91, 93, 96, 101, 104, 109, 113, 115], [3, 8, 9, 18, 19, 22, 23, 26, 27, 29  …  82, 86, 88, 95, 98, 100, 103, 105, 108, 114], [1, 2, 4, 6, 10, 13, 14, 20, 21, 25  …  92, 94, 97, 99, 102, 106, 107, 110, 111, 112]]
 [[1, 3, 4, 7, 8, 12, 14, 19, 20, 21  …  75, 81, 82, 89, 95, 98, 99, 104, 110, 114], [10, 11, 16, 18, 24, 31, 33, 34, 35, 40  …  96, 100, 103, 105, 107, 108, 109, 111, 112, 115], [2, 5, 6, 9, 13, 15, 17, 22, 25, 27  …  86, 90, 92, 93, 94, 97, 101, 102, 106, 113]]
 [[4, 8, 14, 26, 27, 28, 32, 35, 37, 41  …  95, 96, 97, 99, 100, 102, 109, 111, 113, 114], [3, 5, 6, 7, 9, 11, 18, 20, 21, 22  …  78, 79, 82, 90, 91, 94, 106, 110, 112, 115], [1, 2, 10, 12, 13, 15, 16, 17, 19, 24  …  83, 92, 93, 98, 101, 103, 104, 105, 107, 108]]
 [[2, 12, 13, 14, 16, 21, 23, 24, 25, 27  …  88, 90, 96, 97, 99, 101, 104, 105, 108, 114], [1, 4, 5, 7, 9, 11, 15, 18, 31, 33  …  86, 87, 91, 93, 94, 100, 103, 106, 110, 115], [3, 6, 8, 10, 17, 19, 20, 22, 26, 28  …  89, 92, 95, 98, 102, 107, 109, 111, 112, 113]]
  • (2) Replicated "Test-set" CV

    • Train is split to Cal/Val (e.g. Cal = 70% of Train, Val = 30% of Train), and this is replicated.

#pct = .30
#m = Int(round(pct * ntrain))
#segm = segmts(ntrain, m; rep = 30)

Illustration of segments:

i = 1  
segm[i]      # the K segments of replication 'i'
3-element Vector{Vector{Int64}}:
 [3, 4, 5, 12, 14, 16, 19, 25, 28, 31  …  95, 96, 97, 103, 104, 105, 106, 109, 112, 114]
 [6, 7, 8, 11, 15, 18, 20, 23, 26, 29  …  84, 86, 94, 98, 102, 107, 108, 110, 113, 115]
 [1, 2, 9, 10, 13, 17, 21, 22, 24, 27  …  82, 85, 87, 88, 89, 92, 99, 100, 101, 111]
k = 1
segm[i][k]   # segment 'k' of replication 'i'
39-element Vector{Int64}:
   3
   4
   5
  12
  14
  16
  19
  25
  28
  31
   ⋮
  96
  97
 103
 104
 105
 106
 109
 112
 114

Grid-search

The best syntax to use function gridcv for LV-based functions (eg. plskern, kplsr, lwplsr, etc.) is to set parameter nlv outside its argument pars defining the grid (in general with function mpar). In that case, the computation time is reduced [See the naïve and non-optimized syntax at the end of this script]. This is the same principle when definining the parameter lb in ridge-based functions (rr, krr, etc.).

nlv = 0:20
model = plskern()
rescv = gridcv(model, Xtrain, ytrain; segm, score = rmsep, nlv)
@names rescv 
res = rescv.res
res_rep = rescv.res_rep
1575×4 DataFrame
1550 rows omitted
Rowrepsegmnlvy1
Int64Int64Int64Float64
111012.4628
21112.25499
31122.31123
41132.12503
51142.07529
61152.01878
71161.96414
81172.1926
91182.56157
101192.57482
1111102.58039
1211112.51361
1311122.85064
156425392.9
1565253102.98456
1566253113.33903
1567253123.50815
1568253133.7351
1569253144.2545
1570253154.18228
1571253164.213
1572253174.50353
1573253184.9846
1574253195.95916
1575253205.53443
plotgrid(res.nlv, res.y1; step = 2, xlabel = "Nb. LVs", ylabel = "RMSEP-CV").f
f, ax = plotgrid(res.nlv, res.y1; step = 2, xlabel = "Nb. LVs", ylabel = "RMSEP-CV")
for i = 1:rep, j = 1:K
    zres = res_rep[res_rep.rep .== i .&& res_rep.segm .== j, :]
    lines!(ax, zres.nlv, zres.y1; color = (:grey, .2))
end
lines!(ax, res.nlv, res.y1; color = :red, linewidth = 1)
f

Selection of the best parameter combination

u = findall(res.y1 .== minimum(res.y1))[1] 
res[u, :]
DataFrameRow (2 columns)
Rownlvy1
Int64Float64
762.20461

Final prediction (Test) using the optimal model

model = plskern(nlv = res.nlv[u])
fit!(model, Xtrain, ytrain)
pred = predict(model, Xtest).pred
63×1 Matrix{Float64}:
 27.396108728712
  4.4912779253048605
  4.9900980331906455
 11.016752894644581
 13.6730456055306
 20.679552141109767
 25.814043422530865
 54.06157064233883
  9.87588864713091
  5.04840120243277
  ⋮
 17.006189347070176
 17.53307437947076
 20.679445265823034
 26.951106243297836
 25.25250542344898
 32.719328178913216
 33.196189093807185
 37.01967890654217
 46.24659766456356

Generalization error

rmsep(pred, ytest)
1×1 Matrix{Float64}:
 2.075776908672637

Plotting predictions vs. observed data

plotxy(pred, ytest; size = (500, 400), color = (:red, .5), bisect = true, title = string("Test set - variable ", nam), 
    xlabel = "Prediction", ylabel = "Observed").f

Additional parameters in the grid

If other parameters have to be defined in the grid, they have to be set in argument pars built with function mpar, such as in the example below.

pars = mpar(scal = [false; true])
nlv = 0:20
model = plskern()
res = gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars, nlv).res
42×3 DataFrame
17 rows omitted
Rownlvscaly1
Int64BoolFloat64
10false12.6275
20true12.6275
31false2.50954
41true2.76545
52false2.33918
62true2.3499
73false2.30427
83true2.39444
94false2.24238
104true2.43105
115false2.20793
125true2.36566
136false2.20461
3115false2.98664
3215true3.01468
3316false2.93406
3416true2.99211
3517false2.90731
3617true3.05365
3718false2.96206
3818true3.17214
3919false3.06006
4019true3.40002
4120false3.21341
4220true3.59739
plotgrid(res.nlv, res.y1, res.scal; step = 2, xlabel = "Nb. LVs", ylabel = "RMSEP-Val").f
u = findall(res.y1 .== minimum(res.y1))[1] 
res[u, :]
DataFrameRow (3 columns)
Rownlvscaly1
Int64BoolFloat64
136false2.20461
model = plskern(nlv = res.nlv[u], scal = res.scal[u])
fit!(model, Xtrain, ytrain)
pred = predict(model, Xtest).pred
63×1 Matrix{Float64}:
 27.396108728712
  4.4912779253048605
  4.9900980331906455
 11.016752894644581
 13.6730456055306
 20.679552141109767
 25.814043422530865
 54.06157064233883
  9.87588864713091
  5.04840120243277
  ⋮
 17.006189347070176
 17.53307437947076
 20.679445265823034
 26.951106243297836
 25.25250542344898
 32.719328178913216
 33.196189093807185
 37.01967890654217
 46.24659766456356

Naïve syntax to use gridcv for LV-based functions

Parameter nlv can also be set in argument pars (wich is the generic approach to define the grid). This is strictly equivalent (gives the same results) but the computations are slower.

pars = mpar(nlv = 0:20)
model = plskern()
gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars).res
21×2 DataFrame
Rownlvy1
Int64Float64
1012.6275
212.50954
322.33918
432.30427
542.24238
652.20793
762.20461
872.24701
982.27596
1092.2262
11102.25345
12112.32985
13122.48538
14132.68772
15142.89272
16152.98664
17162.93406
18172.90731
19182.96206
20193.06006
21203.21341
pars = mpar(nlv = 0:20, scal = [false; true])
model = plskern()
gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars).res
42×3 DataFrame
17 rows omitted
Rownlvscaly1
IntegerIntegerFloat64
10false12.6275
20true12.6275
31false2.50954
41true2.76545
52false2.33918
62true2.3499
73false2.30427
83true2.39444
94false2.24238
104true2.43105
115false2.20793
125true2.36566
136false2.20461
3115false2.98664
3215true3.01468
3316false2.93406
3416true2.99211
3517false2.90731
3617true3.05365
3718false2.96206
3818true3.17214
3919false3.06006
4019true3.40002
4120false3.21341
4220true3.59739