gridcv - tecator - Rr

using Jchemo, JchemoData
using JLD2, CairoMakie

Data importation

path_jdat = dirname(dirname(pathof(JchemoData)))
db = joinpath(path_jdat, "data/tecator.jld2") 
@load db dat
@names dat
(:X, :Y)
X = dat.X
@head X
... (178, 100)
3×100 DataFrame
Row8508528548568588608628648668688708728748768788808828848868888908928948968989009029049069089109129149169189209229249269289309329349369389409429449469489509529549569589609629649669689709729749769789809829849869889909929949969981000100210041006100810101012101410161018102010221024102610281030103210341036103810401042104410461048
Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64
12.617762.618142.618592.619122.619812.620712.621862.623342.625112.627222.629642.632452.635652.639332.643532.648252.65352.659372.665852.672812.680082.687332.694272.700732.706842.712812.719142.726282.734622.744162.754662.765682.776792.78792.799492.812252.827062.843562.861062.878572.894972.909242.920852.930152.938462.947712.960192.978313.003063.035063.074283.119633.168683.217713.262543.299883.328473.348993.363423.373793.381523.387413.391643.394183.39493.393663.390453.385413.378693.370413.360733.349793.337693.324433.310133.294873.278913.262323.245423.228283.21083.192873.174333.155033.134753.113393.091163.06853.045963.023933.002472.981452.960722.940132.919782.899662.879642.85962.83942.8192
22.834542.838712.842832.847052.851382.855872.86062.865662.870932.876612.882642.888982.895772.903082.910972.919532.928732.938632.949292.960722.972722.984932.99693.008333.01923.02993.041013.053453.067773.084163.102213.121063.139833.15813.176233.195193.215843.237473.258893.278353.293843.303623.306813.303933.2973.289253.284093.285053.293263.309233.332673.362513.396613.431883.464923.492953.514583.530043.540673.547973.553063.556753.559213.560453.560343.558763.555713.551323.545853.53953.532353.524423.515833.506683.4973.486833.476263.465523.455013.444813.434773.424653.414193.403033.390823.377313.362653.347453.332453.318183.304733.291863.279213.266553.253693.240453.226593.211813.1963.17942
32.582842.584582.586292.588082.589962.591922.594012.596272.598732.601312.604142.607142.610292.613612.617142.620892.624862.629092.633612.638352.64332.648382.653542.65872.663752.66882.673832.678922.684112.689372.69472.700122.705632.711412.717752.72492.733442.743272.754332.766422.779312.792722.806492.820642.835412.851212.868722.889052.912892.940882.973253.009463.04783.085543.119473.146963.166773.179383.186313.189243.18953.188013.184983.180393.174113.166113.156413.145123.132413.118433.103293.087143.070143.052373.033933.015042.995692.976122.956422.93662.916672.896552.876222.855632.834742.813612.792352.771132.750152.729562.709342.689512.670092.651122.632622.614612.597182.580342.564042.54816
Y = dat.Y 
@head Y
... (178, 4)
3×4 DataFrame
Rowwaterfatproteintyp
Float64Float64Float64String
160.522.516.7train
246.040.113.5train
371.08.420.5train
typ = Y.typ
tab(typ)
OrderedCollections.OrderedDict{String, Int64} with 3 entries:
  "test"  => 31
  "train" => 115
  "val"   => 32
wlst = names(X)
wl = parse.(Float64, wlst) 
#plotsp(X, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
100-element Vector{Float64}:
  850.0
  852.0
  854.0
  856.0
  858.0
  860.0
  862.0
  864.0
  866.0
  868.0
    ⋮
 1032.0
 1034.0
 1036.0
 1038.0
 1040.0
 1042.0
 1044.0
 1046.0
 1048.0

Preprocessing

model1 = snv()
model2 = savgol(npoint = 15, deriv = 2, degree = 3)
model = pip(model1, model2)
fit!(model, X)
@head Xp = transf(model, X) 
#plotsp(Xp, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
3×100 Matrix{Float64}:
 0.000397076  0.000495203  0.000598623  …  0.00827138  0.00917311  0.00946072
 0.00242055   0.00244366   0.00234233      0.00580631  0.00689249  0.00749316
 0.0011927    0.00122721   0.00120098      0.0101019   0.0108142   0.0108444
... (178, 100)

Split Tot to Train/Test

s = typ .== "train"
Xtrain = Xp[s, :] 
Ytrain = Y[s, :]
Xtest = rmrow(Xp, s)
Ytest = rmrow(Y, s)
ntrain = nro(Xtrain)
ntest = nro(Xtest)
ntot = ntrain + ntest
(ntot = ntot, ntrain, ntest)
(ntot = 178, ntrain = 115, ntest = 63)

Working response y

namy = names(Y)[1:3]
j = 2  
nam = namy[j]    # work on the j-th y-variable
ytrain = Ytrain[:, nam]
ytest = Ytest[:, nam]
63-element Vector{Float64}:
 29.8
  1.4
  4.6
 11.0
 17.0
 22.4
 27.9
 46.5
  6.1
  2.0
  ⋮
 18.1
 19.4
 24.8
 27.2
 28.4
 31.3
 33.8
 35.5
 42.5

CV-Segments for model tuning

Replicated K-fold CV

K = 3     # nb. folds (segments)
rep = 10  # nb. replications
segm = segmkf(ntrain, K; rep = rep)
10-element Vector{Vector{Vector{Int64}}}:
 [[4, 6, 13, 14, 15, 23, 24, 33, 34, 35  …  98, 99, 100, 102, 104, 105, 107, 108, 109, 114], [1, 2, 3, 7, 8, 9, 10, 18, 22, 25  …  86, 90, 92, 93, 94, 101, 110, 111, 112, 115], [5, 11, 12, 16, 17, 19, 20, 21, 26, 27  …  64, 65, 78, 79, 80, 87, 96, 103, 106, 113]]
 [[1, 7, 9, 10, 11, 15, 17, 18, 24, 29  …  85, 89, 101, 104, 105, 108, 111, 113, 114, 115], [2, 4, 5, 6, 12, 13, 14, 16, 19, 21  …  81, 86, 88, 93, 97, 98, 103, 107, 110, 112], [3, 8, 20, 22, 25, 26, 27, 28, 30, 32  …  91, 92, 94, 95, 96, 99, 100, 102, 106, 109]]
 [[3, 5, 12, 14, 23, 24, 29, 30, 32, 33  …  94, 96, 98, 99, 100, 103, 104, 111, 113, 114], [10, 13, 16, 17, 19, 20, 21, 22, 25, 27  …  80, 83, 84, 86, 91, 95, 97, 102, 105, 107], [1, 2, 4, 6, 7, 8, 9, 11, 15, 18  …  89, 92, 93, 101, 106, 108, 109, 110, 112, 115]]
 [[2, 6, 8, 10, 12, 16, 17, 20, 23, 24  …  98, 99, 101, 103, 105, 106, 107, 108, 110, 113], [3, 5, 7, 9, 11, 18, 21, 28, 29, 31  …  86, 88, 92, 102, 104, 109, 111, 112, 114, 115], [1, 4, 13, 14, 15, 19, 22, 25, 30, 33  …  75, 77, 79, 84, 87, 90, 91, 93, 96, 100]]
 [[5, 10, 11, 14, 15, 18, 19, 22, 28, 30  …  84, 86, 90, 91, 93, 100, 102, 104, 109, 115], [1, 2, 3, 6, 8, 16, 17, 20, 21, 23  …  89, 96, 98, 101, 103, 106, 107, 108, 111, 114], [4, 7, 9, 12, 13, 25, 26, 27, 29, 31  …  88, 92, 94, 95, 97, 99, 105, 110, 112, 113]]
 [[4, 5, 7, 8, 11, 13, 21, 29, 30, 31  …  86, 90, 99, 102, 108, 110, 112, 113, 114, 115], [3, 14, 15, 16, 17, 18, 19, 20, 22, 23  …  87, 89, 91, 92, 93, 94, 97, 103, 105, 111], [1, 2, 6, 9, 10, 12, 26, 28, 32, 33  …  88, 95, 96, 98, 100, 101, 104, 106, 107, 109]]
 [[7, 8, 11, 13, 15, 20, 25, 29, 30, 31  …  90, 93, 94, 97, 98, 101, 102, 104, 109, 114], [2, 5, 6, 10, 14, 16, 18, 19, 21, 22  …  84, 87, 95, 96, 105, 108, 110, 111, 113, 115], [1, 3, 4, 9, 12, 17, 23, 27, 28, 33  …  79, 83, 91, 92, 99, 100, 103, 106, 107, 112]]
 [[1, 2, 8, 9, 18, 19, 26, 29, 30, 31  …  83, 87, 90, 92, 94, 104, 107, 111, 112, 114], [3, 4, 10, 12, 13, 22, 23, 24, 25, 28  …  96, 97, 99, 101, 103, 105, 106, 108, 110, 115], [5, 6, 7, 11, 14, 15, 16, 17, 20, 21  …  84, 86, 89, 93, 95, 98, 100, 102, 109, 113]]
 [[2, 5, 7, 9, 11, 14, 20, 24, 27, 28  …  93, 94, 95, 97, 98, 103, 106, 111, 112, 115], [4, 8, 13, 15, 16, 17, 18, 19, 21, 25  …  82, 85, 88, 90, 101, 104, 108, 110, 113, 114], [1, 3, 6, 10, 12, 22, 23, 29, 33, 34  …  87, 89, 91, 96, 99, 100, 102, 105, 107, 109]]
 [[19, 20, 21, 22, 24, 28, 29, 31, 39, 40  …  85, 90, 92, 98, 101, 102, 103, 108, 114, 115], [1, 5, 6, 9, 10, 11, 13, 14, 18, 23  …  94, 95, 99, 104, 105, 106, 107, 109, 111, 112], [2, 3, 4, 7, 8, 12, 15, 16, 17, 25  …  86, 87, 89, 91, 93, 96, 97, 100, 110, 113]]

Grid-search

The best syntax to use function gridcv for LV-based functions (e.g. rr, krr, etc.) is to set parameter lb outside its argument pars defining the grid (in general with function mpar). In that case, the computation time is reduced [See the naïve and non-optimized syntax at the end of this script]. This is the same principle when definining the parameter nlv in LV-based functions (eg. plskern, kplsr, lwplsr, etc.).

lb = 10.0.^(-15:.1:3)
model = rr()
rescv = gridcv(model, Xtrain, ytrain; segm, score = rmsep, lb)
@names rescv 
res = rescv.res
res_rep = rescv.res_rep
5430×4 DataFrame
5405 rows omitted
Rowrepsegmlby1
Int64Int64Float64Float64
1111.0e-157.72914
2111.25893e-157.72914
3111.58489e-157.72914
4111.99526e-157.72914
5112.51189e-157.72914
6113.16228e-157.72914
7113.98107e-157.72914
8115.01187e-157.72914
9116.30957e-157.72914
10117.94328e-157.72914
11111.0e-147.72914
12111.25893e-147.72914
13111.58489e-147.72914
541910379.432810.47
5420103100.010.47
5421103125.89310.47
5422103158.48910.47
5423103199.52610.47
5424103251.18910.47
5425103316.22810.47
5426103398.10710.47
5427103501.18710.47
5428103630.95710.47
5429103794.32810.47
54301031000.010.47
loglb = round.(log.(10, res.lb), digits = 3)
plotgrid(loglb, res.y1; step = 2, xlabel ="lambda (log scale)", ylabel = "RMSEP-Val").f
f, ax = plotgrid(loglb, res.y1; step = 2, xlabel = "Nb. LVs", ylabel = "RMSEP-CV")
for i = 1:rep, j = 1:K
    zres = res_rep[res_rep.rep .== i .&& res_rep.segm .== j, :]
    lines!(ax, loglb, zres.y1; color = (:grey, .2))
end
lines!(ax, loglb, res.y1; color = :red, linewidth = 1)
f

If other parameters have to be defined in the grid, they have to be set in argument pars, such as in the example below.

pars = mpar(scal = [false; true])
lb = 10.0.^(-15:.1:3)
model = rr()
res = gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars, lb).res
362×3 DataFrame
337 rows omitted
Rowlbscaly1
Float64BoolFloat64
11.0e-15false5.97984
21.25893e-15false5.97984
31.58489e-15false5.97984
41.99526e-15false5.97984
52.51189e-15false5.97984
63.16228e-15false5.97984
73.98107e-15false5.97984
85.01187e-15false5.97984
96.30957e-15false5.97984
107.94328e-15false5.97984
111.0e-14false5.97984
121.25893e-14false5.97984
131.58489e-14false5.97984
35179.4328true12.5259
352100.0true12.5709
353125.893true12.5996
354158.489true12.6177
355199.526true12.6291
356251.189true12.6364
357316.228true12.641
358398.107true12.6439
359501.187true12.6457
360630.957true12.6468
361794.328true12.6476
3621000.0true12.648
loglb = round.(log.(10, res.lb), digits = 3)
plotgrid(loglb, res.y1, res.scal; step = 2, xlabel ="lambda (log scale)", ylabel = "RMSEP-Val").f

Selection of the best parameter combination

u = findall(res.y1 .== minimum(res.y1))[1] 
res[u, :]
DataFrameRow (3 columns)
Rowlbscaly1
Float64BoolFloat64
1150.000251189false2.13767

Final prediction (Test) using the optimal model

model = rr(nlv = res.lb[u], scal = res.scal[u])
fit!(model, Xtrain, ytrain)
pred = predict(model, Xtest).pred
63×1 Matrix{Float64}:
 27.876444002998458
  3.5904081173640385
  2.32548879045871
  9.169771876981347
 13.228383160517499
 20.88245631069842
 21.875931002536706
 47.68294255512646
  7.591807587883839
  2.0283951227110926
  ⋮
 18.070871456149778
 19.35569672742449
 21.75703948647831
 27.713990081253858
 27.393213664847824
 34.13829277337615
 33.47241171228406
 38.85986950484511
 44.4835756075456

Generalization error

rmsep(pred, ytest)
1×1 Matrix{Float64}:
 2.8631045531296424

Plotting predictions vs. observed data

plotxy(pred, ytest; size = (500, 400), color = (:red, .5), bisect = true, title = string("Test set - variable ", nam), 
    xlabel = "Prediction", ylabel = "Observed").f

Naïve syntax to use gridcv for ridge-based functions

Parameter lb can also be set in argument pars (wich is the generic approach to define the grid). This is strictly equivalent (gives the same results) but the computations are slower.

pars = mpar(lb = 10.0.^(-15:.1:3))
model = rr()
gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars).res
181×2 DataFrame
156 rows omitted
Rowlby1
Float64Float64
11.0e-155.97984
21.25893e-155.97984
31.58489e-155.97984
41.99526e-155.97984
52.51189e-155.97984
63.16228e-155.97984
73.98107e-155.97984
85.01187e-155.97984
96.30957e-155.97984
107.94328e-155.97984
111.0e-145.97984
121.25893e-145.97984
131.58489e-145.97984
17079.432812.6488
171100.012.6488
172125.89312.6488
173158.48912.6488
174199.52612.6488
175251.18912.6488
176316.22812.6488
177398.10712.6488
178501.18712.6488
179630.95712.6488
180794.32812.6488
1811000.012.6488
pars = mpar(lb = 10.0.^(-15:.1:3), scal = [false; true])
model = rr()
gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars).res
362×3 DataFrame
337 rows omitted
Rowlbscaly1
RealRealFloat64
11.0e-15false5.97984
21.25893e-15false5.97984
31.58489e-15false5.97984
41.99526e-15false5.97984
52.51189e-15false5.97984
63.16228e-15false5.97984
73.98107e-15false5.97984
85.01187e-15false5.97984
96.30957e-15false5.97984
107.94328e-15false5.97984
111.0e-14false5.97984
121.25893e-14false5.97984
131.58489e-14false5.97984
35179.4328true12.5259
352100.0true12.5709
353125.893true12.5996
354158.489true12.6177
355199.526true12.6291
356251.189true12.6364
357316.228true12.641
358398.107true12.6439
359501.187true12.6457
360630.957true12.6468
361794.328true12.6476
3621000.0true12.648