/*
 * Decompiled with CFR 0.152.
 */
package org.meteoinfo.math.distribution;

import java.util.Arrays;
import java.util.List;
import org.apache.commons.math4.legacy.distribution.MultivariateNormalDistribution;
import org.apache.commons.math4.legacy.distribution.MultivariateRealDistribution;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.simple.RandomSource;
import org.apache.commons.statistics.distribution.ContinuousDistribution;
import org.apache.commons.statistics.distribution.NormalDistribution;
import org.meteoinfo.ndarray.Array;
import org.meteoinfo.ndarray.DataType;
import org.meteoinfo.ndarray.IndexIterator;
import org.meteoinfo.ndarray.math.ArrayUtil;

public class DistributionUtil {
    public static MultivariateNormalDistribution mvNormDist(Array means, Array covariances) {
        double[] m = (double[])ArrayUtil.copyToNDJavaArray_Double((Array)means);
        double[][] cov = (double[][])ArrayUtil.copyToNDJavaArray_Double((Array)covariances);
        return new MultivariateNormalDistribution(m, cov);
    }

    public static Array rvs(ContinuousDistribution dis, int n) {
        ContinuousDistribution.Sampler sampler = dis.createSampler((UniformRandomProvider)RandomSource.MT.create());
        double[] samples = new double[n];
        for (int i = 0; i < n; ++i) {
            samples[i] = sampler.sample();
        }
        Array r = Array.factory((DataType)DataType.DOUBLE, (int[])new int[]{n}, (Object)samples);
        return r;
    }

    public static Array rvs(ContinuousDistribution dis, List<Integer> size) {
        ContinuousDistribution.Sampler sampler = dis.createSampler((UniformRandomProvider)RandomSource.MT.create());
        int n = 1;
        for (int s : size) {
            n *= s;
        }
        double[] samples = new double[n];
        for (int i = 0; i < n; ++i) {
            samples[i] = sampler.sample();
        }
        int[] shape = size.stream().mapToInt(Integer::intValue).toArray();
        Array r = Array.factory((DataType)DataType.DOUBLE, (int[])shape, (Object)samples);
        return r;
    }

    public static Array rvs(MultivariateRealDistribution dis, int n) {
        MultivariateRealDistribution.Sampler sampler = dis.createSampler((UniformRandomProvider)RandomSource.MT.create());
        int dim = dis.getDimension();
        double[][] samples = new double[n][dim];
        for (int i = 0; i < n; ++i) {
            samples[i] = sampler.sample();
        }
        double[] s = Arrays.stream(samples).flatMapToDouble(x -> Arrays.stream(x)).toArray();
        Array r = Array.factory((DataType)DataType.DOUBLE, (int[])new int[]{n, dim}, (Object)s);
        return r;
    }

    public static Array rvs(MultivariateNormalDistribution dis, int n) {
        MultivariateRealDistribution.Sampler sampler = dis.createSampler((UniformRandomProvider)RandomSource.MT.create());
        int dim = dis.getDimension();
        double[][] samples = new double[n][dim];
        for (int i = 0; i < n; ++i) {
            samples[i] = sampler.sample();
        }
        double[] s = Arrays.stream(samples).flatMapToDouble(x -> Arrays.stream(x)).toArray();
        Array r = Array.factory((DataType)DataType.DOUBLE, (int[])new int[]{n, dim}, (Object)s);
        return r;
    }

    public static Array rvs(MultivariateNormalDistribution dis, Array size) {
        MultivariateRealDistribution.Sampler sampler = dis.createSampler((UniformRandomProvider)RandomSource.MT.create());
        int dim = dis.getDimension();
        int n = (int)size.getSize();
        double[][] samples = new double[n][dim];
        for (int i = 0; i < n; ++i) {
            samples[i] = sampler.sample();
        }
        double[] s = Arrays.stream(samples).flatMapToDouble(x -> Arrays.stream(x)).toArray();
        Array r = Array.factory((DataType)DataType.DOUBLE, (int[])new int[]{n, dim}, (Object)s);
        return r;
    }

    public static double pdf(ContinuousDistribution dis, Number x) {
        return dis.density(x.doubleValue());
    }

    public static Array pdf(ContinuousDistribution dis, Array x) {
        Array r = Array.factory((DataType)DataType.DOUBLE, (int[])x.getShape());
        IndexIterator iter = x.getIndexIterator();
        int i = 0;
        while ((long)i < r.getSize()) {
            r.setDouble(i, dis.density(iter.getDoubleNext()));
            ++i;
        }
        return r;
    }

    public static double logpdf(NormalDistribution dis, Number x) {
        return dis.logDensity(x.doubleValue());
    }

    public static Array logpdf(NormalDistribution dis, Array x) {
        Array r = Array.factory((DataType)DataType.DOUBLE, (int[])x.getShape());
        IndexIterator iter = x.getIndexIterator();
        int i = 0;
        while ((long)i < r.getSize()) {
            r.setDouble(i, dis.logDensity(iter.getDoubleNext()));
            ++i;
        }
        return r;
    }

    public static double cdf(ContinuousDistribution dis, Number x) {
        return dis.cumulativeProbability(x.doubleValue());
    }

    public static Array cdf(ContinuousDistribution dis, Array x) {
        Array r = Array.factory((DataType)DataType.DOUBLE, (int[])x.getShape());
        IndexIterator iter = x.getIndexIterator();
        int i = 0;
        while ((long)i < r.getSize()) {
            r.setDouble(i, dis.cumulativeProbability(iter.getDoubleNext()));
            ++i;
        }
        return r;
    }

    public static double pmf(ContinuousDistribution dis, Number x) {
        return dis.probability(x.doubleValue(), x.doubleValue());
    }

    public static Array pmf(ContinuousDistribution dis, Array x) {
        Array r = Array.factory((DataType)DataType.DOUBLE, (int[])x.getShape());
        IndexIterator iter = x.getIndexIterator();
        int i = 0;
        while ((long)i < r.getSize()) {
            double v = iter.getDoubleNext();
            r.setDouble(i, dis.probability(v, v));
            ++i;
        }
        return r;
    }

    public static double ppf(ContinuousDistribution dis, Number q) {
        return dis.inverseCumulativeProbability(q.doubleValue());
    }

    public static Array ppf(ContinuousDistribution dis, Array q) {
        Array r = Array.factory((DataType)DataType.DOUBLE, (int[])q.getShape());
        IndexIterator iter = q.getIndexIterator();
        int i = 0;
        while ((long)i < r.getSize()) {
            r.setDouble(i, dis.inverseCumulativeProbability(iter.getDoubleNext()));
            ++i;
        }
        return r;
    }
}

