In this blog post you’ll learn about the TrustyAI explainability library and how to use it in order to provide explanations of “predictions” generated by decision services and plain machine learning models.
The need for explainability
Nowadays AI based systems and decision services are widely used in industry in a wide range of domains, like financial services and health care. Such large adoption poses some concerns about how much we, as humans, can rely on these systems when they are involved in taking important decisions. Are such services fair when making decisions ? What are such systems giving importance to, when providing recommendations in supporting decision makers ? All in all, can we trust them ?
The above issues particularly affect machine learning techniques like deep learning which, more than others, are considered “opaque” as they’re hard to interpret and understand. By just looking at the inner workings (e.g. neurons’ activations) of neural networks it is not easy to guess what the network is giving more importance to, when performing a prediction. Say an AI system to detect SARS-CoV-2 in chest x-ray images predicts Covid-19 is present in one such image, can a doctor trust that enough to decide upon further treatment? If the system also highlights portions of the image that are responsible for the final prediction, a doctor can more easily double check and decide if what the system is basing its output makes sense from the clinical perspective too and decide whether the prediction can be trusted. The highlighted portions of such an image are an example of an explanation, an interface between the human and the system, a human understandable description of an (AI) system prediction internals. Depending on the system (and user) at hand, different kinds of explanations might work best. For example a bank having an automated credit card request approval system may need to provide explanations too. Imagine that the user compiles a form with information about its financial situation and family. If a user gets its credit card approval request rejected, one might want to know the rationale behind the rejection.
An explanation of a credit card request having been rejected
A useful explanation in this case would show how the input data typed into the form by the user influenced the decision (rejection). As an example in the above picture the explanation is telling us that the data about the user age, the fact that it owns (or doesn’t own) a car and its income had a negative impact on the decision (hence being the major causes for rejection). On the other hand other information about the number of children and the fact that it owns a realty (or doesn’t own it) might have favoured an approval.
Explainable Artificial Intelligence (aka XAI or Explainability) is a research area of Artificial Intelligence that focuses on trying to make opaque AI systems more interpretable, understandable for different stakeholders, in order to make them trustworthy and allow humans to trust them, especially in sensitive processes that affect humans in their real lives. In the remainder of this post we’ll consider those systems used to make human impacting decisions (often based on AI techniques) as black boxes; you’ll learn about a way to provide explanations for their outputs. Note also that the insights provided by explanations are helpful in identifying what to fix in a system. For example you might find out that the credit card approval system is unnecessarily negatively biased towards users who own a car.
Local Interpretable Model agnostic Explanations
A quite well known explainability technique is called Local Interpretable Model agnostic Explanations (aka LIME). LIME provides a tool to explain the predictions of a black box classifier (but also works with regression models). It can be used seamlessly on any type of classifier; for example, a text classifier would get explanations based on which words were contributing most to the prediction output, an image classifier would get explanations based on which image patches were contributing most to the prediction output.
The way LIME generates explanation is by perturbing inputs, for example randomly dropping features from an input, and performing classification with such slightly modified inputs. The outputs and the sparse version of the perturbed inputs (e.g. a binary feature presence vector) are used to train a linear regression model, local with respect to the original input. Once training is done, sparse feature weights are analyzed in order to correlate them to the predictions, for example positive outputs are correlated to weights whose value is larger).
In summary you can use LIME to generate explanations like the one from the picture above, that explains a single output (the credit card approval system rejection). This kind of explanation is usually referred to as a local explanation. Explanations that provide feature level importance scores on a single prediction are usually referred to as saliency explanations.
TrustyAI Explainability
Kogito is a cloud-native business automation technology for building cloud-ready business applications, with a focus on hybrid cloud scenarios.
Kogito allows you to build and deploy a cloud application as the composition of different domain specific services.
In this context Kogito also provides a set of addons that you can include in a Kogito application.
The TrustyAI initiative in Kogito provides a set of services that aim to provide monitoring and explainability capabilities to such applications.
We implemented an explainability library to deliver such services.
You can use our own implementation of LIME by instantiating a LimeExplainer
.
int noOfSamples = 100; // number of perturbed samples to be generated
int noOfPerturbations = 1; // min number of perturbations to be performed for each input
LimeExplainer limeExplainer = new LimeExplainer(noOfSamples, noOfPerturbations);
LimeExplainer
takes an input PredictionProvider
and a Prediction
to be explained and (asynchronously) generates a saliency map (Map<String, Saliency>
). The saliency map contains a Saliency
object for each single output produced in a single Prediction
(e.g. it is common for DMN models to have multiple outputs). A Saliency
contains a FeatureImportance
object for each Feature
seen in PredictionInput
attached to the input Prediction
. A FeatureImportance
contains the score (the importance) attached to the related Feature
.
List<Feature> features = new ArrayList<>();
...
PredictionInput input = new PredictionInput(features);
PredictionProvider predictionProvider = ...
PredictionOutput output = predictionProvider.predictAsync(List.of(input))
.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())
.get(0);
Prediction prediction = new Prediction(input, output);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model)
.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
With respect to the original LIME implementation our own implementation doesn’t require training data in case the input is in the form of tabular data. Actually our LIME implementation works seamlessly with plain text, tabular data and complex nested objects.
Explaining credit card approval decision service
Suppose we have a DMN model for the credit card approval system we discussed in the beginning of this post. Leveraging Kogito we can instantiate such a DecisionModel
from its definition (a file stored under /dmn/cc.dmn).
DMNRuntime dmnRuntime = DMNKogito.createGenericDMNRuntime(new InputStreamReader(getClass().getResourceAsStream("/dmn/cc.dmn")));
DecisionModel decisionModel = new DmnDecisionModel(dmnRuntime, "kiegroup", "cc");
Now suppose this DMN model takes the inputs we have discussed previously: for the sake of simplicity we focus on income, number of children, whether the user owns a car, whether the user owns a realty and the user age. We therefore build a PredictionInput
using those five Features
.
List<Feature> features = new ArrayList<>();
features.add(FeatureFactory.newNumericalFeature("income", 1000));
features.add(FeatureFactory.newNumericalFeature("noOfChildren", 3));
features.add(FeatureFactory.newBooleanFeature("ownsCar", true));
features.add(FeatureFactory.newBooleanFeature("ownsRealty", true));
features.add(FeatureFactory.newNumericalFeature("age", 18));
PredictionInput input = new PredictionInput(features);
We wrap the DMN model with the DecisionModelWrapper
convenience class provided by our library and run the model on our sample input.
PredictionProvider model = new DecisionModelWrapper(decisionModel);
PredictionOutput output = model.predictAsync(List.of(input))
.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())
.get(0);
Our DMN model only generates one output named Approved, therefore the PredictionOutput
has only one underlying Output
, the approval / rejection result, as a boolean value. In this case we know it is false, as the credit card request was rejected.
boolean approved = output.getOutputs().get(0).getValue().asBoolean(); // <-- rejected
We’re now ready to pack PredictionInput
and PredictionOutput
into a Prediction
and generate the explanation using LIME.
Prediction prediction = new Prediction(input, output);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model)
.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
Now we can see what are the top n Features
that were important for having our request rejected.
List<FeatureImportance> negativeFeatures = saliencyMap.get("Approved").getNegativeFeatures(2);
for (FeatureImportance featureImportance : negativeFeatures) {
System.out.println(featureImportance.getFeature().getName() + ": " + featureImportance.getScore());
}
age: -0.61
ownsCar: -0.55
Explaining language detection machine learning model
Let’s consider another example of explaining a machine learning classifier trained to detect the language of an input text. For this sake we load a pretrained Apache OpenNLP LanguageDetector model (from a file called langdetect-183.bin).
InputStream is = new FileInputStream("langdetect-183.bin");
LanguageDetectorModel languageDetectorModel = new LanguageDetectorModel(is);
LanguageDetector languageDetector = new LanguageDetectorME(languageDetectorModel);
In order to make it possible to use it in the explainability API we wrap it as a PredictionProvider
. We assume that all input Features
are textual and we combine them together into a single String to be passed to the underlying LanguageDetector
.
PredictionProvider model = inputs -> CompletableFuture.supplyAsync(() -> {
List<PredictionOutput> results = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
StringBuilder builder = new StringBuilder();
for (Feature f : predictionInput.getFeatures()) {
if (builder.length() > 0) {
builder.append(' ');
}
builder.append(f.getValue().asString());
}
Language language = languageDetector.predictLanguage(builder.toString());
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("lang", Type.TEXT, new Value<>(language.getLang()), language.getConfidence())));
results.add(predictionOutput);
}
return results;
});
Now let’s take an example input text, for example “italiani spaghetti pizza mandolino”, and create a new PredictionInput
to be passed to the model to obtain the Prediction
. Note that we have a single String of text, but since we’d like to understand what words influence more the detected language, we create a full text Feature so that each word in the input text becomes a separate Feature
(hence the actual input will contain four Features
, one for each word).
String inputText = "italiani spaghetti pizza mandolino";
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newFulltextFeature("text", inputText));
The detected language is Italian, with a confidence score of ~0.03.
PredictionOutput output = model.predictAsync(List.of(input)).get().get(0);
Output output = output.getOutputs().get(0);
System.out.println(output.getValue().asString() + ": " + output.getScore();
ita: 0.029
Finally we can explain the Prediction
using LimeExplainer
and iterate through the most positively influencing Features
.
Prediction prediction = new Prediction(input, output);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model)
.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
for (FeatureImportance featureImportance : saliencyMap.get("lang").getPositiveFeatures(2)) {
System.out.println(featureImportance.getFeature().getName() + ": " + featureImportance.getScore());
}
spaghetti: 0.021
pizza: 0.019