using Jchemo, JchemoData using JLD2, CairoMakie
path_jdat = dirname(dirname(pathof(JchemoData))) db = joinpath(path_jdat, "data/tecator.jld2") @load db dat @names dat
(:X, :Y)
X = dat.X @head X
... (178, 100)
| Row | 850 | 852 | 854 | 856 | 858 | 860 | 862 | 864 | 866 | 868 | 870 | 872 | 874 | 876 | 878 | 880 | 882 | 884 | 886 | 888 | 890 | 892 | 894 | 896 | 898 | 900 | 902 | 904 | 906 | 908 | 910 | 912 | 914 | 916 | 918 | 920 | 922 | 924 | 926 | 928 | 930 | 932 | 934 | 936 | 938 | 940 | 942 | 944 | 946 | 948 | 950 | 952 | 954 | 956 | 958 | 960 | 962 | 964 | 966 | 968 | 970 | 972 | 974 | 976 | 978 | 980 | 982 | 984 | 986 | 988 | 990 | 992 | 994 | 996 | 998 | 1000 | 1002 | 1004 | 1006 | 1008 | 1010 | 1012 | 1014 | 1016 | 1018 | 1020 | 1022 | 1024 | 1026 | 1028 | 1030 | 1032 | 1034 | 1036 | 1038 | 1040 | 1042 | 1044 | 1046 | 1048 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | |
| 1 | 2.61776 | 2.61814 | 2.61859 | 2.61912 | 2.61981 | 2.62071 | 2.62186 | 2.62334 | 2.62511 | 2.62722 | 2.62964 | 2.63245 | 2.63565 | 2.63933 | 2.64353 | 2.64825 | 2.6535 | 2.65937 | 2.66585 | 2.67281 | 2.68008 | 2.68733 | 2.69427 | 2.70073 | 2.70684 | 2.71281 | 2.71914 | 2.72628 | 2.73462 | 2.74416 | 2.75466 | 2.76568 | 2.77679 | 2.7879 | 2.79949 | 2.81225 | 2.82706 | 2.84356 | 2.86106 | 2.87857 | 2.89497 | 2.90924 | 2.92085 | 2.93015 | 2.93846 | 2.94771 | 2.96019 | 2.97831 | 3.00306 | 3.03506 | 3.07428 | 3.11963 | 3.16868 | 3.21771 | 3.26254 | 3.29988 | 3.32847 | 3.34899 | 3.36342 | 3.37379 | 3.38152 | 3.38741 | 3.39164 | 3.39418 | 3.3949 | 3.39366 | 3.39045 | 3.38541 | 3.37869 | 3.37041 | 3.36073 | 3.34979 | 3.33769 | 3.32443 | 3.31013 | 3.29487 | 3.27891 | 3.26232 | 3.24542 | 3.22828 | 3.2108 | 3.19287 | 3.17433 | 3.15503 | 3.13475 | 3.11339 | 3.09116 | 3.0685 | 3.04596 | 3.02393 | 3.00247 | 2.98145 | 2.96072 | 2.94013 | 2.91978 | 2.89966 | 2.87964 | 2.8596 | 2.8394 | 2.8192 |
| 2 | 2.83454 | 2.83871 | 2.84283 | 2.84705 | 2.85138 | 2.85587 | 2.8606 | 2.86566 | 2.87093 | 2.87661 | 2.88264 | 2.88898 | 2.89577 | 2.90308 | 2.91097 | 2.91953 | 2.92873 | 2.93863 | 2.94929 | 2.96072 | 2.97272 | 2.98493 | 2.9969 | 3.00833 | 3.0192 | 3.0299 | 3.04101 | 3.05345 | 3.06777 | 3.08416 | 3.10221 | 3.12106 | 3.13983 | 3.1581 | 3.17623 | 3.19519 | 3.21584 | 3.23747 | 3.25889 | 3.27835 | 3.29384 | 3.30362 | 3.30681 | 3.30393 | 3.297 | 3.28925 | 3.28409 | 3.28505 | 3.29326 | 3.30923 | 3.33267 | 3.36251 | 3.39661 | 3.43188 | 3.46492 | 3.49295 | 3.51458 | 3.53004 | 3.54067 | 3.54797 | 3.55306 | 3.55675 | 3.55921 | 3.56045 | 3.56034 | 3.55876 | 3.55571 | 3.55132 | 3.54585 | 3.5395 | 3.53235 | 3.52442 | 3.51583 | 3.50668 | 3.497 | 3.48683 | 3.47626 | 3.46552 | 3.45501 | 3.44481 | 3.43477 | 3.42465 | 3.41419 | 3.40303 | 3.39082 | 3.37731 | 3.36265 | 3.34745 | 3.33245 | 3.31818 | 3.30473 | 3.29186 | 3.27921 | 3.26655 | 3.25369 | 3.24045 | 3.22659 | 3.21181 | 3.196 | 3.17942 |
| 3 | 2.58284 | 2.58458 | 2.58629 | 2.58808 | 2.58996 | 2.59192 | 2.59401 | 2.59627 | 2.59873 | 2.60131 | 2.60414 | 2.60714 | 2.61029 | 2.61361 | 2.61714 | 2.62089 | 2.62486 | 2.62909 | 2.63361 | 2.63835 | 2.6433 | 2.64838 | 2.65354 | 2.6587 | 2.66375 | 2.6688 | 2.67383 | 2.67892 | 2.68411 | 2.68937 | 2.6947 | 2.70012 | 2.70563 | 2.71141 | 2.71775 | 2.7249 | 2.73344 | 2.74327 | 2.75433 | 2.76642 | 2.77931 | 2.79272 | 2.80649 | 2.82064 | 2.83541 | 2.85121 | 2.86872 | 2.88905 | 2.91289 | 2.94088 | 2.97325 | 3.00946 | 3.0478 | 3.08554 | 3.11947 | 3.14696 | 3.16677 | 3.17938 | 3.18631 | 3.18924 | 3.1895 | 3.18801 | 3.18498 | 3.18039 | 3.17411 | 3.16611 | 3.15641 | 3.14512 | 3.13241 | 3.11843 | 3.10329 | 3.08714 | 3.07014 | 3.05237 | 3.03393 | 3.01504 | 2.99569 | 2.97612 | 2.95642 | 2.9366 | 2.91667 | 2.89655 | 2.87622 | 2.85563 | 2.83474 | 2.81361 | 2.79235 | 2.77113 | 2.75015 | 2.72956 | 2.70934 | 2.68951 | 2.67009 | 2.65112 | 2.63262 | 2.61461 | 2.59718 | 2.58034 | 2.56404 | 2.54816 |
Y = dat.Y @head Y
... (178, 4)
| Row | water | fat | protein | typ |
|---|---|---|---|---|
| Float64 | Float64 | Float64 | String | |
| 1 | 60.5 | 22.5 | 16.7 | train |
| 2 | 46.0 | 40.1 | 13.5 | train |
| 3 | 71.0 | 8.4 | 20.5 | train |
typ = Y.typ tab(typ)
OrderedCollections.OrderedDict{String, Int64} with 3 entries:
"test" => 31
"train" => 115
"val" => 32
wlst = names(X) wl = parse.(Float64, wlst) #plotsp(X, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
100-element Vector{Float64}:
850.0
852.0
854.0
856.0
858.0
860.0
862.0
864.0
866.0
868.0
⋮
1032.0
1034.0
1036.0
1038.0
1040.0
1042.0
1044.0
1046.0
1048.0
Preprocessing
model1 = snv() model2 = savgol(npoint = 15, deriv = 2, degree = 3) model = pip(model1, model2) fit!(model, X) @head Xp = transf(model, X) #plotsp(Xp, wl; xlabel = "Wavelength (nm)", ylabel = "Absorbance").f
3×100 Matrix{Float64}:
0.000397076 0.000495203 0.000598623 … 0.00827138 0.00917311 0.00946072
0.00242055 0.00244366 0.00234233 0.00580631 0.00689249 0.00749316
0.0011927 0.00122721 0.00120098 0.0101019 0.0108142 0.0108444
... (178, 100)
The model is fitted on Train, and the generalization error is estimated on Test. In this example, Train is already defined in variable typ of the dataset, and Test is defined by the remaining samples. But Tot could also be split a posteriori, for instance by sampling (random, systematic or any other designs). See for instance functions samprand, sampsys, etc.
s = typ .== "train" Xtrain = Xp[s, :] Ytrain = Y[s, :] Xtest = rmrow(Xp, s) Ytest = rmrow(Y, s) ntrain = nro(Xtrain) ntest = nro(Xtest) ntot = ntrain + ntest (ntot = ntot, ntrain, ntest)
(ntot = 178, ntrain = 115, ntest = 63)
Working response y
namy = names(Y)[1:3] j = 2 nam = namy[j] # work on the j-th y-variable ytrain = Ytrain[:, nam] ytest = Ytest[:, nam]
63-element Vector{Float64}:
29.8
1.4
4.6
11.0
17.0
22.4
27.9
46.5
6.1
2.0
⋮
18.1
19.4
24.8
27.2
28.4
31.3
33.8
35.5
42.5
Two methods can be used to build the CV segments within Train (for the same total number of segments, these two methods return close resuts):
(1) Replicated K-fold CV
Train is splitted in a number of K folds (segments),
and this split can be replicated (==> replicated K-Fold CV).
K = 3 # nb. folds (segments) rep = 25 # nb. replications segm = segmkf(ntrain, K; rep = rep)
25-element Vector{Vector{Vector{Int64}}}:
[[3, 4, 5, 6, 12, 16, 22, 24, 26, 28 … 91, 93, 94, 95, 98, 100, 104, 105, 107, 115], [1, 2, 7, 8, 13, 14, 15, 17, 18, 25 … 90, 92, 99, 106, 108, 109, 110, 111, 112, 113], [9, 10, 11, 19, 20, 21, 23, 27, 30, 31 … 81, 84, 85, 86, 96, 97, 101, 102, 103, 114]]
[[4, 6, 9, 11, 13, 14, 16, 18, 19, 20 … 77, 85, 89, 94, 97, 103, 108, 111, 112, 114], [1, 2, 5, 7, 12, 15, 24, 25, 26, 29 … 90, 92, 96, 98, 99, 101, 106, 107, 110, 113], [3, 8, 10, 17, 21, 22, 27, 35, 37, 40 … 84, 91, 93, 95, 100, 102, 104, 105, 109, 115]]
[[3, 4, 13, 14, 16, 18, 24, 28, 30, 32 … 93, 99, 101, 105, 107, 109, 110, 111, 114, 115], [5, 6, 7, 10, 11, 19, 20, 21, 22, 25 … 84, 85, 86, 89, 92, 94, 96, 102, 103, 106], [1, 2, 8, 9, 12, 15, 17, 23, 27, 31 … 90, 91, 95, 97, 98, 100, 104, 108, 112, 113]]
[[1, 7, 10, 12, 14, 16, 22, 24, 26, 28 … 76, 82, 83, 84, 85, 94, 96, 101, 108, 115], [2, 3, 13, 15, 17, 18, 21, 23, 36, 37 … 95, 97, 98, 100, 102, 103, 104, 105, 109, 110], [4, 5, 6, 8, 9, 11, 19, 20, 25, 27 … 89, 90, 92, 99, 106, 107, 111, 112, 113, 114]]
[[2, 4, 6, 10, 19, 22, 25, 30, 31, 32 … 88, 90, 99, 101, 103, 104, 105, 107, 108, 109], [1, 3, 7, 8, 9, 11, 14, 15, 16, 17 … 80, 82, 89, 95, 106, 110, 112, 113, 114, 115], [5, 12, 13, 18, 20, 21, 23, 26, 28, 29 … 91, 92, 93, 94, 96, 97, 98, 100, 102, 111]]
[[2, 4, 5, 18, 21, 26, 30, 31, 34, 35 … 90, 95, 98, 102, 105, 107, 112, 113, 114, 115], [8, 9, 10, 12, 13, 14, 15, 16, 17, 24 … 91, 94, 97, 99, 103, 104, 106, 108, 110, 111], [1, 3, 6, 7, 11, 19, 20, 22, 23, 28 … 78, 79, 81, 89, 92, 93, 96, 100, 101, 109]]
[[1, 3, 7, 14, 16, 17, 18, 19, 23, 24 … 80, 81, 83, 86, 91, 92, 93, 94, 99, 107], [4, 5, 6, 8, 10, 11, 20, 28, 31, 32 … 97, 103, 104, 106, 109, 110, 112, 113, 114, 115], [2, 9, 12, 13, 15, 21, 22, 29, 33, 36 … 88, 90, 96, 98, 100, 101, 102, 105, 108, 111]]
[[1, 3, 5, 9, 10, 16, 19, 21, 26, 28 … 86, 89, 93, 96, 100, 101, 105, 107, 109, 115], [2, 7, 11, 13, 14, 22, 24, 27, 30, 31 … 88, 91, 92, 97, 98, 99, 104, 112, 113, 114], [4, 6, 8, 12, 15, 17, 18, 20, 23, 25 … 85, 90, 94, 95, 102, 103, 106, 108, 110, 111]]
[[5, 7, 10, 12, 13, 18, 20, 21, 22, 35 … 102, 103, 104, 105, 106, 109, 110, 111, 113, 114], [3, 4, 8, 9, 11, 16, 23, 25, 26, 27 … 77, 85, 90, 91, 94, 95, 107, 108, 112, 115], [1, 2, 6, 14, 15, 17, 19, 24, 28, 33 … 76, 79, 80, 81, 83, 84, 92, 96, 98, 101]]
[[1, 3, 5, 12, 16, 17, 20, 22, 36, 38 … 92, 94, 95, 98, 99, 100, 103, 108, 109, 112], [4, 7, 11, 18, 19, 28, 29, 31, 33, 34 … 88, 93, 96, 97, 104, 105, 107, 110, 113, 115], [2, 6, 8, 9, 10, 13, 14, 15, 21, 23 … 79, 80, 83, 85, 90, 101, 102, 106, 111, 114]]
⋮
[[2, 3, 6, 15, 16, 23, 24, 25, 30, 31 … 87, 90, 93, 94, 97, 98, 101, 111, 114, 115], [1, 4, 8, 11, 12, 14, 17, 19, 20, 22 … 77, 81, 83, 86, 88, 89, 96, 100, 108, 113], [5, 7, 9, 10, 13, 18, 21, 27, 33, 34 … 99, 102, 103, 104, 105, 106, 107, 109, 110, 112]]
[[1, 5, 11, 14, 18, 20, 21, 25, 43, 44 … 99, 100, 101, 105, 106, 108, 109, 110, 111, 112], [2, 3, 4, 7, 9, 10, 16, 23, 24, 29 … 88, 90, 92, 93, 94, 95, 96, 103, 113, 115], [6, 8, 12, 13, 15, 17, 19, 22, 26, 27 … 78, 81, 82, 84, 86, 97, 102, 104, 107, 114]]
[[3, 4, 7, 8, 11, 12, 16, 22, 29, 30 … 96, 98, 99, 100, 102, 106, 107, 108, 110, 115], [1, 9, 10, 13, 15, 17, 18, 19, 20, 23 … 81, 82, 84, 89, 92, 95, 97, 101, 109, 112], [2, 5, 6, 14, 21, 24, 25, 26, 27, 28 … 88, 90, 93, 94, 103, 104, 105, 111, 113, 114]]
[[4, 5, 6, 7, 13, 19, 25, 27, 29, 31 … 89, 90, 92, 96, 102, 106, 108, 110, 111, 114], [10, 11, 12, 14, 15, 16, 17, 18, 22, 23 … 80, 82, 87, 91, 98, 99, 100, 104, 105, 109], [1, 2, 3, 8, 9, 20, 21, 28, 30, 35 … 93, 94, 95, 97, 101, 103, 107, 112, 113, 115]]
[[6, 8, 10, 14, 15, 26, 30, 33, 34, 37 … 88, 93, 98, 99, 101, 104, 105, 110, 111, 115], [2, 9, 11, 12, 13, 17, 18, 20, 21, 23 … 81, 85, 89, 94, 95, 96, 102, 103, 107, 112], [1, 3, 4, 5, 7, 16, 19, 22, 25, 27 … 90, 91, 92, 97, 100, 106, 108, 109, 113, 114]]
[[1, 4, 6, 8, 12, 13, 14, 21, 27, 29 … 85, 88, 91, 96, 98, 100, 101, 106, 112, 113], [5, 9, 16, 18, 19, 24, 25, 34, 35, 38 … 94, 95, 99, 105, 107, 108, 109, 110, 111, 114], [2, 3, 7, 10, 11, 15, 17, 20, 22, 23 … 77, 78, 83, 87, 92, 97, 102, 103, 104, 115]]
[[1, 3, 6, 10, 12, 14, 20, 21, 23, 24 … 86, 87, 88, 89, 90, 96, 103, 106, 107, 112], [2, 4, 7, 9, 16, 18, 26, 29, 31, 32 … 91, 93, 94, 95, 97, 99, 102, 104, 109, 113], [5, 8, 11, 13, 15, 17, 19, 22, 25, 28 … 92, 98, 100, 101, 105, 108, 110, 111, 114, 115]]
[[1, 3, 9, 13, 20, 21, 22, 27, 29, 31 … 87, 90, 98, 99, 100, 101, 106, 111, 113, 115], [2, 5, 8, 11, 12, 18, 23, 24, 25, 26 … 93, 95, 96, 97, 103, 104, 105, 107, 110, 112], [4, 6, 7, 10, 14, 15, 16, 17, 19, 28 … 86, 88, 89, 91, 92, 94, 102, 108, 109, 114]]
[[2, 4, 9, 12, 16, 17, 18, 22, 25, 27 … 87, 88, 89, 94, 104, 105, 107, 108, 110, 111], [1, 3, 7, 8, 11, 15, 21, 29, 33, 36 … 85, 91, 95, 97, 99, 100, 101, 106, 109, 115], [5, 6, 10, 13, 14, 19, 20, 23, 24, 26 … 90, 92, 93, 96, 98, 102, 103, 112, 113, 114]]
(2) Replicated "Test-set" CV
Train is split to Cal/Val (e.g. Cal = 70% of Train, Val = 30% of Train), and this is replicated.
#pct = .30 #m = Int(round(pct * ntrain)) #segm = segmts(ntrain, m; rep = 30)
Illustration of segments:
i = 1 segm[i] # the K segments of replication 'i'
3-element Vector{Vector{Int64}}:
[3, 4, 5, 6, 12, 16, 22, 24, 26, 28 … 91, 93, 94, 95, 98, 100, 104, 105, 107, 115]
[1, 2, 7, 8, 13, 14, 15, 17, 18, 25 … 90, 92, 99, 106, 108, 109, 110, 111, 112, 113]
[9, 10, 11, 19, 20, 21, 23, 27, 30, 31 … 81, 84, 85, 86, 96, 97, 101, 102, 103, 114]
k = 1 segm[i][k] # segment 'k' of replication 'i'
39-element Vector{Int64}:
3
4
5
6
12
16
22
24
26
28
⋮
93
94
95
98
100
104
105
107
115
The best syntax to use function gridcv for LV-based functions (eg. plskern, kplsr, lwplsr, etc.) is to set parameter nlv outside argument pars defining the grid (in general with function mpar). In that case, the computation time is reduced [See the naïve and non-optimized syntax at the end of this script]. This is the same principle when definining the parameter lb in ridge-based functions (rr, krr, etc.).
nlv = 0:20 model = plskern() rescv = gridcv(model, Xtrain, ytrain; segm, score = rmsep, nlv) @names rescv res = rescv.res res_rep = rescv.res_rep
| Row | rep | segm | nlv | y1 |
|---|---|---|---|---|
| Int64 | Int64 | Int64 | Float64 | |
| 1 | 1 | 1 | 0 | 13.5666 |
| 2 | 1 | 1 | 1 | 2.44251 |
| 3 | 1 | 1 | 2 | 2.24555 |
| 4 | 1 | 1 | 3 | 2.42385 |
| 5 | 1 | 1 | 4 | 2.58804 |
| 6 | 1 | 1 | 5 | 2.22739 |
| 7 | 1 | 1 | 6 | 2.28135 |
| 8 | 1 | 1 | 7 | 2.23079 |
| 9 | 1 | 1 | 8 | 2.19531 |
| 10 | 1 | 1 | 9 | 2.17222 |
| 11 | 1 | 1 | 10 | 2.16421 |
| 12 | 1 | 1 | 11 | 2.17249 |
| 13 | 1 | 1 | 12 | 2.23543 |
| ⋮ | ⋮ | ⋮ | ⋮ | ⋮ |
| 1564 | 25 | 3 | 9 | 2.23427 |
| 1565 | 25 | 3 | 10 | 2.15717 |
| 1566 | 25 | 3 | 11 | 2.13504 |
| 1567 | 25 | 3 | 12 | 2.20721 |
| 1568 | 25 | 3 | 13 | 2.64451 |
| 1569 | 25 | 3 | 14 | 2.48442 |
| 1570 | 25 | 3 | 15 | 2.40507 |
| 1571 | 25 | 3 | 16 | 2.04081 |
| 1572 | 25 | 3 | 17 | 1.88561 |
| 1573 | 25 | 3 | 18 | 1.66322 |
| 1574 | 25 | 3 | 19 | 1.56811 |
| 1575 | 25 | 3 | 20 | 1.55932 |
plotgrid(res.nlv, res.y1; step = 2, xlabel = "Nb. LVs", ylabel = "RMSEP-CV").f
f, ax = plotgrid(res.nlv, res.y1; step = 2, xlabel = "Nb. LVs", ylabel = "RMSEP-CV") for i = 1:rep, j = 1:K zres = res_rep[res_rep.rep .== i .&& res_rep.segm .== j, :] lines!(ax, zres.nlv, zres.y1; color = (:grey, .2)) end lines!(ax, res.nlv, res.y1; color = :red, linewidth = 1) f
Selection of the best parameter combination
u = findall(res.y1 .== minimum(res.y1))[1] res[u, :]
| Row | nlv | y1 |
|---|---|---|
| Int64 | Float64 | |
| 7 | 6 | 2.2196 |
model = plskern(nlv = res.nlv[u]) fit!(model, Xtrain, ytrain) pred = predict(model, Xtest).pred
63×1 Matrix{Float64}:
27.396108728712
4.4912779253048605
4.9900980331906455
11.016752894644581
13.6730456055306
20.679552141109767
25.814043422530865
54.06157064233883
9.87588864713091
5.04840120243277
⋮
17.006189347070176
17.53307437947076
20.679445265823034
26.951106243297836
25.25250542344898
32.719328178913216
33.196189093807185
37.01967890654217
46.24659766456356
Generalization error
rmsep(pred, ytest)
1×1 Matrix{Float64}:
2.075776908672637
Plotting predictions vs. observed data
plotxy(pred, ytest; size = (500, 400), color = (:red, .5), bisect = true, title = "Test set - variable $nam", xlabel = "Prediction", ylabel = "Observed").f
If other parameters have to be defined in the grid, they have to be set in argument pars built with function mpar, such as in the example below.
pars = mpar(scal = [false; true]) nlv = 0:20 model = plskern() res = gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars, nlv).res
| Row | nlv | scal | y1 |
|---|---|---|---|
| Int64 | Bool | Float64 | |
| 1 | 0 | false | 12.6693 |
| 2 | 0 | true | 12.6693 |
| 3 | 1 | false | 2.5128 |
| 4 | 1 | true | 2.76265 |
| 5 | 2 | false | 2.34379 |
| 6 | 2 | true | 2.35703 |
| 7 | 3 | false | 2.31464 |
| 8 | 3 | true | 2.42326 |
| 9 | 4 | false | 2.24089 |
| 10 | 4 | true | 2.47409 |
| 11 | 5 | false | 2.22243 |
| 12 | 5 | true | 2.39829 |
| 13 | 6 | false | 2.2196 |
| ⋮ | ⋮ | ⋮ | ⋮ |
| 31 | 15 | false | 2.82116 |
| 32 | 15 | true | 2.83473 |
| 33 | 16 | false | 2.80642 |
| 34 | 16 | true | 2.84314 |
| 35 | 17 | false | 2.82318 |
| 36 | 17 | true | 2.8986 |
| 37 | 18 | false | 2.82288 |
| 38 | 18 | true | 3.02593 |
| 39 | 19 | false | 2.90825 |
| 40 | 19 | true | 3.2283 |
| 41 | 20 | false | 3.00598 |
| 42 | 20 | true | 3.3288 |
plotgrid(res.nlv, res.y1, res.scal; step = 2, xlabel = "Nb. LVs", ylabel = "RMSEP-Val").f
u = findall(res.y1 .== minimum(res.y1))[1] res[u, :]
| Row | nlv | scal | y1 |
|---|---|---|---|
| Int64 | Bool | Float64 | |
| 13 | 6 | false | 2.2196 |
model = plskern(nlv = res.nlv[u], scal = res.scal[u]) fit!(model, Xtrain, ytrain) pred = predict(model, Xtest).pred
63×1 Matrix{Float64}:
27.396108728712
4.4912779253048605
4.9900980331906455
11.016752894644581
13.6730456055306
20.679552141109767
25.814043422530865
54.06157064233883
9.87588864713091
5.04840120243277
⋮
17.006189347070176
17.53307437947076
20.679445265823034
26.951106243297836
25.25250542344898
32.719328178913216
33.196189093807185
37.01967890654217
46.24659766456356
Parameter nlv can also be set in argument pars (wich is the generic approach to define the grid). This is strictly equivalent (gives the same results) but the computations are slower.
pars = mpar(nlv = 0:20) model = plskern() gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars).res
| Row | nlv | y1 |
|---|---|---|
| Int64 | Float64 | |
| 1 | 0 | 12.6693 |
| 2 | 1 | 2.5128 |
| 3 | 2 | 2.34379 |
| 4 | 3 | 2.31464 |
| 5 | 4 | 2.24089 |
| 6 | 5 | 2.22243 |
| 7 | 6 | 2.2196 |
| 8 | 7 | 2.27627 |
| 9 | 8 | 2.27375 |
| 10 | 9 | 2.22622 |
| 11 | 10 | 2.26324 |
| 12 | 11 | 2.3036 |
| 13 | 12 | 2.47049 |
| 14 | 13 | 2.6224 |
| 15 | 14 | 2.73301 |
| 16 | 15 | 2.82116 |
| 17 | 16 | 2.80642 |
| 18 | 17 | 2.82318 |
| 19 | 18 | 2.82288 |
| 20 | 19 | 2.90825 |
| 21 | 20 | 3.00598 |
pars = mpar(nlv = 0:20, scal = [false; true]) model = plskern() gridcv(model, Xtrain, ytrain; segm, score = rmsep, pars).res
| Row | nlv | scal | y1 |
|---|---|---|---|
| Integer | Integer | Float64 | |
| 1 | 0 | false | 12.6693 |
| 2 | 0 | true | 12.6693 |
| 3 | 1 | false | 2.5128 |
| 4 | 1 | true | 2.76265 |
| 5 | 2 | false | 2.34379 |
| 6 | 2 | true | 2.35703 |
| 7 | 3 | false | 2.31464 |
| 8 | 3 | true | 2.42326 |
| 9 | 4 | false | 2.24089 |
| 10 | 4 | true | 2.47409 |
| 11 | 5 | false | 2.22243 |
| 12 | 5 | true | 2.39829 |
| 13 | 6 | false | 2.2196 |
| ⋮ | ⋮ | ⋮ | ⋮ |
| 31 | 15 | false | 2.82116 |
| 32 | 15 | true | 2.83473 |
| 33 | 16 | false | 2.80642 |
| 34 | 16 | true | 2.84314 |
| 35 | 17 | false | 2.82318 |
| 36 | 17 | true | 2.8986 |
| 37 | 18 | false | 2.82288 |
| 38 | 18 | true | 3.02593 |
| 39 | 19 | false | 2.90825 |
| 40 | 19 | true | 3.2283 |
| 41 | 20 | false | 3.00598 |
| 42 | 20 | true | 3.3288 |