Source code for symdet.symmetry_group_extraction.group_detection

"""
This program and the accompanying materials are made available under the terms of the
Eclipse Public License v2.0 which accompanies this distribution, and is available at
https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincware Project.

Group Detection
===============
Cluster raw data into symmetry groups
"""
from symdet.ml_models.dense_model import DenseModel
from symdet.analysis.model_visualization import Visualizer
from typing import Tuple
import numpy as np
import tensorflow as tf
from symdet.utils.data_clustering import compute_com, compute_radius_of_gyration


[docs]class GroupDetection: """ A class to cluster raw data into symmetry groups. Attributes ---------- model : DenseModel Model to use in the group detection. data_clusters : dict Data cluster class used for the partitioning of the data. representation_set : str Which set to use in the representation, train, validation, or test. """ def __init__(self, model: DenseModel, data_clusters: dict, representation_set: str = 'train'): """ Constructor for the GroupDetection class. Parameters ---------- model : DenseModel Model to use in the group detection. data_clusters : dict Data cluster class used for the partitioning of the data. representation_set : str Which set to use in the representation, train, validation, or test. """ self.model = model self.data = data_clusters self.representation_set = representation_set self.model.add_data(self.data) # add the data to the model. def _get_model_predictions(self) -> Tuple: """ Train the attached model. Returns ------- val_data : tf.Tensor Data on which the prediction were made. model_predictions : Tuple Embedding layer of the NN on validation data. """ self.model.train_model() if self.representation_set == 'train': val_data = self.model.train_ds predictions = self.model.model.predict(val_data[:, 0:self.model.input_shape]) elif self.representation_set == 'test:': val_data = self.model.test_ds predictions = self.model.model.predict(val_data[:, 0:self.model.input_shape]) else: val_data = self.model.val_ds predictions = self.model.model.predict(val_data[:, 0:self.model.input_shape]) return val_data, predictions def _run_visualization(self): """ Perform a visualization on the TSNE data. Returns ------- """ pass @staticmethod def _cluster_detection(function_data: np.ndarray, data: np.ndarray): """ Use the results of the TSNE reduction to extract clusters. Parameters ---------- function_data : tf.Tensor A tensor of the raw data to be collected. data : np.ndarray Results of the tsne representation Returns ------- clusters : dict An unordered point cloud of data belonging to the same cluster. e.g. {1: [radial values], 2: [radial_values], ...} """ net_array = np.concatenate((data, function_data), 1) sorted_data = tf.gather(net_array, tf.argsort(net_array[:, -1])).numpy() class_array = np.unique(function_data[:, -1]) point_cloud = {} # loop over the class array for i, item in enumerate(class_array): start = np.searchsorted(sorted_data[:, -1], item, side='left') stop = np.searchsorted(sorted_data[:, -1], item, side='right') - 1 com = compute_com(sorted_data[start:stop, 0:2]) rg = compute_radius_of_gyration(sorted_data[start:stop, 0:2], com) #print(f"Class: {item}, COM: {com}, Rg: {rg}") if rg > 1000: point_cloud[item] = sorted_data[start:stop, 2:-1] return point_cloud @staticmethod def _filter_data(predictions: tf.Tensor, targets: tf.Tensor): """ Check which data points are predicted well and include them in the data. Parameters ---------- targets : tf.Tensor Target values on which predictions were made. predictions : tf.Tensor Network predictions. Returns ------- """ accepted_candidates = np.zeros(len(predictions)) target_values = tf.keras.utils.to_categorical(targets[:, -1]) counter = 0 for i, item in enumerate(predictions): if np.linalg.norm(predictions[i] - target_values[i]) <= 2e-1: accepted_candidates[counter] = i counter += 1 accepted_candidates = tf.convert_to_tensor(accepted_candidates[0:counter], dtype=tf.int32) return tf.gather(targets, accepted_candidates)
[docs] def run_symmetry_detection(self, plot: bool = True, save: bool = False): """ Run the symmetry detection routine. Parameters ---------- plot : bool Plot the TSNE visualization. save : bool Save the image plotted. Returns ------- None """ validation_data, predictions = self._get_model_predictions() accepted_data = self._filter_data(predictions, validation_data) representation = self.model.get_embedding_layer_representation(accepted_data) # get the embedding layer visualizer = Visualizer(representation, accepted_data[:, -1]) data = visualizer.tsne_visualization(plot=plot, save=save) # determine coupled groups in the tSNE representation. point_cloud = self._cluster_detection(validation_data, data) return point_cloud