Module connectome.visualization.viz_framework

framework function for the vizualisation of predictions or important features

Functions

def visualization_framework(model, X, y=None, viz_method: str = 'GFI', **kwargs)

Returns feature importance and other visualization techniques. Methods: "GFI" and "GFI_only" work for elastic net and random forest, if an aggregation by yeo7 is possible. "FI" and "FI_only" work for elastic net and random forest. "elastic_net" works for elastic net models. "shapley" works for random forest and gradient boosting. "feature_attribution" works for CNN. For more details on the methods, see the documentations of the respective functions.

Examples:

>>> # Visualize Saliency Maps for neural networks
>>> visualization_framework(model = model,
                            X = X_test,
                            y= y_test,
                            viz_method = 'feature_attribution',
                            method='saliency',
                            average=True,
                            ordered = True)
>>> # Calculate and visualize the Grouped Permutation Feature Importance, e.g. for an elastic net model. Works similar for 'GFI_only', 'FI' and 'FI_only'.
>>> visualization_framework(model = model,
                            X = X_test,
                            y= y_test,
                            viz_method = 'GFI',
                            m = 20) # the higher m (number of permutations) the more accurate the result, but the longer the runtime
>>> # Plot coefficients of an elastic net model
>>> visualization_framework(model = model,
                            X = X_test,
                            y= y_test,
                            viz_method = 'elastic_net')

Args

model
a trained ML Model
X
A dataframe
y
the labels
viz_method
Choice of "GFI", "GFI_only", "FI", "FI_only", "elastic_net", "shapley", or "feature_attribution"

Returns

List of reordered connectvity Matrices