Source code for symdet.analysis.model_visualization

"""
This program and the accompanying materials are made available under the terms of the
Eclipse Public License v2.0 which accompanies this distribution, and is available at
https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincware Project.

Model Visualization
===================
Visualize the NN models in different ways
"""
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np


[docs]class Visualizer: """ Class for the visualization of NN models Attributes ---------- data : tf.Tensor data to be visualized by the TSNE colour_map : list A colour map to be applied to the data so that it matches that of the cluster map. """ def __init__(self, data, colour_map): """ Constructor the visualizer class Parameters ---------- data : tf.Tensor data to be visualized by the TSNE colour_map : list A colour map to be applied to the data so that it matches that of the cluster map. """ self.data = data self.colour_map = colour_map
[docs] def tsne_visualization(self, perplexity=50, n_components=2, plot: bool = True, save: bool = False) -> np.ndarray: """ Display a TSNE representation of the models embedding layer Parameters ---------- perplexity : int Perplexity of the tsne representation. n_components : int Dimensionality of the TSNE representation. plot : bool If true the TSNE representation will be plotted. save : bool If true the plot will be saved. Returns ------- tsne_representation : np.ndarray Returns the tsne representation as an array. Notes ----- See the theory documentation for a full overview of these parameters, particularly in the case of the TSNE values. """ tsne_model = TSNE(n_components=n_components, perplexity=perplexity, random_state=1) tsne_representation = tsne_model.fit_transform(self.data) if plot: plt.scatter( tsne_representation[:, 0], tsne_representation[:, 1], c=self.colour_map, marker='.', cmap="viridis", vmax=11, vmin=-1, ) plt.colorbar() if save: plt.savefig( f"tsne_representation_{perplexity}_{n_components}.svg", dpi=800, format="svg", ) plt.show() return tsne_representation