package org.nd4j.linalg.dataset.api.preprocessor;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/VGG16ImagePreProcessor.class */
public class VGG16ImagePreProcessor implements DataNormalization {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) VGG16ImagePreProcessor.class);
    public static final INDArray VGG_MEAN_OFFSET_BGR = Nd4j.create(new double[]{123.68d, 116.779d, 103.939d});

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public void fit(DataSet dataSet) {
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void fit(DataSetIterator dataSetIterator) {
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization, org.nd4j.linalg.dataset.api.DataSetPreProcessor
    public void preProcess(DataSet dataSet) {
        preProcess(dataSet.getFeatures());
    }

    public void preProcess(INDArray iNDArray) {
        Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastSubOp(iNDArray.dup(), VGG_MEAN_OFFSET_BGR, iNDArray, 1));
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public void transform(DataSet dataSet) {
        preProcess(dataSet);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transform(INDArray iNDArray) {
        preProcess(iNDArray);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transform(INDArray iNDArray, INDArray iNDArray2) {
        transform(iNDArray);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transformLabel(INDArray iNDArray) {
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transformLabel(INDArray iNDArray, INDArray iNDArray2) {
        transformLabel(iNDArray);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public void revert(DataSet dataSet) {
        revertFeatures(dataSet.getFeatures());
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public NormalizerType getType() {
        return NormalizerType.IMAGE_VGG16;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void revertFeatures(INDArray iNDArray) {
        Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastAddOp(iNDArray.dup(), VGG_MEAN_OFFSET_BGR, iNDArray, 1));
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void revertFeatures(INDArray iNDArray, INDArray iNDArray2) {
        revertFeatures(iNDArray);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void revertLabels(INDArray iNDArray) {
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void revertLabels(INDArray iNDArray, INDArray iNDArray2) {
        revertLabels(iNDArray);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void fitLabel(boolean z) {
        if (z) {
            log.warn("Labels fitting not currently supported for ImagePreProcessingScaler. Labels will not be modified");
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public boolean isFitLabel() {
        return false;
    }
}
