Prediction - mnist20pct - Confusion matrix

using Jchemo, JchemoData
using JLD2, CairoMakie
using FreqTables 
using CodecZlib   # required since mnist20pct.jld2 is compressed

Data importation

path_jdat = dirname(dirname(pathof(JchemoData)))
db = joinpath(path_jdat, "data/mnist20pct.jld2") 
@load db dat
@names dat
(:Xtrain, :ytrain, :Xtest, :ytest)
Xtrain = dat.Xtrain
ytrain = dat.ytrain
Xtest = dat.Xtest
ytest = dat.ytest
ntrain, p = size(Xtrain)
ntest = nro(Xtest)
ntot = ntrain + ntest
(ntot = ntot, ntrain, ntest)
(ntot = 14000, ntrain = 12000, ntest = 2000)
tab(ytrain)
OrderedCollections.OrderedDict{Float32, Int64} with 10 entries:
  0.0 => 1185
  1.0 => 1348
  2.0 => 1192
  3.0 => 1226
  4.0 => 1168
  5.0 => 1084
  6.0 => 1184
  7.0 => 1253
  8.0 => 1170
  9.0 => 1190
tab(ytest)
OrderedCollections.OrderedDict{Float32, Int64} with 10 entries:
  0.0 => 196
  1.0 => 227
  2.0 => 207
  3.0 => 201
  4.0 => 197
  5.0 => 178
  6.0 => 192
  7.0 => 205
  8.0 => 195
  9.0 => 202

Grey levels 0-255 standardized between 0-1 (not required here but used when fitting deep learning models)

@head Xtrain = Matrix(Xtrain) / 255
3×784 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
... (12000, 784)
@head Xtest = Matrix(Xtest) / 255
3×784 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
... (2000, 784)

Example of one sample (= one unfolded image)

plotsp(Xtrain, 1:p; nsamp = 1, xlabel = "Pixel", ylabel = "Grey level").f

Fitting a Pls-Qda and prediction of the test set

model = plsqda(nlv = 25)
fit!(model, Xtrain, ytrain)
pred = predict(model, Xtest).pred
2000×1 Matrix{Float32}:
 1.0
 0.0
 4.0
 9.0
 6.0
 0.0
 2.0
 7.0
 1.0
 5.0
 ⋮
 6.0
 7.0
 8.0
 1.0
 3.0
 4.0
 9.0
 2.0
 6.0
  • Error rates (proportions)

errp(pred, ytest)  # overall
1×1 Matrix{Float64}:
 0.052
merrp(pred, ytest) # average by class
1×1 Matrix{Float64}:
 0.052147630315370674

Confusion matrix

Using function freqtable of package FreqTables

  • Counts

res = freqtable(ytest, vec(pred))
10×10 Named Matrix{Int64}
Dim1 ╲ Dim2 │ 0.0  1.0  2.0  3.0  4.0  5.0  6.0  7.0  8.0  9.0
────────────┼─────────────────────────────────────────────────
0.0         │ 195    0    0    0    0    0    0    0    1    0
1.0         │   0  219    3    2    0    0    0    0    3    0
2.0         │   1    0  197    1    1    0    2    0    4    1
3.0         │   2    0    2  185    0    3    0    0    8    1
4.0         │   1    0    1    0  191    0    0    2    0    2
5.0         │   1    0    0    6    0  165    1    1    2    2
6.0         │   0    0    1    0    0    3  185    0    3    0
7.0         │   1    1   11    0    1    1    0  186    2    2
8.0         │   0    0    1    4    0    2    0    0  185    3
9.0         │   2    0    0    3    2    1    0    2    4  188
  • Row %

round.(100 * res ./ rowsum(res); digits = 1)
10×10 Named Matrix{Float64}
Dim1 ╲ Dim2 │  0.0   1.0   2.0   3.0   4.0   5.0   6.0   7.0   8.0   9.0
────────────┼───────────────────────────────────────────────────────────
0.0         │ 99.5   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.5   0.0
1.0         │  0.0  96.5   1.3   0.9   0.0   0.0   0.0   0.0   1.3   0.0
2.0         │  0.5   0.0  95.2   0.5   0.5   0.0   1.0   0.0   1.9   0.5
3.0         │  1.0   0.0   1.0  92.0   0.0   1.5   0.0   0.0   4.0   0.5
4.0         │  0.5   0.0   0.5   0.0  97.0   0.0   0.0   1.0   0.0   1.0
5.0         │  0.6   0.0   0.0   3.4   0.0  92.7   0.6   0.6   1.1   1.1
6.0         │  0.0   0.0   0.5   0.0   0.0   1.6  96.4   0.0   1.6   0.0
7.0         │  0.5   0.5   5.4   0.0   0.5   0.5   0.0  90.7   1.0   1.0
8.0         │  0.0   0.0   0.5   2.1   0.0   1.0   0.0   0.0  94.9   1.5
9.0         │  1.0   0.0   0.0   1.5   1.0   0.5   0.0   1.0   2.0  93.1

Using function conf of package Jchemo

cf = conf(pred, ytest)
@names cf
(:cnt, :pct, :A, :Apct, :diagpct, :accpct, :lev)
  • Counts

cf.cnt
10×11 DataFrame
Rowypred_0.0pred_1.0pred_2.0pred_3.0pred_4.0pred_5.0pred_6.0pred_7.0pred_8.0pred_9.0
StringInt64Int64Int64Int64Int64Int64Int64Int64Int64Int64
10.0195000000010
21.0021932000030
32.0101971102041
43.0202185030081
54.0101019100202
65.0100601651122
76.0001003185030
87.01111011018622
98.0001402001853
109.0200321024188
  • Row %

cf.pct
10×11 DataFrame
Rowlevelspred_0.0pred_1.0pred_2.0pred_3.0pred_4.0pred_5.0pred_6.0pred_7.0pred_8.0pred_9.0
StringFloat64Float64Float64Float64Float64Float64Float64Float64Float64Float64
10.099.50.00.00.00.00.00.00.00.50.0
21.00.096.51.30.90.00.00.00.01.30.0
32.00.50.095.20.50.50.01.00.01.90.5
43.01.00.01.092.00.01.50.00.04.00.5
54.00.50.00.50.097.00.00.01.00.01.0
65.00.60.00.03.40.092.70.60.61.11.1
76.00.00.00.50.00.01.696.40.01.60.0
87.00.50.55.40.00.50.50.090.71.01.0
98.00.00.00.52.10.01.00.00.094.91.5
109.01.00.00.01.51.00.50.01.02.093.1
  • Total error rate (%) in each class

cf.diagpct
10×2 DataFrame
Rowleverrp_pct
Float32Float64
10.00.5
21.03.5
32.04.8
43.08.0
54.03.0
65.07.3
76.03.6
87.09.3
98.05.1
109.06.9
  • Accurary (%)

cf.accpct
2×2 DataFrame
Rowtypaccuracy_pct
StringFloat64
1Overall94.8
2Mean by class94.8
  • Plotting

plotconf(cf).f
plotconf(cf; cnt = false).f