using Jchemo, JchemoData using JLD2, CairoMakie using FreqTables using CodecZlib # required since mnist20pct.jld2 is compressed
path_jdat = dirname(dirname(pathof(JchemoData))) db = joinpath(path_jdat, "data/mnist20pct.jld2") @load db dat @names dat
(:Xtrain, :ytrain, :Xtest, :ytest)
Xtrain = dat.Xtrain ytrain = dat.ytrain Xtest = dat.Xtest ytest = dat.ytest ntrain, p = size(Xtrain) ntest = nro(Xtest) ntot = ntrain + ntest (ntot = ntot, ntrain, ntest)
(ntot = 14000, ntrain = 12000, ntest = 2000)
tab(ytrain)
OrderedCollections.OrderedDict{Float32, Int64} with 10 entries: 0.0 => 1185 1.0 => 1348 2.0 => 1192 3.0 => 1226 4.0 => 1168 5.0 => 1084 6.0 => 1184 7.0 => 1253 8.0 => 1170 9.0 => 1190
tab(ytest)
OrderedCollections.OrderedDict{Float32, Int64} with 10 entries: 0.0 => 196 1.0 => 227 2.0 => 207 3.0 => 201 4.0 => 197 5.0 => 178 6.0 => 192 7.0 => 205 8.0 => 195 9.0 => 202
Grey levels 0-255 standardized between 0-1 (not required here but used when fitting deep learning models)
@head Xtrain = Matrix(Xtrain) / 255
3×784 Matrix{Float32}: 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... (12000, 784)
@head Xtest = Matrix(Xtest) / 255
3×784 Matrix{Float32}: 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... (2000, 784)
Example of one sample (= one unfolded image)
plotsp(Xtrain, 1:p; nsamp = 1, xlabel = "Pixel", ylabel = "Grey level").f
model = plsqda(nlv = 25) fit!(model, Xtrain, ytrain) pred = predict(model, Xtest).pred
2000×1 Matrix{Float32}: 1.0 0.0 4.0 9.0 6.0 0.0 2.0 7.0 1.0 5.0 ⋮ 6.0 7.0 8.0 1.0 3.0 4.0 9.0 2.0 6.0
Error rates (proportions)
errp(pred, ytest) # overall
1×1 Matrix{Float64}: 0.052
merrp(pred, ytest) # average by class
1×1 Matrix{Float64}: 0.052147630315370674
Using function freqtable
of package FreqTables
Counts
res = freqtable(ytest, vec(pred))
10×10 Named Matrix{Int64} Dim1 ╲ Dim2 │ 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 ────────────┼───────────────────────────────────────────────── 0.0 │ 195 0 0 0 0 0 0 0 1 0 1.0 │ 0 219 3 2 0 0 0 0 3 0 2.0 │ 1 0 197 1 1 0 2 0 4 1 3.0 │ 2 0 2 185 0 3 0 0 8 1 4.0 │ 1 0 1 0 191 0 0 2 0 2 5.0 │ 1 0 0 6 0 165 1 1 2 2 6.0 │ 0 0 1 0 0 3 185 0 3 0 7.0 │ 1 1 11 0 1 1 0 186 2 2 8.0 │ 0 0 1 4 0 2 0 0 185 3 9.0 │ 2 0 0 3 2 1 0 2 4 188
Row %
round.(100 * res ./ rowsum(res); digits = 1)
10×10 Named Matrix{Float64} Dim1 ╲ Dim2 │ 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 ────────────┼─────────────────────────────────────────────────────────── 0.0 │ 99.5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.5 0.0 1.0 │ 0.0 96.5 1.3 0.9 0.0 0.0 0.0 0.0 1.3 0.0 2.0 │ 0.5 0.0 95.2 0.5 0.5 0.0 1.0 0.0 1.9 0.5 3.0 │ 1.0 0.0 1.0 92.0 0.0 1.5 0.0 0.0 4.0 0.5 4.0 │ 0.5 0.0 0.5 0.0 97.0 0.0 0.0 1.0 0.0 1.0 5.0 │ 0.6 0.0 0.0 3.4 0.0 92.7 0.6 0.6 1.1 1.1 6.0 │ 0.0 0.0 0.5 0.0 0.0 1.6 96.4 0.0 1.6 0.0 7.0 │ 0.5 0.5 5.4 0.0 0.5 0.5 0.0 90.7 1.0 1.0 8.0 │ 0.0 0.0 0.5 2.1 0.0 1.0 0.0 0.0 94.9 1.5 9.0 │ 1.0 0.0 0.0 1.5 1.0 0.5 0.0 1.0 2.0 93.1
Using function conf
of package Jchemo
cf = conf(pred, ytest) @names cf
(:cnt, :pct, :A, :Apct, :diagpct, :accpct, :lev)
Counts
cf.cnt
Row | y | pred_0.0 | pred_1.0 | pred_2.0 | pred_3.0 | pred_4.0 | pred_5.0 | pred_6.0 | pred_7.0 | pred_8.0 | pred_9.0 |
---|---|---|---|---|---|---|---|---|---|---|---|
String | Int64 | Int64 | Int64 | Int64 | Int64 | Int64 | Int64 | Int64 | Int64 | Int64 | |
1 | 0.0 | 195 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
2 | 1.0 | 0 | 219 | 3 | 2 | 0 | 0 | 0 | 0 | 3 | 0 |
3 | 2.0 | 1 | 0 | 197 | 1 | 1 | 0 | 2 | 0 | 4 | 1 |
4 | 3.0 | 2 | 0 | 2 | 185 | 0 | 3 | 0 | 0 | 8 | 1 |
5 | 4.0 | 1 | 0 | 1 | 0 | 191 | 0 | 0 | 2 | 0 | 2 |
6 | 5.0 | 1 | 0 | 0 | 6 | 0 | 165 | 1 | 1 | 2 | 2 |
7 | 6.0 | 0 | 0 | 1 | 0 | 0 | 3 | 185 | 0 | 3 | 0 |
8 | 7.0 | 1 | 1 | 11 | 0 | 1 | 1 | 0 | 186 | 2 | 2 |
9 | 8.0 | 0 | 0 | 1 | 4 | 0 | 2 | 0 | 0 | 185 | 3 |
10 | 9.0 | 2 | 0 | 0 | 3 | 2 | 1 | 0 | 2 | 4 | 188 |
Row %
cf.pct
Row | levels | pred_0.0 | pred_1.0 | pred_2.0 | pred_3.0 | pred_4.0 | pred_5.0 | pred_6.0 | pred_7.0 | pred_8.0 | pred_9.0 |
---|---|---|---|---|---|---|---|---|---|---|---|
String | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | |
1 | 0.0 | 99.5 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.5 | 0.0 |
2 | 1.0 | 0.0 | 96.5 | 1.3 | 0.9 | 0.0 | 0.0 | 0.0 | 0.0 | 1.3 | 0.0 |
3 | 2.0 | 0.5 | 0.0 | 95.2 | 0.5 | 0.5 | 0.0 | 1.0 | 0.0 | 1.9 | 0.5 |
4 | 3.0 | 1.0 | 0.0 | 1.0 | 92.0 | 0.0 | 1.5 | 0.0 | 0.0 | 4.0 | 0.5 |
5 | 4.0 | 0.5 | 0.0 | 0.5 | 0.0 | 97.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 |
6 | 5.0 | 0.6 | 0.0 | 0.0 | 3.4 | 0.0 | 92.7 | 0.6 | 0.6 | 1.1 | 1.1 |
7 | 6.0 | 0.0 | 0.0 | 0.5 | 0.0 | 0.0 | 1.6 | 96.4 | 0.0 | 1.6 | 0.0 |
8 | 7.0 | 0.5 | 0.5 | 5.4 | 0.0 | 0.5 | 0.5 | 0.0 | 90.7 | 1.0 | 1.0 |
9 | 8.0 | 0.0 | 0.0 | 0.5 | 2.1 | 0.0 | 1.0 | 0.0 | 0.0 | 94.9 | 1.5 |
10 | 9.0 | 1.0 | 0.0 | 0.0 | 1.5 | 1.0 | 0.5 | 0.0 | 1.0 | 2.0 | 93.1 |
Total error rate (%) in each class
cf.diagpct
Row | lev | errp_pct |
---|---|---|
Float32 | Float64 | |
1 | 0.0 | 0.5 |
2 | 1.0 | 3.5 |
3 | 2.0 | 4.8 |
4 | 3.0 | 8.0 |
5 | 4.0 | 3.0 |
6 | 5.0 | 7.3 |
7 | 6.0 | 3.6 |
8 | 7.0 | 9.3 |
9 | 8.0 | 5.1 |
10 | 9.0 | 6.9 |
Accurary (%)
cf.accpct
Row | typ | accuracy_pct |
---|---|---|
String | Float64 | |
1 | Overall | 94.8 |
2 | Mean by class | 94.8 |
Plotting
plotconf(cf).f
plotconf(cf; cnt = false).f