using Jchemo, JchemoData using JLD2, CairoMakie
path_jdat = dirname(dirname(pathof(JchemoData))) db = joinpath(path_jdat, "data/tecator.jld2") @load db dat @names dat
(:X, :Y)
X = dat.X @head X
... (178, 100)
Row | 850 | 852 | 854 | 856 | 858 | 860 | 862 | 864 | 866 | 868 | 870 | 872 | 874 | 876 | 878 | 880 | 882 | 884 | 886 | 888 | 890 | 892 | 894 | 896 | 898 | 900 | 902 | 904 | 906 | 908 | 910 | 912 | 914 | 916 | 918 | 920 | 922 | 924 | 926 | 928 | 930 | 932 | 934 | 936 | 938 | 940 | 942 | 944 | 946 | 948 | 950 | 952 | 954 | 956 | 958 | 960 | 962 | 964 | 966 | 968 | 970 | 972 | 974 | 976 | 978 | 980 | 982 | 984 | 986 | 988 | 990 | 992 | 994 | 996 | 998 | 1000 | 1002 | 1004 | 1006 | 1008 | 1010 | 1012 | 1014 | 1016 | 1018 | 1020 | 1022 | 1024 | 1026 | 1028 | 1030 | 1032 | 1034 | 1036 | 1038 | 1040 | 1042 | 1044 | 1046 | 1048 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | |
1 | 2.61776 | 2.61814 | 2.61859 | 2.61912 | 2.61981 | 2.62071 | 2.62186 | 2.62334 | 2.62511 | 2.62722 | 2.62964 | 2.63245 | 2.63565 | 2.63933 | 2.64353 | 2.64825 | 2.6535 | 2.65937 | 2.66585 | 2.67281 | 2.68008 | 2.68733 | 2.69427 | 2.70073 | 2.70684 | 2.71281 | 2.71914 | 2.72628 | 2.73462 | 2.74416 | 2.75466 | 2.76568 | 2.77679 | 2.7879 | 2.79949 | 2.81225 | 2.82706 | 2.84356 | 2.86106 | 2.87857 | 2.89497 | 2.90924 | 2.92085 | 2.93015 | 2.93846 | 2.94771 | 2.96019 | 2.97831 | 3.00306 | 3.03506 | 3.07428 | 3.11963 | 3.16868 | 3.21771 | 3.26254 | 3.29988 | 3.32847 | 3.34899 | 3.36342 | 3.37379 | 3.38152 | 3.38741 | 3.39164 | 3.39418 | 3.3949 | 3.39366 | 3.39045 | 3.38541 | 3.37869 | 3.37041 | 3.36073 | 3.34979 | 3.33769 | 3.32443 | 3.31013 | 3.29487 | 3.27891 | 3.26232 | 3.24542 | 3.22828 | 3.2108 | 3.19287 | 3.17433 | 3.15503 | 3.13475 | 3.11339 | 3.09116 | 3.0685 | 3.04596 | 3.02393 | 3.00247 | 2.98145 | 2.96072 | 2.94013 | 2.91978 | 2.89966 | 2.87964 | 2.8596 | 2.8394 | 2.8192 |
2 | 2.83454 | 2.83871 | 2.84283 | 2.84705 | 2.85138 | 2.85587 | 2.8606 | 2.86566 | 2.87093 | 2.87661 | 2.88264 | 2.88898 | 2.89577 | 2.90308 | 2.91097 | 2.91953 | 2.92873 | 2.93863 | 2.94929 | 2.96072 | 2.97272 | 2.98493 | 2.9969 | 3.00833 | 3.0192 | 3.0299 | 3.04101 | 3.05345 | 3.06777 | 3.08416 | 3.10221 | 3.12106 | 3.13983 | 3.1581 | 3.17623 | 3.19519 | 3.21584 | 3.23747 | 3.25889 | 3.27835 | 3.29384 | 3.30362 | 3.30681 | 3.30393 | 3.297 | 3.28925 | 3.28409 | 3.28505 | 3.29326 | 3.30923 | 3.33267 | 3.36251 | 3.39661 | 3.43188 | 3.46492 | 3.49295 | 3.51458 | 3.53004 | 3.54067 | 3.54797 | 3.55306 | 3.55675 | 3.55921 | 3.56045 | 3.56034 | 3.55876 | 3.55571 | 3.55132 | 3.54585 | 3.5395 | 3.53235 | 3.52442 | 3.51583 | 3.50668 | 3.497 | 3.48683 | 3.47626 | 3.46552 | 3.45501 | 3.44481 | 3.43477 | 3.42465 | 3.41419 | 3.40303 | 3.39082 | 3.37731 | 3.36265 | 3.34745 | 3.33245 | 3.31818 | 3.30473 | 3.29186 | 3.27921 | 3.26655 | 3.25369 | 3.24045 | 3.22659 | 3.21181 | 3.196 | 3.17942 |
3 | 2.58284 | 2.58458 | 2.58629 | 2.58808 | 2.58996 | 2.59192 | 2.59401 | 2.59627 | 2.59873 | 2.60131 | 2.60414 | 2.60714 | 2.61029 | 2.61361 | 2.61714 | 2.62089 | 2.62486 | 2.62909 | 2.63361 | 2.63835 | 2.6433 | 2.64838 | 2.65354 | 2.6587 | 2.66375 | 2.6688 | 2.67383 | 2.67892 | 2.68411 | 2.68937 | 2.6947 | 2.70012 | 2.70563 | 2.71141 | 2.71775 | 2.7249 | 2.73344 | 2.74327 | 2.75433 | 2.76642 | 2.77931 | 2.79272 | 2.80649 | 2.82064 | 2.83541 | 2.85121 | 2.86872 | 2.88905 | 2.91289 | 2.94088 | 2.97325 | 3.00946 | 3.0478 | 3.08554 | 3.11947 | 3.14696 | 3.16677 | 3.17938 | 3.18631 | 3.18924 | 3.1895 | 3.18801 | 3.18498 | 3.18039 | 3.17411 | 3.16611 | 3.15641 | 3.14512 | 3.13241 | 3.11843 | 3.10329 | 3.08714 | 3.07014 | 3.05237 | 3.03393 | 3.01504 | 2.99569 | 2.97612 | 2.95642 | 2.9366 | 2.91667 | 2.89655 | 2.87622 | 2.85563 | 2.83474 | 2.81361 | 2.79235 | 2.77113 | 2.75015 | 2.72956 | 2.70934 | 2.68951 | 2.67009 | 2.65112 | 2.63262 | 2.61461 | 2.59718 | 2.58034 | 2.56404 | 2.54816 |
Y = dat.Y @head Y
... (178, 4)
Row | water | fat | protein | typ |
---|---|---|---|---|
Float64 | Float64 | Float64 | String | |
1 | 60.5 | 22.5 | 16.7 | train |
2 | 46.0 | 40.1 | 13.5 | train |
3 | 71.0 | 8.4 | 20.5 | train |
typ = Y.typ tab(typ)
OrderedCollections.OrderedDict{String, Int64} with 3 entries: "test" => 31 "train" => 115 "val" => 32
wlst = names(X) wl = parse.(Float64, wlst) #plotsp(X, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
100-element Vector{Float64}: 850.0 852.0 854.0 856.0 858.0 860.0 862.0 864.0 866.0 868.0 ⋮ 1032.0 1034.0 1036.0 1038.0 1040.0 1042.0 1044.0 1046.0 1048.0
Preprocessing
model1 = snv() model2 = savgol(npoint = 15, deriv = 2, degree = 3) model = pip(model1, model2) fit!(model, X) @head Xp = transf(model, X) #plotsp(Xp, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
3×100 Matrix{Float64}: 0.000397076 0.000495203 0.000598623 … 0.00827138 0.00917311 0.00946072 0.00242055 0.00244366 0.00234233 0.00580631 0.00689249 0.00749316 0.0011927 0.00122721 0.00120098 0.0101019 0.0108142 0.0108444 ... (178, 100)
s = typ .== "train" Xtrain = Xp[s, :] Ytrain = Y[s, :] Xtest = rmrow(Xp, s) Ytest = rmrow(Y, s) ntrain = nro(Xtrain) ntest = nro(Xtest) ntot = ntrain + ntest (ntot = ntot, ntrain, ntest)
(ntot = 178, ntrain = 115, ntest = 63)
Working response y
namy = names(Y)[1:3] j = 2 nam = namy[j] # work on the j-th y-variable ytrain = Ytrain[:, nam] ytest = Ytest[:, nam]
63-element Vector{Float64}: 29.8 1.4 4.6 11.0 17.0 22.4 27.9 46.5 6.1 2.0 ⋮ 18.1 19.4 24.8 27.2 28.4 31.3 33.8 35.5 42.5
Replicated K-fold CV
K = 3 # nb. folds (segments) rep = 10 # nb. replications segm = segmkf(ntrain, K; rep = rep)
10-element Vector{Vector{Vector{Int64}}}: [[4, 6, 13, 14, 15, 23, 24, 33, 34, 35 … 98, 99, 100, 102, 104, 105, 107, 108, 109, 114], [1, 2, 3, 7, 8, 9, 10, 18, 22, 25 … 86, 90, 92, 93, 94, 101, 110, 111, 112, 115], [5, 11, 12, 16, 17, 19, 20, 21, 26, 27 … 64, 65, 78, 79, 80, 87, 96, 103, 106, 113]] [[1, 7, 9, 10, 11, 15, 17, 18, 24, 29 … 85, 89, 101, 104, 105, 108, 111, 113, 114, 115], [2, 4, 5, 6, 12, 13, 14, 16, 19, 21 … 81, 86, 88, 93, 97, 98, 103, 107, 110, 112], [3, 8, 20, 22, 25, 26, 27, 28, 30, 32 … 91, 92, 94, 95, 96, 99, 100, 102, 106, 109]] [[3, 5, 12, 14, 23, 24, 29, 30, 32, 33 … 94, 96, 98, 99, 100, 103, 104, 111, 113, 114], [10, 13, 16, 17, 19, 20, 21, 22, 25, 27 … 80, 83, 84, 86, 91, 95, 97, 102, 105, 107], [1, 2, 4, 6, 7, 8, 9, 11, 15, 18 … 89, 92, 93, 101, 106, 108, 109, 110, 112, 115]] [[2, 6, 8, 10, 12, 16, 17, 20, 23, 24 … 98, 99, 101, 103, 105, 106, 107, 108, 110, 113], [3, 5, 7, 9, 11, 18, 21, 28, 29, 31 … 86, 88, 92, 102, 104, 109, 111, 112, 114, 115], [1, 4, 13, 14, 15, 19, 22, 25, 30, 33 … 75, 77, 79, 84, 87, 90, 91, 93, 96, 100]] [[5, 10, 11, 14, 15, 18, 19, 22, 28, 30 … 84, 86, 90, 91, 93, 100, 102, 104, 109, 115], [1, 2, 3, 6, 8, 16, 17, 20, 21, 23 … 89, 96, 98, 101, 103, 106, 107, 108, 111, 114], [4, 7, 9, 12, 13, 25, 26, 27, 29, 31 … 88, 92, 94, 95, 97, 99, 105, 110, 112, 113]] [[4, 5, 7, 8, 11, 13, 21, 29, 30, 31 … 86, 90, 99, 102, 108, 110, 112, 113, 114, 115], [3, 14, 15, 16, 17, 18, 19, 20, 22, 23 … 87, 89, 91, 92, 93, 94, 97, 103, 105, 111], [1, 2, 6, 9, 10, 12, 26, 28, 32, 33 … 88, 95, 96, 98, 100, 101, 104, 106, 107, 109]] [[7, 8, 11, 13, 15, 20, 25, 29, 30, 31 … 90, 93, 94, 97, 98, 101, 102, 104, 109, 114], [2, 5, 6, 10, 14, 16, 18, 19, 21, 22 … 84, 87, 95, 96, 105, 108, 110, 111, 113, 115], [1, 3, 4, 9, 12, 17, 23, 27, 28, 33 … 79, 83, 91, 92, 99, 100, 103, 106, 107, 112]] [[1, 2, 8, 9, 18, 19, 26, 29, 30, 31 … 83, 87, 90, 92, 94, 104, 107, 111, 112, 114], [3, 4, 10, 12, 13, 22, 23, 24, 25, 28 … 96, 97, 99, 101, 103, 105, 106, 108, 110, 115], [5, 6, 7, 11, 14, 15, 16, 17, 20, 21 … 84, 86, 89, 93, 95, 98, 100, 102, 109, 113]] [[2, 5, 7, 9, 11, 14, 20, 24, 27, 28 … 93, 94, 95, 97, 98, 103, 106, 111, 112, 115], [4, 8, 13, 15, 16, 17, 18, 19, 21, 25 … 82, 85, 88, 90, 101, 104, 108, 110, 113, 114], [1, 3, 6, 10, 12, 22, 23, 29, 33, 34 … 87, 89, 91, 96, 99, 100, 102, 105, 107, 109]] [[19, 20, 21, 22, 24, 28, 29, 31, 39, 40 … 85, 90, 92, 98, 101, 102, 103, 108, 114, 115], [1, 5, 6, 9, 10, 11, 13, 14, 18, 23 … 94, 95, 99, 104, 105, 106, 107, 109, 111, 112], [2, 3, 4, 7, 8, 12, 15, 16, 17, 25 … 86, 87, 89, 91, 93, 96, 97, 100, 110, 113]]
The best syntax to use function gridcv
for LV-based functions (e.g. rr
, krr
, etc.) is to set parameter lb
outside its argument pars
defining the grid (in general with function mpar
). In that case, the computation time is reduced [See the naïve and non-optimized syntax at the end of this script]. This is the same principle when definining the parameter nlv
in LV-based functions (eg. plskern
, kplsr
, lwplsr
, etc.).
lb = 10.0.^(-15:.1:3) model = rr() rescv = gridcv(model, Xtrain, ytrain; segm, score = rmsep, lb) @names rescv res = rescv.res res_rep = rescv.res_rep
Row | rep | segm | lb | y1 |
---|---|---|---|---|
Int64 | Int64 | Float64 | Float64 | |
1 | 1 | 1 | 1.0e-15 | 7.72914 |
2 | 1 | 1 | 1.25893e-15 | 7.72914 |
3 | 1 | 1 | 1.58489e-15 | 7.72914 |
4 | 1 | 1 | 1.99526e-15 | 7.72914 |
5 | 1 | 1 | 2.51189e-15 | 7.72914 |
6 | 1 | 1 | 3.16228e-15 | 7.72914 |
7 | 1 | 1 | 3.98107e-15 | 7.72914 |
8 | 1 | 1 | 5.01187e-15 | 7.72914 |
9 | 1 | 1 | 6.30957e-15 | 7.72914 |
10 | 1 | 1 | 7.94328e-15 | 7.72914 |
11 | 1 | 1 | 1.0e-14 | 7.72914 |
12 | 1 | 1 | 1.25893e-14 | 7.72914 |
13 | 1 | 1 | 1.58489e-14 | 7.72914 |
⋮ | ⋮ | ⋮ | ⋮ | ⋮ |
5419 | 10 | 3 | 79.4328 | 10.47 |
5420 | 10 | 3 | 100.0 | 10.47 |
5421 | 10 | 3 | 125.893 | 10.47 |
5422 | 10 | 3 | 158.489 | 10.47 |
5423 | 10 | 3 | 199.526 | 10.47 |
5424 | 10 | 3 | 251.189 | 10.47 |
5425 | 10 | 3 | 316.228 | 10.47 |
5426 | 10 | 3 | 398.107 | 10.47 |
5427 | 10 | 3 | 501.187 | 10.47 |
5428 | 10 | 3 | 630.957 | 10.47 |
5429 | 10 | 3 | 794.328 | 10.47 |
5430 | 10 | 3 | 1000.0 | 10.47 |
loglb = round.(log.(10, res.lb), digits = 3) plotgrid(loglb, res.y1; step = 2, xlabel ="lambda (log scale)", ylabel = "RMSEP-Val").f
f, ax = plotgrid(loglb, res.y1; step = 2, xlabel = "Nb. LVs", ylabel = "RMSEP-CV") for i = 1:rep, j = 1:K zres = res_rep[res_rep.rep .== i .&& res_rep.segm .== j, :] lines!(ax, loglb, zres.y1; color = (:grey, .2)) end lines!(ax, loglb, res.y1; color = :red, linewidth = 1) f
If other parameters have to be defined in the grid, they have to be set in argument pars
, such as in the example below.
pars = mpar(scal = [false; true]) lb = 10.0.^(-15:.1:3) model = rr() res = gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars, lb).res
Row | lb | scal | y1 |
---|---|---|---|
Float64 | Bool | Float64 | |
1 | 1.0e-15 | false | 5.97984 |
2 | 1.25893e-15 | false | 5.97984 |
3 | 1.58489e-15 | false | 5.97984 |
4 | 1.99526e-15 | false | 5.97984 |
5 | 2.51189e-15 | false | 5.97984 |
6 | 3.16228e-15 | false | 5.97984 |
7 | 3.98107e-15 | false | 5.97984 |
8 | 5.01187e-15 | false | 5.97984 |
9 | 6.30957e-15 | false | 5.97984 |
10 | 7.94328e-15 | false | 5.97984 |
11 | 1.0e-14 | false | 5.97984 |
12 | 1.25893e-14 | false | 5.97984 |
13 | 1.58489e-14 | false | 5.97984 |
⋮ | ⋮ | ⋮ | ⋮ |
351 | 79.4328 | true | 12.5259 |
352 | 100.0 | true | 12.5709 |
353 | 125.893 | true | 12.5996 |
354 | 158.489 | true | 12.6177 |
355 | 199.526 | true | 12.6291 |
356 | 251.189 | true | 12.6364 |
357 | 316.228 | true | 12.641 |
358 | 398.107 | true | 12.6439 |
359 | 501.187 | true | 12.6457 |
360 | 630.957 | true | 12.6468 |
361 | 794.328 | true | 12.6476 |
362 | 1000.0 | true | 12.648 |
loglb = round.(log.(10, res.lb), digits = 3) plotgrid(loglb, res.y1, res.scal; step = 2, xlabel ="lambda (log scale)", ylabel = "RMSEP-Val").f
Selection of the best parameter combination
u = findall(res.y1 .== minimum(res.y1))[1] res[u, :]
Row | lb | scal | y1 |
---|---|---|---|
Float64 | Bool | Float64 | |
115 | 0.000251189 | false | 2.13767 |
model = rr(nlv = res.lb[u], scal = res.scal[u]) fit!(model, Xtrain, ytrain) pred = predict(model, Xtest).pred
63×1 Matrix{Float64}: 27.876444002998458 3.5904081173640385 2.32548879045871 9.169771876981347 13.228383160517499 20.88245631069842 21.875931002536706 47.68294255512646 7.591807587883839 2.0283951227110926 ⋮ 18.070871456149778 19.35569672742449 21.75703948647831 27.713990081253858 27.393213664847824 34.13829277337615 33.47241171228406 38.85986950484511 44.4835756075456
Generalization error
rmsep(pred, ytest)
1×1 Matrix{Float64}: 2.8631045531296424
Plotting predictions vs. observed data
plotxy(pred, ytest; size = (500, 400), color = (:red, .5), bisect = true, title = string("Test set - variable ", nam), xlabel = "Prediction", ylabel = "Observed").f
Parameter lb
can also be set in argument pars
(wich is the generic approach to define the grid). This is strictly equivalent (gives the same results) but the computations are slower.
pars = mpar(lb = 10.0.^(-15:.1:3)) model = rr() gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars).res
Row | lb | y1 |
---|---|---|
Float64 | Float64 | |
1 | 1.0e-15 | 5.97984 |
2 | 1.25893e-15 | 5.97984 |
3 | 1.58489e-15 | 5.97984 |
4 | 1.99526e-15 | 5.97984 |
5 | 2.51189e-15 | 5.97984 |
6 | 3.16228e-15 | 5.97984 |
7 | 3.98107e-15 | 5.97984 |
8 | 5.01187e-15 | 5.97984 |
9 | 6.30957e-15 | 5.97984 |
10 | 7.94328e-15 | 5.97984 |
11 | 1.0e-14 | 5.97984 |
12 | 1.25893e-14 | 5.97984 |
13 | 1.58489e-14 | 5.97984 |
⋮ | ⋮ | ⋮ |
170 | 79.4328 | 12.6488 |
171 | 100.0 | 12.6488 |
172 | 125.893 | 12.6488 |
173 | 158.489 | 12.6488 |
174 | 199.526 | 12.6488 |
175 | 251.189 | 12.6488 |
176 | 316.228 | 12.6488 |
177 | 398.107 | 12.6488 |
178 | 501.187 | 12.6488 |
179 | 630.957 | 12.6488 |
180 | 794.328 | 12.6488 |
181 | 1000.0 | 12.6488 |
pars = mpar(lb = 10.0.^(-15:.1:3), scal = [false; true]) model = rr() gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars).res
Row | lb | scal | y1 |
---|---|---|---|
Real | Real | Float64 | |
1 | 1.0e-15 | false | 5.97984 |
2 | 1.25893e-15 | false | 5.97984 |
3 | 1.58489e-15 | false | 5.97984 |
4 | 1.99526e-15 | false | 5.97984 |
5 | 2.51189e-15 | false | 5.97984 |
6 | 3.16228e-15 | false | 5.97984 |
7 | 3.98107e-15 | false | 5.97984 |
8 | 5.01187e-15 | false | 5.97984 |
9 | 6.30957e-15 | false | 5.97984 |
10 | 7.94328e-15 | false | 5.97984 |
11 | 1.0e-14 | false | 5.97984 |
12 | 1.25893e-14 | false | 5.97984 |
13 | 1.58489e-14 | false | 5.97984 |
⋮ | ⋮ | ⋮ | ⋮ |
351 | 79.4328 | true | 12.5259 |
352 | 100.0 | true | 12.5709 |
353 | 125.893 | true | 12.5996 |
354 | 158.489 | true | 12.6177 |
355 | 199.526 | true | 12.6291 |
356 | 251.189 | true | 12.6364 |
357 | 316.228 | true | 12.641 |
358 | 398.107 | true | 12.6439 |
359 | 501.187 | true | 12.6457 |
360 | 630.957 | true | 12.6468 |
361 | 794.328 | true | 12.6476 |
362 | 1000.0 | true | 12.648 |