package org.marketcetera.tensorflow;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.Iterator;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.runner.RunWith;
import org.marketcetera.core.publisher.Subscriber;
import org.marketcetera.module.DataFlowID;
import org.marketcetera.module.DataRequest;
import org.marketcetera.module.ModuleManager;
import org.marketcetera.module.ModuleURN;
import org.marketcetera.modules.headwater.HeadwaterModuleFactory;
import org.marketcetera.modules.publisher.PublisherModuleFactory;
import org.marketcetera.tensorflow.converter.TensorFlowConverterModuleFactory;
import org.marketcetera.tensorflow.converters.TensorFromObjectConverter;
import org.marketcetera.tensorflow.dao.GraphContainerDao;
import org.marketcetera.tensorflow.model.TensorFlowModelModuleFactory;
import org.marketcetera.tensorflow.model.TensorFlowRunner;
import org.marketcetera.tensorflow.service.TensorFlowService;
import org.marketcetera.util.log.SLF4JLoggerProxy;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;

@ContextConfiguration(locations = {"classpath:/test.xml"})
@RunWith(SpringJUnit4ClassRunner.class)
/* loaded from: input_file:org/marketcetera/tensorflow/TensorFlowTestBase.class */
public class TensorFlowTestBase {
    protected final Collection<DataFlowID> dataFlows = Lists.newArrayList();
    protected final Deque<Object> receivedData = Lists.newLinkedList();
    protected ModuleURN headwaterUrn;
    protected ModuleURN publisherUrn;
    protected String headwaterInstance;

    @Autowired
    protected ApplicationContext applicationContext;

    @Autowired
    protected ModuleManager moduleManager;

    @Autowired
    protected TensorFlowService tensorFlowService;

    @Autowired
    protected GraphContainerDao graphContainerDao;

    @Before
    public void setup() throws Exception {
    }

    @After
    public void cleanup() throws Exception {
        reset();
    }

    protected void reset() {
        synchronized (this.dataFlows) {
            Iterator<DataFlowID> it = this.dataFlows.iterator();
            while (it.hasNext()) {
                try {
                    this.moduleManager.cancel(it.next());
                } catch (Exception e) {
                }
            }
            this.dataFlows.clear();
        }
        synchronized (this.receivedData) {
            this.receivedData.clear();
        }
        this.graphContainerDao.deleteAll();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Object waitForData() throws Exception {
        long currentTimeMillis = System.currentTimeMillis();
        while (System.currentTimeMillis() < currentTimeMillis + 10000) {
            synchronized (this.receivedData) {
                Object poll = this.receivedData.poll();
                if (poll != null) {
                    return poll;
                }
                this.receivedData.wait(100L);
            }
        }
        Assert.fail("No tensor received in 10s");
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DataFlowID startConverterDataFlow(TensorFromObjectConverter<?> tensorFromObjectConverter) {
        DataFlowID createDataFlow = this.moduleManager.createDataFlow(getConverterDataRequest(tensorFromObjectConverter));
        synchronized (this.dataFlows) {
            this.dataFlows.add(createDataFlow);
        }
        return createDataFlow;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DataFlowID startModelDataFlow(TensorFromObjectConverter<?> tensorFromObjectConverter, TensorFlowRunner tensorFlowRunner) {
        DataFlowID createDataFlow = this.moduleManager.createDataFlow(getModelDataRequest(tensorFromObjectConverter, tensorFlowRunner));
        synchronized (this.dataFlows) {
            this.dataFlows.add(createDataFlow);
        }
        return createDataFlow;
    }

    protected DataRequest[] getConverterDataRequest(TensorFromObjectConverter<?> tensorFromObjectConverter) {
        ArrayList newArrayList = Lists.newArrayList();
        createHeadwaterModule();
        createPublisherModule();
        newArrayList.add(new DataRequest(this.headwaterUrn));
        newArrayList.add(new DataRequest(TensorFlowConverterModuleFactory.INSTANCE_URN, tensorFromObjectConverter));
        newArrayList.add(new DataRequest(this.publisherUrn));
        return (DataRequest[]) newArrayList.toArray(new DataRequest[newArrayList.size()]);
    }

    protected DataRequest[] getModelDataRequest(TensorFromObjectConverter<?> tensorFromObjectConverter, TensorFlowRunner tensorFlowRunner) {
        ArrayList newArrayList = Lists.newArrayList();
        createHeadwaterModule();
        createPublisherModule();
        newArrayList.add(new DataRequest(this.headwaterUrn));
        newArrayList.add(new DataRequest(TensorFlowConverterModuleFactory.INSTANCE_URN, tensorFromObjectConverter));
        newArrayList.add(new DataRequest(TensorFlowModelModuleFactory.INSTANCE_URN, tensorFlowRunner));
        newArrayList.add(new DataRequest(this.publisherUrn));
        return (DataRequest[]) newArrayList.toArray(new DataRequest[newArrayList.size()]);
    }

    protected DataRequest[] getModelDataRequest(String str) {
        ArrayList newArrayList = Lists.newArrayList();
        createHeadwaterModule();
        createPublisherModule();
        newArrayList.add(new DataRequest(this.headwaterUrn));
        newArrayList.add(new DataRequest(TensorFlowModelModuleFactory.INSTANCE_URN, str));
        newArrayList.add(new DataRequest(this.publisherUrn));
        return (DataRequest[]) newArrayList.toArray(new DataRequest[newArrayList.size()]);
    }

    protected ModuleURN createHeadwaterModule() {
        if (this.headwaterUrn == null) {
            this.headwaterInstance = "hw" + System.nanoTime();
            this.headwaterUrn = this.moduleManager.createModule(HeadwaterModuleFactory.PROVIDER_URN, new Object[]{this.headwaterInstance});
        }
        return this.headwaterUrn;
    }

    protected ModuleURN createPublisherModule() {
        if (this.publisherUrn == null) {
            this.publisherUrn = this.moduleManager.createModule(PublisherModuleFactory.PROVIDER_URN, new Object[]{new Subscriber() { // from class: org.marketcetera.tensorflow.TensorFlowTestBase.1
                public boolean isInteresting(Object obj) {
                    return true;
                }

                public void publishTo(Object obj) {
                    SLF4JLoggerProxy.debug(TensorFlowTestBase.this, "Received {}", new Object[]{obj});
                    synchronized (TensorFlowTestBase.this.receivedData) {
                        TensorFlowTestBase.this.receivedData.add(obj);
                        TensorFlowTestBase.this.receivedData.notifyAll();
                    }
                }
            }});
        }
        return this.publisherUrn;
    }
}
