package org.integratedmodelling.riskwiz.pfunction.distrib;

import Jama.Matrix;
import java.util.List;
import java.util.Vector;
import org.integratedmodelling.riskwiz.pfunction.JavaCondProbDistrib;

/* loaded from: input_file:lib/riskwiz-1.0.0.jar:org/integratedmodelling/riskwiz/pfunction/distrib/MultivarGaussian.class */
public class MultivarGaussian extends JavaCondProbDistrib {
    private boolean fixedMean;
    private boolean fixedCovariance;
    private int d;
    private Matrix mu;
    private Matrix sigma;
    private double dimFactor;
    private double normConst;
    private Matrix sigmaInverse;
    private Matrix sqrtSigma;

    public MultivarGaussian(Matrix matrix, Matrix matrix2) {
        setDimension(matrix.getRowDimension());
        this.fixedMean = true;
        setMean(matrix);
        this.fixedCovariance = true;
        setCovariance(matrix2);
    }

    public MultivarGaussian(Vector vector) {
        if (vector.size() == 0) {
            throw new IllegalArgumentException("Dimension of MultivarGaussian distribution must be specified as parameter.");
        }
        if (!(vector.get(0) instanceof Integer)) {
            throw new IllegalArgumentException("Dimension of MultivarGaussian distribution must be an integer, not " + vector.get(0) + " of " + vector.get(0).getClass());
        }
        setDimension(((Integer) vector.get(0)).intValue());
        if (vector.size() == 1) {
            this.fixedMean = false;
            this.fixedCovariance = false;
            return;
        }
        if (vector.size() == 2) {
            this.fixedMean = false;
            this.fixedCovariance = true;
            setCovariance(vector.get(1));
        } else {
            if (vector.size() != 3) {
                throw new IllegalArgumentException("MultivarGaussian CPD expects at most 3 parameters, not " + vector.size());
            }
            this.fixedMean = true;
            setMean(vector.get(1));
            this.fixedCovariance = true;
            setCovariance(vector.get(2));
        }
    }

    @Override // org.integratedmodelling.riskwiz.pfunction.ICondProbDistrib
    public double getProb(List list, Object obj) {
        initParams(list);
        if ((obj instanceof Matrix) && ((Matrix) obj).getRowDimension() == this.d && ((Matrix) obj).getColumnDimension() == 1) {
            return getProbInternal((Matrix) obj);
        }
        throw new IllegalArgumentException("The value passed to the " + this.d + "-dimensional multivariate Gaussian distribution's getProb method must be a column vector of length " + this.d + ", not " + obj);
    }

    public double getProb(Matrix matrix) {
        if (this.fixedMean && this.fixedCovariance) {
            return getProbInternal(matrix);
        }
        throw new IllegalStateException("Mean and covariance are not fixed.");
    }

    public double getLogProb(Matrix matrix) {
        if (this.fixedMean && this.fixedCovariance) {
            return getLogProbInternal(matrix);
        }
        throw new IllegalStateException("Mean and covariance are not fixed.");
    }

    @Override // org.integratedmodelling.riskwiz.pfunction.ICondProbDistrib
    public Object sampleVal(List list) {
        initParams(list);
        return sampleVal();
    }

    public Matrix sampleVal() {
        Matrix matrix = new Matrix(this.d, 1);
        for (int i = 0; i < this.d; i++) {
            matrix.set(i, 0, UnivarGaussian.STANDARD.sampleVal());
        }
        return this.mu.plus(this.sqrtSigma.times(matrix));
    }

    public Matrix getMean() {
        if (this.fixedMean) {
            return this.mu;
        }
        return null;
    }

    public Matrix getCovar() {
        if (this.fixedCovariance) {
            return this.sigma;
        }
        return null;
    }

    private double getProbInternal(Matrix matrix) {
        return Math.exp((-0.5d) * matrix.minus(this.mu).transpose().times(this.sigmaInverse).times(matrix.minus(this.mu)).get(0, 0)) / this.normConst;
    }

    private double getLogProbInternal(Matrix matrix) {
        return ((-0.5d) * matrix.minus(this.mu).transpose().times(this.sigmaInverse).times(matrix.minus(this.mu)).get(0, 0)) - Math.log(this.normConst);
    }

    private void initParams(List list) {
        if (this.fixedMean) {
            if (list.size() > 0) {
                throw new IllegalArgumentException("MultivarGaussian CPD with fixed mean expects no arguments.");
            }
        } else {
            if (list.size() < 1) {
                throw new IllegalArgumentException("MultivarGaussian CPD created without a fixed mean; requires mean as an argument.");
            }
            setMean(list.get(0));
            if (this.fixedCovariance) {
                if (list.size() > 1) {
                    throw new IllegalArgumentException("MultivarGaussian CPD with fixed covariance matrix expects only one argument.");
                }
            } else {
                if (list.size() < 2) {
                    throw new IllegalArgumentException("MultivarGaussian CPD created without a fixed covariance matrix; requires covariance matrix as argument.");
                }
                setCovariance(list.get(1));
            }
        }
    }

    private void setDimension(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Dimension of MultivarGaussian distribution must be positive, not " + i);
        }
        this.d = i;
        this.dimFactor = Math.pow(6.283185307179586d, this.d / 2.0d);
    }

    private void setMean(Object obj) {
        if (!(obj instanceof Matrix) || ((Matrix) obj).getColumnDimension() != 1) {
            throw new IllegalArgumentException("The mean of a MultivarGaussian distribution must be a column vector, not " + obj + " of " + obj.getClass());
        }
        this.mu = (Matrix) obj;
        if (this.mu.getRowDimension() != this.d) {
            throw new IllegalArgumentException("Mean of " + this.d + "-dimensional Gaussian distribution must be column vector of length " + this.d);
        }
    }

    private void setCovariance(Object obj) {
        if (!(obj instanceof Matrix) || ((Matrix) obj).getColumnDimension() != this.d || ((Matrix) obj).getColumnDimension() != this.d) {
            throw new IllegalArgumentException("The covariance matrix of a " + this.d + "-dimensional Gaussian distribution must be a " + this.d + "-by-" + this.d + " Matrix, not " + obj + " of " + obj.getClass());
        }
        this.sigma = (Matrix) obj;
        for (int i = 0; i < this.sigma.getRowDimension(); i++) {
            for (int i2 = 0; i2 < this.sigma.getColumnDimension(); i2++) {
                if (Math.abs((this.sigma.get(i, i2) / this.sigma.get(i2, i)) - 1.0d) > 1.0E-6d) {
                    throw new IllegalArgumentException("Invalid covariance matrix (not symmetric): " + this.sigma);
                }
            }
        }
        this.normConst = Math.sqrt(this.sigma.det()) * this.dimFactor;
        this.sigmaInverse = this.sigma.inverse();
        this.sqrtSigma = this.sigma.chol().getL();
    }
}
