"""
Class LocalAdaptivityCalculator provides methods to adaptively control of micro simulations
in a local way. If the Micro Manager is run in parallel, simulations on one rank are compared to
each other. A global comparison is not done.
"""
import sys
import numpy as np
from copy import deepcopy
from mpi4py import MPI

from .adaptivity import AdaptivityCalculator


class LocalAdaptivityCalculator(AdaptivityCalculator):
    def __init__(
        self,
        configurator,
        num_sims,
        base_logger,
        rank,
        comm,
        micro_problem_cls,
    ) -> None:
        """
        Class constructor.

        Parameters
        ----------
        configurator : object of class Config
            Object which has getter functions to get parameters defined in the configuration file.
        num_sims : int
            Number of micro simulations.
        base_logger : object of class Logger
            Logger object to log messages.
        rank : int
            Rank of the current MPI process.
        comm : MPI.Comm
            Communicator for MPI.
        micro_problem_cls : callable
            Class of micro problem.
        """
        super().__init__(configurator, num_sims, micro_problem_cls, base_logger, rank)
        self._comm = comm

        # similarity_dists: 2D array having similarity distances between each micro simulation pair
        # This matrix is modified in place via the function update_similarity_dists
        self._similarity_dists = np.zeros((num_sims, num_sims))

    def compute_adaptivity(
        self,
        dt,
        micro_sims,
        data_for_adaptivity: dict,
    ) -> None:
        """
        Compute adaptivity locally (within a rank).

        Parameters
        ----------
        dt : float
            Current time step
        micro_sims : list
            List containing simulation objects
        data_for_adaptivity : dict
            A dictionary containing the names of the data to be used in adaptivity as keys and information on whether
            the data are scalar or vector as values.
        """
        for name in data_for_adaptivity.keys():
            if name not in self._adaptivity_data_names:
                raise ValueError(
                    "Data for adaptivity must be one of the following: {}".format(
                        self._adaptivity_data_names.keys()
                    )
                )

        self._update_similarity_dists(dt, data_for_adaptivity)

        self._local_max_similarity_dist = np.amax(self._similarity_dists)

        # Gather maximum similarity distance from every rank, and use the global maximum distance
        self._max_similarity_dist = self._comm.allreduce(
            self._local_max_similarity_dist, op=MPI.MAX
        )

        self._update_active_sims()

        self._update_inactive_sims(micro_sims)

        self._associate_inactive_to_active()

    def get_active_sim_local_ids(self) -> np.ndarray:
        """
        Get the local ids of active simulations on this rank.

        Returns
        -------
        numpy array
            1D array of active simulation ids
        """
        return np.where(self._is_sim_active)[0]

    def get_active_sim_global_ids(self) -> np.ndarray:
        """
        Get the global(local) ids of active simulations on this rank.

        For local adaptivity, global ids are same as local ids.

        Returns
        -------
        numpy array
            1D array of active simulation ids
        """
        active_sim_ids = self.get_active_sim_local_ids()
        return active_sim_ids

    def get_inactive_sim_local_ids(self) -> np.ndarray:
        """
        Get the local ids of inactive simulations on this rank.

        Returns
        -------
        numpy array
            1D array of inactive simulation ids
        """
        return np.where(self._is_sim_active == False)[0]

    def get_inactive_sim_global_ids(self) -> np.ndarray:
        """
        Get the global(local) ids of inactive simulations on this rank.

        For local adaptivity, global ids are same as local ids.

        Returns
        -------
        numpy array
            1D array of inactive simulation ids
        """
        inactive_sim_ids = self.get_inactive_sim_local_ids()
        return inactive_sim_ids

    def get_full_field_micro_output(self, micro_output: list) -> list:
        """
        Get the full field micro output from active simulations to inactive simulations.

        Parameters
        ----------
        micro_output : list
            List of dicts having individual output of each simulation. Only the active simulation outputs are entered.

        Returns
        -------
        micro_output : list
            List of dicts having individual output of each simulation. Active and inactive simulation outputs are entered.
        """
        micro_sims_output = deepcopy(micro_output)

        inactive_sim_ids = self.get_inactive_sim_local_ids()

        for inactive_id in inactive_sim_ids:
            micro_sims_output[inactive_id] = deepcopy(
                micro_sims_output[self._sim_is_associated_to[inactive_id]]
            )

        return micro_sims_output

    def log_metrics(self, n: int) -> None:
        """
        Log the following metrics:

        Local metrics:
        - Time window at which the metrics are logged
        - Number of active simulations
        - Number of inactive simulations

        Global metrics:
        - Time window at which the metrics are logged
        - Global number of active simulations
        - Global number of inactive simulations
        - Average number of active simulations
        - Average number of inactive simulations
        - Maximum number of active simulations
        - Maximum number of inactive simulations

        Parameters
        ----------
        n : int
            Time step count at which the metrics are logged
        """
        active_sims_on_this_rank = 0
        inactive_sims_on_this_rank = 0
        for local_id in range(self._is_sim_active.size):
            if self._is_sim_active[local_id]:
                active_sims_on_this_rank += 1
            else:
                inactive_sims_on_this_rank += 1

        if (
            self._adaptivity_output_type == "all"
            or self._adaptivity_output_type == "local"
        ):
            self._metrics_logger.log_info(
                "{}|{}|{}".format(
                    n,
                    active_sims_on_this_rank,
                    inactive_sims_on_this_rank,
                )
            )

        if (
            self._adaptivity_output_type == "global"
            or self._adaptivity_output_type == "all"
        ):
            active_sims_rankwise = self._comm.gather(active_sims_on_this_rank, root=0)
            inactive_sims_rankwise = self._comm.gather(
                inactive_sims_on_this_rank, root=0
            )

            if self._rank == 0:
                size = self._comm.Get_size()

                self._global_metrics_logger.log_info(
                    "{}|{}|{}|{}|{}|{}|{}|{}|{}".format(
                        n,
                        sum(active_sims_rankwise),
                        sum(inactive_sims_rankwise),
                        sum(active_sims_rankwise) / size,
                        sum(inactive_sims_rankwise) / size,
                        max(active_sims_rankwise),
                        active_sims_rankwise.index(max(active_sims_rankwise)),
                        max(inactive_sims_rankwise),
                        inactive_sims_rankwise.index(max(inactive_sims_rankwise)),
                    )
                )

    def _update_active_sims(self) -> None:
        """
        Update set of active micro simulations. Active micro simulations are compared to each other
        and if found similar, one of them is deactivated.
        """
        if self._max_similarity_dist == 0.0:
            self._base_logger.log_warning(
                "All similarity distances are zero, which means all the data for adaptivity is the same. Coarsening tolerance will be manually set to minimum float number."
            )
            self._coarse_tol = sys.float_info.min
        else:
            self._coarse_tol = (
                self._coarse_const * self._refine_const * self._max_similarity_dist
            )

        active_gids = self.get_active_sim_local_ids().tolist()

        active_gids_to_check = active_gids.copy()

        # Update the set of active micro sims
        for gid in active_gids:
            if self._check_for_deactivation(gid, active_gids_to_check):
                self._is_sim_active[gid] = False
                self._just_deactivated.append(gid)
                # Remove deactivated gid from further checks
                active_gids_to_check.remove(gid)

    def _update_inactive_sims(self, micro_sims: list) -> None:
        """
        Update set of inactive micro simulations. Each inactive micro simulation is compared to all active ones
        and if it is not similar to any of them, it is activated.

        If a micro simulation which has been inactive since the start of the simulation is activated for the
        first time, the simulation object is created and initialized.

        Parameters
        ----------
        micro_sims : list
            List containing micro simulation objects.
        """
        self._ref_tol = self._refine_const * self._max_similarity_dist

        active_lids = self.get_active_sim_local_ids()
        inactive_lids = self.get_inactive_sim_local_ids()

        to_be_activated_ids = []
        # Update the set of inactive micro sims
        for lid in inactive_lids:
            if self._check_for_activation(lid, active_lids):
                self._is_sim_active[lid] = True
                if lid not in self._just_deactivated:
                    to_be_activated_ids.append(lid)
                    # Add the newly activated lid to active_lids for further checks
                    active_lids = np.append(active_lids, lid)

        self._just_deactivated.clear()  # Clear the list of sims deactivated in this step

        # Update the set of inactive micro sims
        for i in to_be_activated_ids:
            associated_active_id = self._sim_is_associated_to[i]
            micro_sims[i] = self._micro_problem_cls(i)
            micro_sims[i].set_state(micro_sims[associated_active_id].get_state())
            self._sim_is_associated_to[
                i
            ] = -2  # Active sim cannot have an associated sim

        # Delete the inactive micro simulations which have not been activated
        for i in range(self._is_sim_active.size):
            if not self._is_sim_active[i]:
                micro_sims[i] = 0
