воскресенье, 30 декабря 2018 г.

Интерпретация моделей машинного обучения при помощи LIME

Очень часто для применения моделей машинного обучения недостаточно знать, какое качество эти модели демонстрируют в среднем. Также необходимо понимать причины, почему модель выдает то или иное предсказание для каждого конкретного примера. Для решения этой задачи применительно к любой модели был предложен алгоритм LIME (Local Interpretable Model-Agnostic Explanations - локально интерпретируемые объяснения, не зависящие от устройства модели), описанный в публикации [https://arxiv.org/abs/1602.04938]. Оригинальную реализацию на языке Python можно найти по ссылке [https://github.com/marcotcr/lime], также доступен порт на языке R: [https://github.com/thomasp85/lime].
Авторы алгоритма LIME сформулировали следующие характеристики, желаемые для подобных объясняющих алгоритмов:
  1. Интерпретируемость - способность объяснить предсказания модели понятным человеку образом.
  2. Локальная верность - в окрестностях отдельного наблюдения алгоритм должен вести себя так же, как объясняемая им модель; требуется находить компромисс между локальной верностью и интерпретируемостью.
  3. Независимость от внутреннего устройства модели - объясняемая модель может быть использована как “черный ящик”.
  4. Объяснение работы модели в целом путем рассмотрения ее предсказаний на наборе наблюдений.
Сама процедура алгоритма LIME выглядит следующим образом:
  1. Выбрать интерпретируемое представление данных, которые подаются на вход объясняемой модели. Мы обозначим исходное признаковое описание наблюдения как \(x \in \mathbb{R}^{d}\) и интерпретируемое представление как \(x^{'} \in \{0, 1\}^{d^{'}}\), т.е. как бинарный вектор. Понять, как это выглядит на практике, можно на примере изображений: исходные признаки для цветной (RGB) картинки представляют собой трехмерный массив с интенсивностями каждого из трех цветовых каналов для кажого пикселя, а интерпретируемым представлением будет бинарный вектор определенной длины, который служит индикатором наличия некоторой группы пикселей. Еще проще все выглядит для обработки текста, представленного в виде “мешка слов”: никакой переход к новым признакам не требуется, данные и так представлены в виде бинарных векторов. А вот для обработки TF-IDF представления переход к бинарным векторами придется произвести.
  2. Из окрестностей (имеется ввиде близость в пространстве признаков) наблюдения \(x^{'}\) взять искусственные наблюдения, у которых случайным образом сохраняется часть ненулевых значений из \(x^{'}\), а остальные значения обнуляются. Например, для вектора {0, 1, 1, 1, 1} такое искусственное наблюдение \(z^{'} \in \{0, 1\}^{d^{'}}\) может выглядеть как {0, 1, 0, 1, 0}.
  3. С помощью процедуры K-Lasso отобрать K бинарных признаков, которые будут далее использоваться (иными словами, длина вектора \(z^{'}\) для каждого наблюдения будет равна K). Обучается модель лассо-регрессии с признаками \(z^{'}\) и целевой переменной \(f(z)\).
  4. Для искусственных наблюдений восстановить исходное признаковое описание \(z \in \mathbb{R}^{d}\).
  5. Получить \(f(z)\), то есть прогноз объясняемой модели на искусственном наблюдении; использовать полученные прогнозы как метки, т.е. как переменную отклика в объясняющей модели.
  6. Построить объясняющую модель - например, линейную вида \(g(z^{'})=w_{g} \cdot z^{'}\) с локально взвешенной квадратичной функцией потерь (чтобы штраф от удаленных наблюдений был меньше, чем от наблюдений, находящихся ближе всего к объясняемому наблюдению).
  7. Веса в полученной линейной модели интерпретируются как оценки важности соответствующих признаков в исходном признаковом описании.
Разумеется, использование низкоразмерных представлений имеет свои ограничения. Например, описанным способом нельзя объяснить работу нейронной сети, которая умеет классифицировать изображения по их тону (в этом примере предсказание не зависит от отдельных групп пикселей). Также описанная процедура довольно затратна с вычислительной точки зрения: авторы оригинальной публикации сообщают о том, что объяснение одного предсказания при классификации изображения с использованием архитектуры Inception занимает 10 минут (но не указаны ни используемые вычислительные мощности, ни размеры изображения).
Продемонстрируем работу алгоритма LIME на примере табличных данных и модели “случайного леса”, используя реализацию для языка R. Установим и загрузим соответствующий пакет:
# install.packages("lime")
library(lime)
Используем классический набор данных, содержащий информацию о недвижимости Бостона, и модель случайного леса при помощи пакета caret. Целевой переменной является медианная цена дома (количественная переменная), то есть решается задача регрессии.
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
library(MASS)

df <- Boston

model <- train(medv ~ ., data = df[1:500, ], method = "ranger")
За создание функции, объясняющей модель, отвечает функция lime(). Она принимает на вход данные (в виде таблицы, символьного вектора или изображения) и объект с объясняемой моделью, а также ряд дополнительных параметров в зависимости от формата входных данных.
explainer <- lime(df, model)
Функция explain выполняет объяснение предсказаний для отдельных наблюдений. Помимо таблицы или другого объекта, содержащего наблюдения, важными параметрами являются n_features (число признаков, используемых для объяснения) и n_permutations (число генерируемых искусственных наблюдений). Параметр feature_select задает способ отбора указанного числа признаков: по умолчанию используется либо метод последовательного включения предикторов (при n_features <= 6), либо отбираются предикторы с наибольшими абсолютными значениями весов в модели гребневой регрессии.
explanation <- explain(df[501:504, ], 
                       explainer, 
                       n_features = 6,
                       n_permutations = 1000)
head(explanation)
##   model_type case  model_r2 model_intercept model_prediction feature
## 1 regression  501 0.1041348        26.89184         20.52547    crim
## 2 regression  501 0.1041348        26.89184         20.52547      rm
## 3 regression  501 0.1041348        26.89184         20.52547   lstat
## 4 regression  501 0.1041348        26.89184         20.52547 ptratio
## 5 regression  501 0.1041348        26.89184         20.52547     nox
## 6 regression  501 0.1041348        26.89184         20.52547      zn
##   feature_value feature_weight           feature_desc
## 1       0.22438    -0.09725208  0.082 < crim <= 0.257
## 2       6.02700    -2.73938740      5.89 < rm <= 6.21
## 3      14.33000    -2.65658156 11.36 < lstat <= 16.96
## 4      19.20000    -0.93032567 19.1 < ptratio <= 20.2
## 5       0.58500     0.41714729   0.538 < nox <= 0.624
## 6       0.00000    -0.35996841             zn <= 12.5
##                                                                                                                                   data
## 1 0.22438, 0.00000, 9.69000, 0.00000, 0.58500, 6.02700, 79.70000, 2.49820, 6.00000, 391.00000, 19.20000, 396.90000, 14.33000, 16.80000
## 2 0.22438, 0.00000, 9.69000, 0.00000, 0.58500, 6.02700, 79.70000, 2.49820, 6.00000, 391.00000, 19.20000, 396.90000, 14.33000, 16.80000
## 3 0.22438, 0.00000, 9.69000, 0.00000, 0.58500, 6.02700, 79.70000, 2.49820, 6.00000, 391.00000, 19.20000, 396.90000, 14.33000, 16.80000
## 4 0.22438, 0.00000, 9.69000, 0.00000, 0.58500, 6.02700, 79.70000, 2.49820, 6.00000, 391.00000, 19.20000, 396.90000, 14.33000, 16.80000
## 5 0.22438, 0.00000, 9.69000, 0.00000, 0.58500, 6.02700, 79.70000, 2.49820, 6.00000, 391.00000, 19.20000, 396.90000, 14.33000, 16.80000
## 6 0.22438, 0.00000, 9.69000, 0.00000, 0.58500, 6.02700, 79.70000, 2.49820, 6.00000, 391.00000, 19.20000, 396.90000, 14.33000, 16.80000
##   prediction
## 1   19.93304
## 2   19.93304
## 3   19.93304
## 4   19.93304
## 5   19.93304
## 6   19.93304
Полученные результаты можно представить графически:
plot_features(explanation)

Обратите внимание на следующий ожидаемый и интуитивно понятный факт: для трех наблюдений значения предиктора rm (среднее количество комнат) в диапазоне от 5,89 до 6,21 или от 6,21 до 6,62 было ассоциировано с меньшим значением целевой переменной, в то время как значения >6,62, наоборот, указывает на большую медианную цену.

Комментариев нет:

Отправить комментарий