Local to global – Using LIME for feature importance

In a previous blog post we have discussed how to leverage LIME to get more insights about specific predictions generated by a black box decision service. In fact LIME is mostly used to find out which input features where most important for the generation of a particular output, according to that decision service.

Such explanations are called local because they refer to the behavior of a decision service locally, with respect to a certain, specified input.

In this post we’ll see how to use TrustyAI LimeExplainer in order to generate an explanation for a decision service "as a whole". One such explanation kind is called feature importance.

Such explanations help users understand the overall importance of each input feature on a global scale, more informally they answer the question "what does this decision service gives more importance to when taking decisions, generally?".

A very simple idea to generalize local LIME explanations to get global feature importance is to obtain local predictions for a large number of predictions and then average the scores assigned to each feature across all the local explanations to produce a global explanation. You can do this with TrustyAI AggregatedLimeExplainer.

AggregatedLimeExplainer globalExplainer = new AggregatedLimeExplainer();

Of course we need to have access to the decision service we want to explain, exposed as usually in TrustyAI via a PredictionProvider. For the sake of this post we want to explain a PMML regression model trained on the Iris dataset.

We define a getModel() method that wraps such a model as a PredictionProvider.

PredictionProvider getModel() {
    return inputs -> CompletableFuture.supplyAsync(() -> {
        List<PredictionOutput> outputs = new ArrayList<>();
        for (PredictionInput input1 : inputs) {
            List<Feature> features1 = input1.getFeatures();
            LogisticRegressionIrisDataExecutor pmmlModel = new 
                    LogisticRegressionIrisDataExecutor(
                        features1.get(0).getValue().asNumber(), 
                        features1.get(1).getValue().asNumber(),
                        features1.get(2).getValue().asNumber(), 
                        features1.get(3).getValue().asNumber());
            PMML4Result result = pmmlModel.execute(logisticRegressionIrisRuntime);
            String species = result.getResultVariables().get("Species").toString();
            double score = Double.parseDouble(
                result.getResultVariables().get("Probability_" + 
                species).toString());
            PredictionOutput predictionOutput = new PredictionOutput(List.of(new 
                Output("species", Type.TEXT, new Value(species), score)));
            outputs.add(predictionOutput);
        }
        return outputs;
    });
}

We feed such an explainer with a collection of existing predictions in order to get our feature importance explanation.

List<PredictionInput> samples = getSamples();
List<PredictionOutput> predictionOutputs = model.predictAsync(samples).get();
List<Prediction> predictions = DataUtils.getPredictions(samples, 
    predictionOutputs);

The getSamples() method can fetch samples from the Iris training, validation or test sets, or evantually generate random samples.

At this point we can get our global explanation from the AggregatedLimeExplainer.

Map<String, Saliency> saliencyMap = globalExplainer.explainFromPredictions(model, 
     predictions).get(Config.INSTANCE.getAsyncTimeout(), 
     Config.INSTANCE.getAsyncTimeUnit());

You might have noticed that the explanation API is async, so we supply also information about when to eventually timeout the execution.

Let’s print what is the mean LIME score for each of the four features on the PMML regression model.

for (Saliency saliency : saliencyMap.values()) {
   for (FeatureImportance fi : saliency.getPerFeatureImportance()) {
        System.out.println(fi.getFeature().getName()+": "+fi.getScore());
   }
}
petalLength: 8.674779230045368E-5
petalWidth: 8.68437473088059E-5
sepalLength: 8.674876155890406E-5
sepalWidth: 8.673110128646624E-5

The model seems to distribute the importance very fairly among the four features, with a slightly higher importance on the petalWidth feature. Length related features seem to have almost the exact same importance instead, that’s interesting!

Global explanations tools are very useful to get insights about the inner workings of ML models, AI systems or more generally about black box decision services. Such insights can be used, for example, by a data scientist to check whether the importance assigned by decision services is inline with the design of such models, or if there are unwanted or unexpected such behviors.

0 0 votes
Article Rating
Subscribe
Notify of
guest
0 Comments
Inline Feedbacks
View all comments