package org.openimaj.demos;

import com.jmatio.io.MatFileReader;
import com.jmatio.io.MatFileWriter;
import com.jmatio.types.MLArray;
import com.jmatio.types.MLDouble;
import com.jmatio.types.MLSingle;
import com.jmatio.types.MLStructure;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import org.openimaj.feature.FloatFV;
import org.openimaj.feature.local.list.MemoryLocalFeatureList;
import org.openimaj.image.feature.dense.gradient.dsift.FloatDSIFTKeypoint;
import org.openimaj.image.feature.local.aggregate.FisherVector;
import org.openimaj.math.statistics.distribution.DiagonalMultivariateGaussian;
import org.openimaj.math.statistics.distribution.MixtureOfGaussians;
import org.openimaj.math.statistics.distribution.MultivariateGaussian;

/* loaded from: input_file:org/openimaj/demos/FVFWCheckGMM.class */
public class FVFWCheckGMM {
    private static final String GMM_MATLAB_FILE = "/Users/ss/Experiments/FVFW/data/gmm_512.mat";
    private static final String[] FACE_DSIFTS_PCA = {"/Users/ss/Experiments/FVFW/data/aaron-pcadsiftaug.mat"};

    public static void main(String[] strArr) throws IOException {
        FisherVector fisherVector = new FisherVector(loadMoG(), true, true);
        for (String str : FACE_DSIFTS_PCA) {
            FloatFV aggregate = fisherVector.aggregate(loadDSIFTPCA(str));
            System.out.println(String.format("%s: %s", str, aggregate));
            System.out.println("Writing...");
            new MatFileWriter(new File(str + ".fisher.mat"), Arrays.asList(toMLArray(aggregate)));
        }
    }

    private static MemoryLocalFeatureList<FloatDSIFTKeypoint> loadDSIFTPCA(String str) throws IOException {
        MLSingle mLSingle = (MLSingle) new MatFileReader(new File(str)).getContent().get("feats");
        int n = mLSingle.getN();
        MemoryLocalFeatureList<FloatDSIFTKeypoint> memoryLocalFeatureList = new MemoryLocalFeatureList<>();
        for (int i = 0; i < n; i++) {
            FloatDSIFTKeypoint floatDSIFTKeypoint = new FloatDSIFTKeypoint();
            floatDSIFTKeypoint.descriptor = new float[mLSingle.getM()];
            for (int i2 = 0; i2 < ((float[]) floatDSIFTKeypoint.descriptor).length; i2++) {
                ((float[]) floatDSIFTKeypoint.descriptor)[i2] = ((Float) mLSingle.get(i2, i)).floatValue();
            }
            memoryLocalFeatureList.add(floatDSIFTKeypoint);
        }
        return memoryLocalFeatureList;
    }

    private static MLArray toMLArray(FloatFV floatFV) {
        MLDouble mLDouble = new MLDouble("fisherface", new int[]{((float[]) floatFV.values).length, 1});
        for (int i = 0; i < ((float[]) floatFV.values).length; i++) {
            mLDouble.set(Double.valueOf(((float[]) floatFV.values)[i]), i, 0);
        }
        return mLDouble;
    }

    private static MixtureOfGaussians loadMoG() throws IOException {
        MLStructure mLStructure = (MLStructure) new MatFileReader(new File(GMM_MATLAB_FILE)).getContent().get("codebook");
        MLSingle field = mLStructure.getField("mean");
        MLSingle field2 = mLStructure.getField("variance");
        MLSingle field3 = mLStructure.getField("coef");
        int n = field.getN();
        int m = field.getM();
        MultivariateGaussian[] multivariateGaussianArr = new MultivariateGaussian[n];
        double[] dArr = new double[n];
        for (int i = 0; i < n; i++) {
            dArr[i] = ((Float) field3.get(i, 0)).floatValue();
            DiagonalMultivariateGaussian diagonalMultivariateGaussian = new DiagonalMultivariateGaussian(m);
            for (int i2 = 0; i2 < m; i2++) {
                diagonalMultivariateGaussian.mean.set(0, i2, ((Float) field.get(i2, i)).floatValue());
                diagonalMultivariateGaussian.variance[i2] = ((Float) field2.get(i2, i)).floatValue();
            }
            multivariateGaussianArr[i] = diagonalMultivariateGaussian;
        }
        return new MixtureOfGaussians(multivariateGaussianArr, dArr);
    }
}
