using Jchemo, JchemoData using JLD2, CairoMakie using FreqTables using CodecZlib # required since mnist20pct.jld2 is compressed
path_jdat = dirname(dirname(pathof(JchemoData))) db = joinpath(path_jdat, "data/mnist20pct.jld2") @load db dat @names dat
(:Xtrain, :ytrain, :Xtest, :ytest)
Xtrain = dat.Xtrain ytrain = dat.ytrain Xtest = dat.Xtest ytest = dat.ytest ntrain, p = size(Xtrain) ntest = nro(Xtest) ntot = ntrain + ntest (ntot = ntot, ntrain, ntest)
(ntot = 14000, ntrain = 12000, ntest = 2000)
tab(ytrain)
OrderedCollections.OrderedDict{Float32, Int64} with 10 entries:
0.0 => 1185
1.0 => 1348
2.0 => 1192
3.0 => 1226
4.0 => 1168
5.0 => 1084
6.0 => 1184
7.0 => 1253
8.0 => 1170
9.0 => 1190
tab(ytest)
OrderedCollections.OrderedDict{Float32, Int64} with 10 entries:
0.0 => 196
1.0 => 227
2.0 => 207
3.0 => 201
4.0 => 197
5.0 => 178
6.0 => 192
7.0 => 205
8.0 => 195
9.0 => 202
Grey levels 0-255 are standardized between 0-1 (the standardization is not required here but used when fitting deep learning models).
@head Xtrain = Matrix(Xtrain) / 255
3×784 Matrix{Float32}:
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
... (12000, 784)
@head Xtest = Matrix(Xtest) / 255
3×784 Matrix{Float32}:
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
... (2000, 784)
Example of one observation (= one unfolded image):
plotsp(Xtrain, 1:p; nsamp = 1, xlabel = "Pixel", ylabel = "Grey level").f
model = plsqda(nlv = 25) fit!(model, Xtrain, ytrain) pred = predict(model, Xtest).pred
2000×1 Matrix{Float32}:
1.0
0.0
4.0
9.0
6.0
0.0
2.0
7.0
1.0
5.0
⋮
6.0
7.0
8.0
1.0
3.0
4.0
9.0
2.0
6.0
Error rates (proportions)
errp(pred, ytest) # overall
1×1 Matrix{Float64}:
0.052
merrp(pred, ytest) # average by class
1×1 Matrix{Float64}:
0.052147630315370674
Using function freqtable of package FreqTables
Counts
res = freqtable(ytest, vec(pred))
10×10 Named Matrix{Int64}
Dim1 ╲ Dim2 │ 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0
────────────┼─────────────────────────────────────────────────
0.0 │ 195 0 0 0 0 0 0 0 1 0
1.0 │ 0 219 3 2 0 0 0 0 3 0
2.0 │ 1 0 197 1 1 0 2 0 4 1
3.0 │ 2 0 2 185 0 3 0 0 8 1
4.0 │ 1 0 1 0 191 0 0 2 0 2
5.0 │ 1 0 0 6 0 165 1 1 2 2
6.0 │ 0 0 1 0 0 3 185 0 3 0
7.0 │ 1 1 11 0 1 1 0 186 2 2
8.0 │ 0 0 1 4 0 2 0 0 185 3
9.0 │ 2 0 0 3 2 1 0 2 4 188
Row %
round.(100 * res ./ rowsum(res); digits = 1)
10×10 Named Matrix{Float64}
Dim1 ╲ Dim2 │ 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0
────────────┼───────────────────────────────────────────────────────────
0.0 │ 99.5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.5 0.0
1.0 │ 0.0 96.5 1.3 0.9 0.0 0.0 0.0 0.0 1.3 0.0
2.0 │ 0.5 0.0 95.2 0.5 0.5 0.0 1.0 0.0 1.9 0.5
3.0 │ 1.0 0.0 1.0 92.0 0.0 1.5 0.0 0.0 4.0 0.5
4.0 │ 0.5 0.0 0.5 0.0 97.0 0.0 0.0 1.0 0.0 1.0
5.0 │ 0.6 0.0 0.0 3.4 0.0 92.7 0.6 0.6 1.1 1.1
6.0 │ 0.0 0.0 0.5 0.0 0.0 1.6 96.4 0.0 1.6 0.0
7.0 │ 0.5 0.5 5.4 0.0 0.5 0.5 0.0 90.7 1.0 1.0
8.0 │ 0.0 0.0 0.5 2.1 0.0 1.0 0.0 0.0 94.9 1.5
9.0 │ 1.0 0.0 0.0 1.5 1.0 0.5 0.0 1.0 2.0 93.1
Using function conf of package Jchemo
cf = conf(pred, ytest) @names cf
(:cnt, :pct, :A, :Apct, :diagpct, :accpct, :lev)
Counts
cf.cnt
| Row | y | pred_0.0 | pred_1.0 | pred_2.0 | pred_3.0 | pred_4.0 | pred_5.0 | pred_6.0 | pred_7.0 | pred_8.0 | pred_9.0 |
|---|---|---|---|---|---|---|---|---|---|---|---|
| String | Int64 | Int64 | Int64 | Int64 | Int64 | Int64 | Int64 | Int64 | Int64 | Int64 | |
| 1 | 0.0 | 195 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
| 2 | 1.0 | 0 | 219 | 3 | 2 | 0 | 0 | 0 | 0 | 3 | 0 |
| 3 | 2.0 | 1 | 0 | 197 | 1 | 1 | 0 | 2 | 0 | 4 | 1 |
| 4 | 3.0 | 2 | 0 | 2 | 185 | 0 | 3 | 0 | 0 | 8 | 1 |
| 5 | 4.0 | 1 | 0 | 1 | 0 | 191 | 0 | 0 | 2 | 0 | 2 |
| 6 | 5.0 | 1 | 0 | 0 | 6 | 0 | 165 | 1 | 1 | 2 | 2 |
| 7 | 6.0 | 0 | 0 | 1 | 0 | 0 | 3 | 185 | 0 | 3 | 0 |
| 8 | 7.0 | 1 | 1 | 11 | 0 | 1 | 1 | 0 | 186 | 2 | 2 |
| 9 | 8.0 | 0 | 0 | 1 | 4 | 0 | 2 | 0 | 0 | 185 | 3 |
| 10 | 9.0 | 2 | 0 | 0 | 3 | 2 | 1 | 0 | 2 | 4 | 188 |
Row %
cf.pct
| Row | levels | pred_0.0 | pred_1.0 | pred_2.0 | pred_3.0 | pred_4.0 | pred_5.0 | pred_6.0 | pred_7.0 | pred_8.0 | pred_9.0 |
|---|---|---|---|---|---|---|---|---|---|---|---|
| String | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | |
| 1 | 0.0 | 99.5 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.5 | 0.0 |
| 2 | 1.0 | 0.0 | 96.5 | 1.3 | 0.9 | 0.0 | 0.0 | 0.0 | 0.0 | 1.3 | 0.0 |
| 3 | 2.0 | 0.5 | 0.0 | 95.2 | 0.5 | 0.5 | 0.0 | 1.0 | 0.0 | 1.9 | 0.5 |
| 4 | 3.0 | 1.0 | 0.0 | 1.0 | 92.0 | 0.0 | 1.5 | 0.0 | 0.0 | 4.0 | 0.5 |
| 5 | 4.0 | 0.5 | 0.0 | 0.5 | 0.0 | 97.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 |
| 6 | 5.0 | 0.6 | 0.0 | 0.0 | 3.4 | 0.0 | 92.7 | 0.6 | 0.6 | 1.1 | 1.1 |
| 7 | 6.0 | 0.0 | 0.0 | 0.5 | 0.0 | 0.0 | 1.6 | 96.4 | 0.0 | 1.6 | 0.0 |
| 8 | 7.0 | 0.5 | 0.5 | 5.4 | 0.0 | 0.5 | 0.5 | 0.0 | 90.7 | 1.0 | 1.0 |
| 9 | 8.0 | 0.0 | 0.0 | 0.5 | 2.1 | 0.0 | 1.0 | 0.0 | 0.0 | 94.9 | 1.5 |
| 10 | 9.0 | 1.0 | 0.0 | 0.0 | 1.5 | 1.0 | 0.5 | 0.0 | 1.0 | 2.0 | 93.1 |
Total error rate (%) in each class
cf.diagpct
| Row | lev | errp_pct |
|---|---|---|
| Float32 | Float64 | |
| 1 | 0.0 | 0.5 |
| 2 | 1.0 | 3.5 |
| 3 | 2.0 | 4.8 |
| 4 | 3.0 | 8.0 |
| 5 | 4.0 | 3.0 |
| 6 | 5.0 | 7.3 |
| 7 | 6.0 | 3.6 |
| 8 | 7.0 | 9.3 |
| 9 | 8.0 | 5.1 |
| 10 | 9.0 | 6.9 |
Accurary (%)
cf.accpct
| Row | typ | accuracy_pct |
|---|---|---|
| String | Float64 | |
| 1 | Overall | 94.8 |
| 2 | Mean by class | 94.8 |
Plotting
plotconf(cf).f
plotconf(cf; cnt = false).f