iris_split <- initial_split(iris, prop = 0.7, strata = Species)
iris_tr <- training(iris_split)
iris_tst <- testing(iris_split)


iris_rec <- recipe(Species ~ ., data = iris) 

Model Spec

multinom_sim <- multinom_reg(engine = "nnet")


iris_wf <- workflow(iris_rec, multinom_sim)

Fit & Predict

Fit Model

iris_fit <- fit(iris_wf, data = iris_tr)
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: multinom_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ───────────────────────────────────────────────────────────────────────
nnet::multinom(formula = ..y ~ ., data = data, trace = FALSE)

           (Intercept) Sepal.Length Sepal.Width Petal.Length Petal.Width
versicolor    53.35203     3.845927   -31.93694     10.02870   -3.275975
virginica    -57.85131   -23.541980   -41.71563     61.61759   31.272313

Residual Deviance: 0.1554973 
AIC: 20.1555 


iris_res <- broom::augment(iris_fit, new_data = iris_tst)
# A tibble: 6 × 9
  Sepal.Length Sepal.Width Petal.Length Petal.Width Species .pred_class
         <dbl>       <dbl>        <dbl>       <dbl> <fct>   <fct>      
1          4.9         3            1.4         0.2 setosa  setosa     
2          5           3.6          1.4         0.2 setosa  setosa     
3          5.4         3.7          1.5         0.2 setosa  setosa     
4          4.8         3            1.4         0.1 setosa  setosa     
5          5.7         4.4          1.5         0.4 setosa  setosa     
6          5.4         3.9          1.3         0.4 setosa  setosa     
# ℹ 3 more variables: .pred_setosa <dbl>, .pred_versicolor <dbl>,
#   .pred_virginica <dbl>

Multi-class Performace Matrix

Confusion Matrix

conf_mat(iris_res, truth = Species, estimate = .pred_class)
Prediction   setosa versicolor virginica
  setosa         14          0         0
  versicolor      1         13         0
  virginica       0          2        15

Count number of observed class

class_totals <- iris_res |> 
  count(Species, name = "totals") %>% 
  mutate(class_wts = totals / sum(totals))

# A tibble: 3 × 3
  Species    totals class_wts
  <fct>       <int>     <dbl>
1 setosa         15     0.333
2 versicolor     15     0.333
3 virginica      15     0.333
cell_counts <- 
  iris_res %>% 
  group_by(Species, .pred_class) %>% 
  count() %>% 

# A tibble: 5 × 3
  Species    .pred_class     n
  <fct>      <fct>       <int>
1 setosa     setosa         14
2 setosa     versicolor      1
3 versicolor versicolor     13
4 versicolor virginica       2
5 virginica  virginica      15
# Compute the four sensitivities using 1-vs-all
one_versus_all <- 
  cell_counts %>% 
  filter(Species == .pred_class) %>% 
  full_join(class_totals, by = "Species") %>% 
  mutate(sens = n / totals)

# A tibble: 3 × 6
  Species    .pred_class     n totals class_wts  sens
  <fct>      <fct>       <int>  <int>     <dbl> <dbl>
1 setosa     setosa         14     15     0.333 0.933
2 versicolor versicolor     13     15     0.333 0.867
3 virginica  virginica      15     15     0.333 1    
# Three different estimates:
one_versus_all %>% 
    macro = mean(sens), 
    macro_wts = weighted.mean(sens, class_wts),
    micro = sum(n) / sum(totals)
# A tibble: 1 × 3
  macro macro_wts micro
  <dbl>     <dbl> <dbl>
1 0.933     0.933 0.933
  • Macro-averaging: computes a set of one-versus-all metrics using the standard two-class statistics. These are averaged.
  • Macro-weighted averaging: does the same but the average is weighted by the number of samples in each class.
  • Micro-averaging: computes the contribution for each class, aggregates them, then computes a single metric from the aggregates.

