GitXplorerGitXplorer
m

flashlight

public
22 stars
4 forks
2 issues

Commits

List of commits on branch main.
Verified
54e44b44e3ce90917bd7aded807d964c19fc4c04

Merge pull request #64 from mayer79/readme-batches-update

mmayer79 committed 6 months ago
Verified
9c056545c13c455131e0b2b869c181d589c1ade0

Update batches

mmayer79 committed 6 months ago
Verified
6017b8884b1c94777b3055500cfedb5d4765e0e6

Merge pull request #63 from mayer79/FIX-check-two-by

mmayer79 committed a year ago
Unverified
84315131e6f1b5e4337864d63539d3ad69084a28

fix check condition for more than one BY variables

mmayer79 committed a year ago
Verified
79bdbe53f76537481da5fceea94ed559ba1865f9

Fix codecov

mmayer79 committed a year ago
Verified
bdb5ba0b346341b3ab088265b07915aa35ad71ee

Merge pull request #61 from mayer79/fix_print_problem

mmayer79 committed a year ago

README

The README file for this repository.

{flashlight}

R-CMD-check Codecov test coverage CRAN_Status_Badge

Overview

The goal of this package is shed light on black box machine learning models.

The main props of {flashlight}:

  1. It is simple, yet flexible.
  2. It offers model agnostic tools like model performance, variable importance, global surrogate models, ICE profiles, partial dependence, ALE, and further effects plots, scatter plots, interaction strength, and variable contribution breakdown/SHAP for single observations.
  3. It allows to assess multiple models side-by-side.
  4. It supports "group by" operations.
  5. It works with case weights.

Currently, models with numeric or binary response are supported.

Installation

# From CRAN
install.packages("flashlight")

# Development version
devtools::install_github("mayer79/flashlight")

Usage

Let's start with an iris example. For simplicity, we do not split the data into training and testing/validation sets.

library(ggplot2)
library(MetricsWeighted)
library(flashlight)

fit_lm <- lm(Sepal.Length ~ ., data = iris)

# Make explainer object
fl_lm <- flashlight(
  model = fit_lm, 
  data = iris, 
  y = "Sepal.Length", 
  label = "lm",               
  metrics = list(RMSE = rmse, `R-squared` = r_squared)
)

Performance

fl_lm |> 
  light_performance() |> 
  plot(fill = "darkred") +
  labs(x = element_blank(), title = "Performance on training data")

fl_lm |> 
  light_performance(by = "Species") |> 
  plot(fill = "darkred") +
  ggtitle("Performance split by Species")

Performance Grouped

Permutation importance regarding first metric

Error bars represent standard errors, i.e., the uncertainty of the estimated importance.

fl_lm |>
  light_importance(m_repetitions = 4) |> 
  plot(fill = "darkred") +
  labs(title = "Permutation importance", y = "Increase in RMSE")

ICE curves for Petal.Width

fl_lm |> 
  light_ice("Sepal.Width", n_max = 200) |> 
  plot(alpha = 0.3, color = "chartreuse4") +
  labs(title = "ICE curves for 'Sepal.Width'", y = "Prediction")

fl_lm |> 
  light_ice("Sepal.Width", n_max = 200, center = "middle") |> 
  plot(alpha = 0.3, color = "chartreuse4") +
  labs(title = "c-ICE curves for 'Sepal.Width'", y = "Prediction (centered)")

Performance Grouped

PDPs

fl_lm |> 
  light_profile("Sepal.Width", n_bins = 40) |> 
  plot() +
  ggtitle("PDP for 'Sepal.Width'")

fl_lm |> 
  light_profile("Sepal.Width", n_bins = 40, by = "Species") |> 
  plot() +
  ggtitle("Same grouped by 'Species'")

Performance Grouped

2D PDP

fl_lm |> 
  light_profile2d(c("Petal.Width", "Petal.Length")) |> 
  plot()

ALE

fl_lm |> 
  light_profile("Sepal.Width", type = "ale") |> 
  plot() +
  ggtitle("ALE plot for 'Sepal.Width'")

Different profile plots in one

fl_lm |> 
  light_effects("Sepal.Width") |> 
  plot(use = "all") +
  ggtitle("Different types of profiles for 'Sepal.Width'")

Variable contribution breakdown for single observation

fl_lm |> 
  light_breakdown(new_obs = iris[1, ]) |> 
  plot()

Global surrogate tree

fl_lm |> 
  light_global_surrogate() |> 
  plot()

Multiple models

Multiple flashlights can be combined to a multiflashlight.

library(rpart)

fit_tree <- rpart(
  Sepal.Length ~ ., 
  data = iris, 
  control = list(cp = 0, xval = 0, maxdepth = 5)
)

# Make explainer object
fl_tree <- flashlight(
  model = fit_tree, 
  data = iris, 
  y = "Sepal.Length", 
  label = "tree",               
  metrics = list(RMSE = rmse, `R-squared` = r_squared)
)

# Combine with other explainer
fls <- multiflashlight(list(fl_tree, fl_lm))

fls |> 
  light_performance() |> 
  plot(fill = "chartreuse4") +
  labs(x = "Model", title = "Performance")

fls |> 
  light_profile("Petal.Length", n_bins = 40, by = "Species") |> 
  plot() +
  ggtitle("PDP by Species")

Performance Grouped

More

Check out the vignette for more information and important references.