package io.virtdata.continuous.common;

import io.virtdata.api.ValueType;
import io.virtdata.api.VirtDataFunctionLibrary;
import io.virtdata.ast.FunctionCall;
import io.virtdata.ast.VirtDataFlow;
import io.virtdata.core.ResolvedFunction;
import io.virtdata.discrete.common.IntegerDistributions;
import io.virtdata.parser.VirtDataDSL;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.LongToDoubleFunction;
import org.apache.commons.lang3.reflect.ConstructorUtils;
import org.apache.commons.math4.distribution.EmpiricalDistribution;
import org.apache.commons.math4.distribution.EnumeratedRealDistribution;
import org.apache.commons.statistics.distribution.BetaDistribution;
import org.apache.commons.statistics.distribution.CauchyDistribution;
import org.apache.commons.statistics.distribution.ChiSquaredDistribution;
import org.apache.commons.statistics.distribution.ContinuousDistribution;
import org.apache.commons.statistics.distribution.ExponentialDistribution;
import org.apache.commons.statistics.distribution.FDistribution;
import org.apache.commons.statistics.distribution.GammaDistribution;
import org.apache.commons.statistics.distribution.GumbelDistribution;
import org.apache.commons.statistics.distribution.LaplaceDistribution;
import org.apache.commons.statistics.distribution.LevyDistribution;
import org.apache.commons.statistics.distribution.LogNormalDistribution;
import org.apache.commons.statistics.distribution.LogisticDistribution;
import org.apache.commons.statistics.distribution.NakagamiDistribution;
import org.apache.commons.statistics.distribution.NormalDistribution;
import org.apache.commons.statistics.distribution.ParetoDistribution;
import org.apache.commons.statistics.distribution.TDistribution;
import org.apache.commons.statistics.distribution.TriangularDistribution;
import org.apache.commons.statistics.distribution.UniformContinuousDistribution;
import org.apache.commons.statistics.distribution.WeibullDistribution;

/* loaded from: input_file:io/virtdata/continuous/common/RealDistributions.class */
public class RealDistributions implements VirtDataFunctionLibrary {
    private static final String MAPTO = "mapto_";
    private static final String HASHTO = "hashto_";
    private static final String COMPUTE = "compute_";
    private static final String INTERPOLATE = "interpolate_";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/virtdata/continuous/common/RealDistributions$RealDistribution.class */
    public enum RealDistribution {
        levy(LevyDistribution.class),
        nakagami(NakagamiDistribution.class),
        triangular(TriangularDistribution.class),
        exponential(ExponentialDistribution.class),
        logistic(LogisticDistribution.class),
        enumerated_real(EnumeratedRealDistribution.class),
        laplace(LaplaceDistribution.class),
        log_normal(LogNormalDistribution.class),
        cauchy(CauchyDistribution.class),
        f(FDistribution.class),
        t(TDistribution.class),
        empirical(EmpiricalDistribution.class),
        normal(NormalDistribution.class),
        weibull(WeibullDistribution.class),
        chi_squared(ChiSquaredDistribution.class),
        gumbel(GumbelDistribution.class),
        beta(BetaDistribution.class),
        pareto(ParetoDistribution.class),
        gamma(GammaDistribution.class),
        uniform_real(UniformContinuousDistribution.class);

        private final Class<? extends ContinuousDistribution> distribution;

        RealDistribution(Class cls) {
            this.distribution = cls;
        }

        public Class<? extends ContinuousDistribution> getDistributionClass() {
            return this.distribution;
        }

        public static Optional<RealDistribution> optionalValueOf(String str) {
            for (RealDistribution realDistribution : values()) {
                if (realDistribution.toString().equals(str)) {
                    return Optional.of(realDistribution);
                }
            }
            return Optional.empty();
        }
    }

    public static LongToDoubleFunction forSpec(String str) {
        VirtDataDSL.ParseResult parse = VirtDataDSL.parse(str);
        if (parse.throwable != null) {
            throw new RuntimeException(parse.throwable);
        }
        VirtDataFlow virtDataFlow = parse.flow;
        if (virtDataFlow.getExpressions().size() > 1) {
            throw new RuntimeException("Unable to parse flows in " + IntegerDistributions.class);
        }
        FunctionCall call = virtDataFlow.getLastExpression().getCall();
        Class<?> cls = (Class) Optional.ofNullable(call.getInputType()).map(ValueType::valueOfClassName).map((v0) -> {
            return v0.getValueClass();
        }).orElse(null);
        Class<?> cls2 = (Class) Optional.ofNullable(call.getOutputType()).map(ValueType::valueOfClassName).map((v0) -> {
            return v0.getValueClass();
        }).orElse(null);
        if (cls != null && cls != Long.TYPE) {
            throw new RuntimeException("This only supports long for input.");
        }
        if (cls2 != null && cls2 != Double.TYPE) {
            throw new RuntimeException("This only supports double for output.");
        }
        List<ResolvedFunction> resolveFunctions = new RealDistributions().resolveFunctions(cls2 == null ? Integer.TYPE : cls2, cls == null ? Long.TYPE : cls, call.getFunctionName(), call.getArguments());
        if (resolveFunctions.size() > 1) {
            throw new RuntimeException("Found " + resolveFunctions.size() + " implementations, be more specific withinput or output qualifiers as in int -> or -> long");
        }
        return (LongToDoubleFunction) resolveFunctions.get(0).getFunctionObject();
    }

    private static String distributionNameFor(String str) {
        return str.replaceAll(COMPUTE, "").replaceAll(INTERPOLATE, "").replaceAll(MAPTO, "").replaceAll(HASHTO, "");
    }

    @Override // io.virtdata.api.Named
    public String getName() {
        return "math4-ccurves";
    }

    @Override // io.virtdata.api.VirtDataFunctionLibrary
    public List<ResolvedFunction> resolveFunctions(Class<?> cls, Class<?> cls2, String str, Object... objArr) {
        ArrayList arrayList = new ArrayList();
        ValueType valueOfAssignableClass = cls2 == null ? null : ValueType.valueOfAssignableClass(cls2);
        Optional<RealDistribution> optionalValueOf = RealDistribution.optionalValueOf(distributionNameFor(str));
        if (!optionalValueOf.isPresent()) {
            return arrayList;
        }
        Class<? extends ContinuousDistribution> distributionClass = optionalValueOf.get().getDistributionClass();
        Class[] clsArr = new Class[objArr.length];
        for (int i = 0; i < objArr.length; i++) {
            clsArr[i] = objArr[i].getClass();
        }
        Constructor matchingAccessibleConstructor = ConstructorUtils.getMatchingAccessibleConstructor(distributionClass, clsArr);
        try {
            ContinuousDistribution continuousDistribution = (ContinuousDistribution) ConstructorUtils.invokeConstructor(distributionClass, objArr);
            boolean z = !str.contains(COMPUTE) || str.contains(INTERPOLATE);
            boolean z2 = !str.contains(MAPTO) || str.contains(HASHTO);
            RealDistributionICDSource realDistributionICDSource = new RealDistributionICDSource(continuousDistribution);
            if (valueOfAssignableClass == ValueType.LONG || valueOfAssignableClass == null) {
                arrayList.add(new ResolvedFunction(z ? new InterpolatingLongDoubleSampler(realDistributionICDSource, 1000, z2) : new RealLongDoubleSampler(realDistributionICDSource, z2), true, matchingAccessibleConstructor.getParameterTypes(), objArr, Long.TYPE, Double.TYPE, getName()));
            }
            if (valueOfAssignableClass == ValueType.INT || valueOfAssignableClass == null) {
                arrayList.add(new ResolvedFunction(z ? new InterpolatingIntDoubleSampler(realDistributionICDSource, 1000, z2) : new RealIntDoubleSampler(realDistributionICDSource, z2), true, matchingAccessibleConstructor.getParameterTypes(), objArr, Integer.TYPE, Double.TYPE, getName()));
            }
            return arrayList;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // io.virtdata.api.VirtDataFunctionLibrary
    public List<String> getDataMapperNames() {
        ArrayList arrayList = new ArrayList();
        Arrays.stream(RealDistribution.values()).map((v0) -> {
            return String.valueOf(v0);
        }).forEach(str -> {
            arrayList.add(str);
            arrayList.add(MAPTO + str);
            arrayList.add("mapto_compute_" + str);
            arrayList.add(COMPUTE + str);
        });
        return arrayList;
    }
}
