package org.deeplearning4j.nn.modelexport.solr.handler;

import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.handler.SolrDefaultStreamFactory;
import org.deeplearning4j.core.util.ModelGuesser;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.util.NetworkUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStream.class */
public class ModelTupleStream extends TupleStream implements Expressible {
    private static final String SERIALIZED_MODEL_FILE_NAME_PARAM = "serializedModelFileName";
    private static final String INPUT_KEYS_PARAM = "inputKeys";
    private static final String OUTPUT_KEYS_PARAM = "outputKeys";
    private final TupleStream tupleStream;
    private final String serializedModelFileName;
    private final String inputKeysParam;
    private final String outputKeysParam;
    private final String[] inputKeys;
    private final String[] outputKeys;
    private final SolrResourceLoader solrResourceLoader;
    private final Model model;

    public ModelTupleStream(StreamExpression streamExpression, StreamFactory streamFactory) throws IOException {
        List expressionOperandsRepresentingTypes = streamFactory.getExpressionOperandsRepresentingTypes(streamExpression, new Class[]{Expressible.class, TupleStream.class});
        if (expressionOperandsRepresentingTypes.size() != 1) {
            throw new IOException("Expected exactly one stream in expression: " + streamExpression);
        }
        this.tupleStream = streamFactory.constructStream((StreamExpression) expressionOperandsRepresentingTypes.get(0));
        this.serializedModelFileName = getOperandValue(streamExpression, streamFactory, SERIALIZED_MODEL_FILE_NAME_PARAM);
        this.inputKeysParam = getOperandValue(streamExpression, streamFactory, INPUT_KEYS_PARAM);
        this.inputKeys = this.inputKeysParam.split(",");
        this.outputKeysParam = getOperandValue(streamExpression, streamFactory, OUTPUT_KEYS_PARAM);
        this.outputKeys = this.outputKeysParam.split(",");
        if (!(streamFactory instanceof SolrDefaultStreamFactory)) {
            throw new IOException(getClass().getName() + " requires a " + SolrDefaultStreamFactory.class.getName() + " StreamFactory");
        }
        this.solrResourceLoader = ((SolrDefaultStreamFactory) streamFactory).getSolrResourceLoader();
        this.model = restoreModel(openInputStream());
    }

    private static String getOperandValue(StreamExpression streamExpression, StreamFactory streamFactory, String str) throws IOException {
        StreamExpressionNamedParameter namedOperand = streamFactory.getNamedOperand(streamExpression, str);
        String str2 = null;
        if (namedOperand != null && (namedOperand.getParameter() instanceof StreamExpressionValue)) {
            str2 = namedOperand.getParameter().getValue();
        }
        if (str2 == null) {
            throw new IOException("Expected '" + str + "' in expression: " + streamExpression);
        }
        return str2;
    }

    public Map toMap(Map<String, Object> map) {
        return super.toMap(map);
    }

    public void setStreamContext(StreamContext streamContext) {
        this.tupleStream.setStreamContext(streamContext);
    }

    public List<TupleStream> children() {
        return this.tupleStream.children();
    }

    public void open() throws IOException {
        this.tupleStream.open();
    }

    public void close() throws IOException {
        this.tupleStream.close();
    }

    public Tuple read() throws IOException {
        Tuple read = this.tupleStream.read();
        if (read.EOF) {
            return read;
        }
        return applyOutputsToTuple(read, NetworkUtils.output(this.model, getInputsFromTuple(read)));
    }

    public StreamComparator getStreamSort() {
        return this.tupleStream.getStreamSort();
    }

    public Explanation toExplanation(StreamFactory streamFactory) throws IOException {
        return new StreamExplanation(getStreamNodeId().toString()).withChildren(new Explanation[]{this.tupleStream.toExplanation(streamFactory)}).withExpressionType("stream-decorator").withFunctionName(streamFactory.getFunctionName(getClass())).withImplementingClass(getClass().getName()).withExpression(toExpression(streamFactory, false).toString());
    }

    public StreamExpressionParameter toExpression(StreamFactory streamFactory) throws IOException {
        return toExpression(streamFactory, true);
    }

    private StreamExpression toExpression(StreamFactory streamFactory, boolean z) throws IOException {
        StreamExpression streamExpression = new StreamExpression(streamFactory.getFunctionName(getClass()));
        if (!z) {
            streamExpression.addParameter("<stream>");
        } else {
            if (!(this.tupleStream instanceof Expressible)) {
                throw new IOException("This " + getClass().getName() + " contains a non-Expressible TupleStream " + this.tupleStream.getClass().getName());
            }
            streamExpression.addParameter(this.tupleStream.toExpression(streamFactory));
        }
        streamExpression.addParameter(new StreamExpressionNamedParameter(SERIALIZED_MODEL_FILE_NAME_PARAM, this.serializedModelFileName));
        streamExpression.addParameter(new StreamExpressionNamedParameter(INPUT_KEYS_PARAM, this.inputKeysParam));
        streamExpression.addParameter(new StreamExpressionNamedParameter(OUTPUT_KEYS_PARAM, this.outputKeysParam));
        return streamExpression;
    }

    protected InputStream openInputStream() throws IOException {
        return this.solrResourceLoader.openResource(this.serializedModelFileName);
    }

    protected Model restoreModel(InputStream inputStream) throws IOException {
        try {
            return ModelGuesser.loadModelGuess(inputStream, this.solrResourceLoader.getInstancePath().toFile());
        } catch (Exception e) {
            throw new IOException("Failed to restore model from given file (" + this.serializedModelFileName + ")", e);
        }
    }

    /* JADX WARN: Type inference failed for: r0v7, types: [double[], double[][]] */
    protected INDArray getInputsFromTuple(Tuple tuple) {
        double[] dArr = new double[this.inputKeys.length];
        for (int i = 0; i < this.inputKeys.length; i++) {
            dArr[i] = tuple.getDouble(this.inputKeys[i]).doubleValue();
        }
        return Nd4j.create((double[][]) new double[]{dArr});
    }

    protected Tuple applyOutputsToTuple(Tuple tuple, INDArray iNDArray) {
        for (int i = 0; i < this.outputKeys.length; i++) {
            tuple.put(this.outputKeys[i], Float.valueOf(iNDArray.getFloat(i)));
        }
        return tuple;
    }
}
