gridcv - tecator - Rr

using Jchemo, JchemoData
using JLD2, CairoMakie

Data importation

path_jdat = dirname(dirname(pathof(JchemoData)))
db = joinpath(path_jdat, "data/tecator.jld2") 
@load db dat
@names dat
(:X, :Y)
X = dat.X
@head X
... (178, 100)
3×100 DataFrame
Row8508528548568588608628648668688708728748768788808828848868888908928948968989009029049069089109129149169189209229249269289309329349369389409429449469489509529549569589609629649669689709729749769789809829849869889909929949969981000100210041006100810101012101410161018102010221024102610281030103210341036103810401042104410461048
Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64
12.617762.618142.618592.619122.619812.620712.621862.623342.625112.627222.629642.632452.635652.639332.643532.648252.65352.659372.665852.672812.680082.687332.694272.700732.706842.712812.719142.726282.734622.744162.754662.765682.776792.78792.799492.812252.827062.843562.861062.878572.894972.909242.920852.930152.938462.947712.960192.978313.003063.035063.074283.119633.168683.217713.262543.299883.328473.348993.363423.373793.381523.387413.391643.394183.39493.393663.390453.385413.378693.370413.360733.349793.337693.324433.310133.294873.278913.262323.245423.228283.21083.192873.174333.155033.134753.113393.091163.06853.045963.023933.002472.981452.960722.940132.919782.899662.879642.85962.83942.8192
22.834542.838712.842832.847052.851382.855872.86062.865662.870932.876612.882642.888982.895772.903082.910972.919532.928732.938632.949292.960722.972722.984932.99693.008333.01923.02993.041013.053453.067773.084163.102213.121063.139833.15813.176233.195193.215843.237473.258893.278353.293843.303623.306813.303933.2973.289253.284093.285053.293263.309233.332673.362513.396613.431883.464923.492953.514583.530043.540673.547973.553063.556753.559213.560453.560343.558763.555713.551323.545853.53953.532353.524423.515833.506683.4973.486833.476263.465523.455013.444813.434773.424653.414193.403033.390823.377313.362653.347453.332453.318183.304733.291863.279213.266553.253693.240453.226593.211813.1963.17942
32.582842.584582.586292.588082.589962.591922.594012.596272.598732.601312.604142.607142.610292.613612.617142.620892.624862.629092.633612.638352.64332.648382.653542.65872.663752.66882.673832.678922.684112.689372.69472.700122.705632.711412.717752.72492.733442.743272.754332.766422.779312.792722.806492.820642.835412.851212.868722.889052.912892.940882.973253.009463.04783.085543.119473.146963.166773.179383.186313.189243.18953.188013.184983.180393.174113.166113.156413.145123.132413.118433.103293.087143.070143.052373.033933.015042.995692.976122.956422.93662.916672.896552.876222.855632.834742.813612.792352.771132.750152.729562.709342.689512.670092.651122.632622.614612.597182.580342.564042.54816
Y = dat.Y 
@head Y
... (178, 4)
3×4 DataFrame
Rowwaterfatproteintyp
Float64Float64Float64String
160.522.516.7train
246.040.113.5train
371.08.420.5train
typ = Y.typ
tab(typ)
OrderedCollections.OrderedDict{String, Int64} with 3 entries:
  "test"  => 31
  "train" => 115
  "val"   => 32
wlst = names(X)
wl = parse.(Float64, wlst) 
#plotsp(X, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
100-element Vector{Float64}:
  850.0
  852.0
  854.0
  856.0
  858.0
  860.0
  862.0
  864.0
  866.0
  868.0
    ⋮
 1032.0
 1034.0
 1036.0
 1038.0
 1040.0
 1042.0
 1044.0
 1046.0
 1048.0

Preprocessing

model1 = snv()
model2 = savgol(npoint = 15, deriv = 2, degree = 3)
model = pip(model1, model2)
fit!(model, X)
@head Xp = transf(model, X) 
#plotsp(Xp, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
3×100 Matrix{Float64}:
 0.000397076  0.000495203  0.000598623  …  0.00827138  0.00917311  0.00946072
 0.00242055   0.00244366   0.00234233      0.00580631  0.00689249  0.00749316
 0.0011927    0.00122721   0.00120098      0.0101019   0.0108142   0.0108444
... (178, 100)

Split Tot to Train/Test

s = typ .== "train"
Xtrain = Xp[s, :] 
Ytrain = Y[s, :]
Xtest = rmrow(Xp, s)
Ytest = rmrow(Y, s)
ntrain = nro(Xtrain)
ntest = nro(Xtest)
ntot = ntrain + ntest
(ntot = ntot, ntrain, ntest)
(ntot = 178, ntrain = 115, ntest = 63)

Working response y

namy = names(Y)[1:3]
j = 2  
nam = namy[j]    # work on the j-th y-variable
ytrain = Ytrain[:, nam]
ytest = Ytest[:, nam]
63-element Vector{Float64}:
 29.8
  1.4
  4.6
 11.0
 17.0
 22.4
 27.9
 46.5
  6.1
  2.0
  ⋮
 18.1
 19.4
 24.8
 27.2
 28.4
 31.3
 33.8
 35.5
 42.5

CV-Segments for model tuning

Replicated K-fold CV

K = 3     # nb. folds (segments)
rep = 10  # nb. replications
segm = segmkf(ntrain, K; rep = rep)
10-element Vector{Vector{Vector{Int64}}}:
 [[2, 3, 4, 6, 14, 17, 20, 24, 28, 29  …  87, 88, 91, 93, 95, 99, 105, 107, 109, 112], [5, 10, 12, 16, 18, 22, 23, 27, 31, 36  …  101, 102, 103, 104, 106, 108, 110, 111, 114, 115], [1, 7, 8, 9, 11, 13, 15, 19, 21, 25  …  75, 78, 80, 82, 83, 92, 94, 98, 100, 113]]
 [[4, 9, 11, 13, 14, 18, 19, 23, 25, 26  …  86, 90, 94, 95, 96, 99, 107, 108, 110, 113], [1, 2, 3, 7, 8, 10, 12, 16, 17, 20  …  83, 93, 98, 100, 101, 103, 104, 105, 106, 109], [5, 6, 15, 21, 24, 27, 30, 33, 35, 37  …  88, 89, 91, 92, 97, 102, 111, 112, 114, 115]]
 [[1, 8, 12, 14, 16, 21, 23, 25, 27, 28  …  96, 97, 101, 102, 103, 105, 107, 108, 113, 114], [3, 4, 10, 11, 17, 18, 19, 22, 30, 32  …  78, 79, 80, 81, 86, 94, 98, 100, 106, 115], [2, 5, 6, 7, 9, 13, 15, 20, 24, 26  …  88, 89, 90, 95, 99, 104, 109, 110, 111, 112]]
 [[3, 6, 7, 8, 9, 17, 22, 27, 28, 32  …  78, 85, 86, 87, 88, 90, 93, 104, 110, 115], [4, 5, 10, 11, 12, 13, 14, 15, 18, 19  …  83, 89, 97, 102, 103, 106, 108, 109, 112, 113], [1, 2, 16, 20, 21, 23, 26, 29, 30, 35  …  95, 96, 98, 99, 100, 101, 105, 107, 111, 114]]
 [[3, 10, 13, 17, 18, 19, 21, 22, 26, 27  …  77, 79, 80, 82, 83, 95, 96, 98, 108, 113], [2, 4, 5, 8, 11, 12, 15, 24, 25, 30  …  85, 86, 89, 90, 93, 100, 103, 104, 105, 109], [1, 6, 7, 9, 14, 16, 20, 23, 37, 40  …  99, 101, 102, 106, 107, 110, 111, 112, 114, 115]]
 [[4, 7, 8, 9, 11, 12, 13, 18, 20, 24  …  96, 97, 98, 101, 106, 109, 111, 112, 113, 115], [1, 5, 6, 14, 15, 16, 21, 22, 23, 25  …  84, 87, 88, 92, 94, 95, 102, 103, 104, 114], [2, 3, 10, 17, 19, 27, 30, 31, 32, 33  …  81, 83, 85, 90, 99, 100, 105, 107, 108, 110]]
 [[5, 10, 13, 15, 16, 19, 20, 24, 25, 26  …  84, 88, 94, 95, 99, 100, 104, 106, 109, 110], [1, 6, 12, 17, 18, 22, 23, 27, 30, 32  …  89, 96, 97, 98, 102, 103, 107, 111, 113, 114], [2, 3, 4, 7, 8, 9, 11, 14, 21, 33  …  87, 90, 91, 92, 93, 101, 105, 108, 112, 115]]
 [[18, 20, 22, 24, 27, 30, 35, 39, 42, 43  …  88, 91, 94, 95, 98, 103, 107, 111, 112, 115], [1, 7, 10, 14, 15, 16, 17, 32, 38, 41  …  90, 92, 97, 100, 101, 104, 105, 109, 113, 114], [2, 3, 4, 5, 6, 8, 9, 11, 12, 13  …  80, 82, 84, 93, 96, 99, 102, 106, 108, 110]]
 [[4, 5, 6, 8, 14, 16, 22, 24, 27, 28  …  84, 89, 90, 93, 95, 97, 101, 104, 111, 112], [1, 10, 12, 15, 19, 21, 23, 35, 36, 38  …  91, 92, 94, 100, 103, 106, 107, 108, 113, 115], [2, 3, 7, 9, 11, 13, 17, 18, 20, 25  …  85, 86, 96, 98, 99, 102, 105, 109, 110, 114]]
 [[3, 5, 6, 7, 9, 11, 13, 14, 15, 17  …  87, 90, 91, 99, 100, 104, 105, 107, 109, 113], [1, 19, 20, 22, 23, 26, 28, 32, 33, 36  …  76, 88, 92, 96, 97, 98, 106, 110, 112, 114], [2, 4, 8, 10, 12, 16, 18, 21, 24, 34  …  89, 93, 94, 95, 101, 102, 103, 108, 111, 115]]

Grid-search

The best syntax to use function gridcv for LV-based functions (e.g. rr, krr, etc.) is to set parameter lb outside argument pars defining the grid (in general with function mpar). In that case, the computation time is reduced [See the naïve and non-optimized syntax at the end of this script]. This is the same principle when definining the parameter nlv in LV-based functions (eg. plskern, kplsr, lwplsr, etc.).

lb = 10.0.^(-15:.1:3)
model = rr()
rescv = gridcv(model, Xtrain, ytrain; segm, score = rmsep, lb)
@names rescv 
res = rescv.res
res_rep = rescv.res_rep
5430×4 DataFrame
5405 rows omitted
Rowrepsegmlby1
Int64Int64Float64Float64
1111.0e-153.82928
2111.25893e-153.82928
3111.58489e-153.82928
4111.99526e-153.82928
5112.51189e-153.82928
6113.16228e-153.82928
7113.98107e-153.82928
8115.01187e-153.82928
9116.30957e-153.82928
10117.94328e-153.82928
11111.0e-143.82928
12111.25893e-143.82928
13111.58489e-143.82928
541910379.432814.6815
5420103100.014.6815
5421103125.89314.6815
5422103158.48914.6815
5423103199.52614.6815
5424103251.18914.6815
5425103316.22814.6815
5426103398.10714.6815
5427103501.18714.6815
5428103630.95714.6815
5429103794.32814.6815
54301031000.014.6815
loglb = round.(log.(10, res.lb), digits = 3)
plotgrid(loglb, res.y1; step = 2, xlabel ="lambda (log scale)", ylabel = "RMSEP-Val").f
f, ax = plotgrid(loglb, res.y1; step = 2, xlabel = "Nb. LVs", ylabel = "RMSEP-CV")
for i = 1:rep, j = 1:K
    zres = res_rep[res_rep.rep .== i .&& res_rep.segm .== j, :]
    lines!(ax, loglb, zres.y1; color = (:grey, .2))
end
lines!(ax, loglb, res.y1; color = :red, linewidth = 1)
f

If other parameters have to be defined in the grid, they have to be set in argument pars built with function mpar, such as in the example below.

pars = mpar(scal = [false; true])
lb = 10.0.^(-15:.1:3)
#lb = logrange(1e-15, 1e3, 50)   # alternative syntax
model = rr()
res = gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars, lb).res
362×3 DataFrame
337 rows omitted
Rowlbscaly1
Float64BoolFloat64
11.0e-15false7.13373
21.25893e-15false7.13373
31.58489e-15false7.13373
41.99526e-15false7.13373
52.51189e-15false7.13373
63.16228e-15false7.13373
73.98107e-15false7.13373
85.01187e-15false7.13373
96.30957e-15false7.13373
107.94328e-15false7.13373
111.0e-14false7.13373
121.25893e-14false7.13373
131.58489e-14false7.13373
35179.4328true12.5418
352100.0true12.5869
353125.893true12.6155
354158.489true12.6336
355199.526true12.6451
356251.189true12.6524
357316.228true12.6569
358398.107true12.6598
359501.187true12.6617
360630.957true12.6628
361794.328true12.6635
3621000.0true12.664
loglb = round.(log.(10, res.lb), digits = 3)
plotgrid(loglb, res.y1, res.scal; step = 2, xlabel ="lambda (log scale)", ylabel = "RMSEP-Val").f

Selection of the best parameter combination

u = findall(res.y1 .== minimum(res.y1))[1] 
res[u, :]
DataFrameRow (3 columns)
Rowlbscaly1
Float64BoolFloat64
1130.000158489false2.16284

Final prediction (Test) using the optimal model

model = rr(nlv = res.lb[u], scal = res.scal[u])
fit!(model, Xtrain, ytrain)
pred = predict(model, Xtest).pred
63×1 Matrix{Float64}:
 27.876444002998458
  3.5904081173640385
  2.32548879045871
  9.169771876981347
 13.228383160517499
 20.88245631069842
 21.875931002536706
 47.68294255512646
  7.591807587883839
  2.0283951227110926
  ⋮
 18.070871456149778
 19.35569672742449
 21.75703948647831
 27.713990081253858
 27.393213664847824
 34.13829277337615
 33.47241171228406
 38.85986950484511
 44.4835756075456

Generalization error

rmsep(pred, ytest)
1×1 Matrix{Float64}:
 2.8631045531296424

Plotting predictions vs. observed data

plotxy(pred, ytest; size = (500, 400), color = (:red, .5), bisect = true, title = "Test set - variable $nam", 
    xlabel = "Prediction", ylabel = "Observed").f

Naïve (not time-efficient) syntax to use gridcv for LV-based functions

Parameter lb can also be set in argument pars (wich is the generic approach to define the grid). This is strictly equivalent (gives the same results) but the computations are slower.

pars = mpar(lb = 10.0.^(-15:.1:3))
model = rr()
gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars).res
181×2 DataFrame
156 rows omitted
Rowlby1
Float64Float64
11.0e-157.13373
21.25893e-157.13373
31.58489e-157.13373
41.99526e-157.13373
52.51189e-157.13373
63.16228e-157.13373
73.98107e-157.13373
85.01187e-157.13373
96.30957e-157.13373
107.94328e-157.13373
111.0e-147.13373
121.25893e-147.13373
131.58489e-147.13373
17079.432812.6648
171100.012.6648
172125.89312.6648
173158.48912.6648
174199.52612.6648
175251.18912.6648
176316.22812.6648
177398.10712.6648
178501.18712.6648
179630.95712.6648
180794.32812.6648
1811000.012.6648
pars = mpar(lb = 10.0.^(-15:.1:3), scal = [false; true])
model = rr()
gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars).res
362×3 DataFrame
337 rows omitted
Rowlbscaly1
RealRealFloat64
11.0e-15false7.13373
21.25893e-15false7.13373
31.58489e-15false7.13373
41.99526e-15false7.13373
52.51189e-15false7.13373
63.16228e-15false7.13373
73.98107e-15false7.13373
85.01187e-15false7.13373
96.30957e-15false7.13373
107.94328e-15false7.13373
111.0e-14false7.13373
121.25893e-14false7.13373
131.58489e-14false7.13373
35179.4328true12.5418
352100.0true12.5869
353125.893true12.6155
354158.489true12.6336
355199.526true12.6451
356251.189true12.6524
357316.228true12.6569
358398.107true12.6598
359501.187true12.6617
360630.957true12.6628
361794.328true12.6635
3621000.0true12.664