package org.deeplearning4j.iterativereduce.tracker.statetracker.hazelcast;

import com.hazelcast.client.HazelcastClient;
import com.hazelcast.client.config.ClientConfig;
import com.hazelcast.config.Config;
import com.hazelcast.config.JoinConfig;
import com.hazelcast.config.ListConfig;
import com.hazelcast.config.MapConfig;
import com.hazelcast.core.Hazelcast;
import com.hazelcast.core.HazelcastInstance;
import com.hazelcast.core.IAtomicReference;
import com.hazelcast.core.IList;
import com.hazelcast.core.IMap;
import com.hazelcast.core.MemberAttributeEvent;
import com.hazelcast.core.MembershipEvent;
import com.hazelcast.core.MembershipListener;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.iterativereduce.actor.core.Job;
import org.deeplearning4j.iterativereduce.actor.util.PortTaken;
import org.deeplearning4j.iterativereduce.tracker.statetracker.DataSetCache;
import org.deeplearning4j.iterativereduce.tracker.statetracker.IterateAndUpdate;
import org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker;
import org.deeplearning4j.iterativereduce.tracker.statetracker.UpdateSaver;
import org.deeplearning4j.iterativereduce.tracker.statetracker.WorkRetriever;
import org.deeplearning4j.iterativereduce.tracker.statetracker.datasetcache.LocalDataSetCache;
import org.deeplearning4j.iterativereduce.tracker.statetracker.updatesaver.LocalFileUpdateSaver;
import org.deeplearning4j.iterativereduce.tracker.statetracker.workretriever.LocalWorkRetriever;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.optimize.OutputLayerTrainingEvaluator;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.deeplearning4j.scaleout.iterativereduce.Updateable;
import org.nd4j.linalg.dataset.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/tracker/statetracker/hazelcast/BaseHazelCastStateTracker.class */
public abstract class BaseHazelCastStateTracker<E extends Updateable<?>> implements StateTracker<E> {
    private static final long serialVersionUID = -7374372180080957334L;
    public static final String JOBS = "org.deeplearning4j.jobs";
    public static final String NUM_TIMES_PRETRAIN_RAN = "pretrainran";
    public static final String WORKERS = "org.deeplearning4j.workers";
    public static final String AVAILABLE_WORKERS = "AVAILABLE_WORKERS";
    public static final String NUM_TIMES_RUN_PRETRAIN = "PRETRAIN";
    public static final String TOPICS = "topics";
    public static final String RESULT = "RESULT";
    public static final String DONE = "done";
    public static final String UPDATES = "updates";
    public static final String REPLICATE_WEIGHTS = "replicate";
    public static final String HEART_BEAT = "heartbeat";
    public static final String WORKER_ENABLED = "workerenabled";
    public static final String INPUT_SPLIT = "inputsplit";
    public static final String IS_PRETRAIN = "ispretrain";
    public static final String BEST_LOSS = "bestloss";
    public static final String IMPROVEMENT_THRESHOLD = "improvementthreshold";
    public static final String VALIDATION_EPOCHS = "validationepochs";
    public static final String EARLY_STOP = "earlystop";
    public static final String PATIENCE = "patience";
    public static final String PATIENCE_INCREASE = "patienceincrease";
    public static final String BEGUN = "begun";
    public static final String NUM_BATCHES_SO_FAR_RAN = "numbatches";
    private volatile transient IAtomicReference<Object> master;
    private volatile transient IList<Job> jobs;
    private volatile transient IAtomicReference<Integer> numTimesPretrain;
    private volatile transient IAtomicReference<Integer> numTimesPretrainRan;
    private volatile transient IAtomicReference<Double> bestLoss;
    private volatile transient IAtomicReference<Double> improvementThreshold;
    private volatile transient IAtomicReference<Integer> numBatches;
    private volatile transient IAtomicReference<Boolean> earlyStop;
    private volatile transient IAtomicReference<Boolean> done;
    private volatile transient IList<String> replicate;
    private volatile transient IMap<String, Boolean> workerEnabled;
    private volatile transient IList<String> workers;
    private volatile transient IList<String> topics;
    private volatile transient IList<String> updates;
    private volatile IAtomicReference<Double> patience;
    private volatile IAtomicReference<Boolean> begunTraining;
    private volatile IAtomicReference<Double> patienceIncrease;
    private volatile IAtomicReference<Integer> validationEpochs;
    private volatile IAtomicReference<Integer> miniBatchSize;
    private WorkRetriever workRetriever;
    protected UpdateSaver<E> saver;
    private DataSetCache cache;
    private volatile IAtomicReference<Boolean> isPretrain;
    private static Logger log = LoggerFactory.getLogger(HazelCastStateTracker.class);
    private transient Config config;
    public static final int DEFAULT_HAZELCAST_PORT = 2510;
    private transient HazelcastInstance h;
    private String type;
    private int hazelCastPort;
    private String connectionString;
    private Map<String, Long> heartbeat;
    private StateTrackerDropWizardResource resource;
    public static final String HAZELCAST_HOST = "hazelcast.host";

    public BaseHazelCastStateTracker() throws Exception {
        this(DEFAULT_HAZELCAST_PORT);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public int numBatchesRan() {
        return ((Integer) this.numBatches.get()).intValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void incrementBatchesRan(int i) {
        this.numBatches.set(Integer.valueOf(i + ((Integer) this.numBatches.get()).intValue()));
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void startRestApi() {
        this.resource = new StateTrackerDropWizardResource(this);
        try {
            InputStream inputStream = new ClassPathResource("dropwizard.yml").getInputStream();
            File file = new File("dropwizard.yml");
            BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file));
            IOUtils.copy(inputStream, bufferedOutputStream);
            bufferedOutputStream.flush();
            this.resource.run(new String[]{"server", file.getAbsolutePath()});
            file.deleteOnExit();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public abstract UpdateSaver<E> createUpdateSaver();

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public int miniBatchSize() {
        return ((Integer) this.miniBatchSize.get()).intValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void beginTraining() {
        this.begunTraining.set(true);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public boolean hasBegun() {
        return ((Boolean) this.begunTraining.get()).booleanValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void removeWorkerData(String str) {
        this.workRetriever.clear(str);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public Collection<String> workerData() {
        return this.workRetriever.workers();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void setWorkRetriever(WorkRetriever workRetriever) {
        this.workRetriever = workRetriever;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public DataSet loadForWorker(String str) {
        return this.workRetriever.load(str);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void saveWorker(String str, DataSet dataSet) {
        this.workRetriever.save(str, dataSet);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public TrainingEvaluator create(BaseMultiLayerNetwork baseMultiLayerNetwork) {
        return new OutputLayerTrainingEvaluator.Builder().bestLoss(bestLoss()).improvementThreshold(improvementThreshold()).patience(patience()).testSet(testSet()).withNetwork(baseMultiLayerNetwork).validationEpochs(validationEpochs()).patienceIncrease(((Double) this.patienceIncrease.get()).doubleValue()).build();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void setDataSetCache(DataSetCache dataSetCache) {
        if (dataSetCache == null) {
            throw new IllegalArgumentException("Cache must not be null");
        }
        this.cache = dataSetCache;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void setImprovmentThreshold(double d) {
        this.improvementThreshold.set(Double.valueOf(d));
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public double improvementThreshold() {
        return ((Double) this.improvementThreshold.get()).doubleValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void setPatience(double d) {
        this.patience.set(Double.valueOf(d));
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public double patience() {
        return ((Double) this.patience.get()).doubleValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public double improvmentThreshold() {
        return ((Double) this.improvementThreshold.get()).doubleValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public DataSet testSet() {
        return this.cache.get();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void setBestLoss(double d) {
        this.bestLoss.set(Double.valueOf(d));
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public double bestLoss() {
        return ((Double) this.bestLoss.get()).doubleValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public int validationEpochs() {
        return ((Integer) this.validationEpochs.get()).intValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public boolean isEarlyStopTesting() {
        return ((Boolean) this.earlyStop.get()).booleanValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public Collection<String> workerUpdates() {
        return this.updates;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void setUpdateSaver(UpdateSaver<E> updateSaver) {
        this.saver = updateSaver;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public UpdateSaver<E> updateSaver() {
        return this.saver;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void setMiniBatchSize(int i) {
        this.miniBatchSize.set(Integer.valueOf(i));
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public int inputSplit() {
        return (((Integer) this.miniBatchSize.get()).intValue() * numWorkers()) / numWorkers();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public int partition() {
        return inputSplit();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public boolean workerEnabled(String str) {
        return this.workerEnabled.containsKey(str) && ((Boolean) this.workerEnabled.get(str)).booleanValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void enableWorker(String str) {
        this.workerEnabled.put(str, true);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void disableWorker(String str) {
        this.workerEnabled.put(str, false);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void doneReplicating(String str) {
        this.replicate.remove(str);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void addReplicate(String str) {
        if (this.replicate.contains(str)) {
            return;
        }
        this.replicate.add(str);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public boolean needsReplicate(String str) {
        return this.replicate.contains(str);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void addUpdate(String str, E e) {
        if (e == null) {
            return;
        }
        try {
            updateSaver().save(str, e);
            this.updates.add(str);
        } catch (Exception e2) {
            throw new RuntimeException(e2);
        }
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public abstract IterateAndUpdate<E> updates();

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void setConnectionString(String str) {
        this.connectionString = str;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public String connectionString() {
        return this.connectionString;
    }

    public BaseHazelCastStateTracker(int i) throws Exception {
        this("master", "master", i);
    }

    public BaseHazelCastStateTracker(String str) throws Exception {
        this(str, "worker", DEFAULT_HAZELCAST_PORT);
    }

    public BaseHazelCastStateTracker(String str, String str2, int i) throws Exception {
        String str3;
        this.workRetriever = new LocalWorkRetriever();
        this.cache = new LocalDataSetCache();
        this.type = "master";
        this.hazelCastPort = -1;
        log.info("Setting up hazelcast with type " + str2 + " connection string " + str + " and port " + i);
        if (str2.equals("master") && !PortTaken.portTaken(i)) {
            if (str.equals("master")) {
                try {
                    str3 = System.getProperty(HAZELCAST_HOST, InetAddress.getLocalHost().getHostName());
                } catch (Exception e) {
                    str3 = "0.0.0.0";
                }
                this.connectionString = str3 + ":" + i;
            }
            this.hazelCastPort = i;
            this.config = hazelcast();
            this.h = Hazelcast.newHazelcastInstance(this.config);
            this.h.getCluster().addMembershipListener(new MembershipListener() { // from class: org.deeplearning4j.iterativereduce.tracker.statetracker.hazelcast.BaseHazelCastStateTracker.1
                public void memberAdded(MembershipEvent membershipEvent) {
                    BaseHazelCastStateTracker.log.info("Member added " + membershipEvent.toString());
                }

                public void memberRemoved(MembershipEvent membershipEvent) {
                    BaseHazelCastStateTracker.log.info("Member removed " + membershipEvent.toString());
                }

                public void memberAttributeChanged(MemberAttributeEvent memberAttributeEvent) {
                    BaseHazelCastStateTracker.log.info("Member changed " + memberAttributeEvent.toString());
                }
            });
        } else {
            if (str2.equals("master") && PortTaken.portTaken(i)) {
                throw new IllegalStateException("Specified type was master and the port specified was taken, please specify a different port");
            }
            setConnectionString(str);
            log.info("Connecting to hazelcast on " + str);
            ClientConfig clientConfig = new ClientConfig();
            clientConfig.getNetworkConfig().addAddress(new String[]{str});
            this.h = HazelcastClient.newHazelcastClient(clientConfig);
        }
        this.type = str2;
        this.jobs = this.h.getList(JOBS);
        this.workers = this.h.getList(WORKERS);
        if (!this.type.equals("master")) {
            while (this.workers.isEmpty()) {
                log.warn("Waiting for data sync...");
                Thread.sleep(1000L);
            }
            log.info("Workers is " + this.workers.size());
        }
        this.begunTraining = this.h.getAtomicReference(BEGUN);
        this.miniBatchSize = this.h.getAtomicReference(INPUT_SPLIT);
        this.workerEnabled = this.h.getMap(WORKER_ENABLED);
        this.replicate = this.h.getList(REPLICATE_WEIGHTS);
        this.topics = this.h.getList("topics");
        this.updates = this.h.getList(UPDATES);
        this.heartbeat = this.h.getMap(HEART_BEAT);
        this.master = this.h.getAtomicReference(RESULT);
        this.isPretrain = this.h.getAtomicReference(IS_PRETRAIN);
        this.numTimesPretrain = this.h.getAtomicReference(NUM_TIMES_RUN_PRETRAIN);
        this.numTimesPretrainRan = this.h.getAtomicReference(NUM_TIMES_PRETRAIN_RAN);
        this.done = this.h.getAtomicReference(DONE);
        this.validationEpochs = this.h.getAtomicReference(VALIDATION_EPOCHS);
        this.improvementThreshold = this.h.getAtomicReference(IMPROVEMENT_THRESHOLD);
        this.bestLoss = this.h.getAtomicReference(BEST_LOSS);
        this.earlyStop = this.h.getAtomicReference(EARLY_STOP);
        this.patience = this.h.getAtomicReference(PATIENCE);
        this.patienceIncrease = this.h.getAtomicReference(PATIENCE_INCREASE);
        this.numBatches = this.h.getAtomicReference(NUM_BATCHES_SO_FAR_RAN);
        if (str2.equals("master")) {
            this.begunTraining.set(false);
            this.saver = createUpdateSaver();
            this.numTimesPretrainRan.set(0);
            this.numTimesPretrain.set(1);
            this.isPretrain.set(true);
            this.done.set(false);
            this.resource = new StateTrackerDropWizardResource(this);
            this.bestLoss.set(Double.valueOf(Double.POSITIVE_INFINITY));
            this.earlyStop.set(true);
            this.patience.set(Double.valueOf(40.0d));
            this.patienceIncrease.set(Double.valueOf(2.0d));
            this.improvementThreshold.set(Double.valueOf(0.995d));
            this.validationEpochs.set(Integer.valueOf((int) Math.min(10.0d, patience() / 2.0d)));
            this.numBatches.set(0);
        }
        this.workRetriever = new LocalWorkRetriever(this.h);
        this.cache = new LocalDataSetCache(".", this.h);
    }

    private Config hazelcast() {
        Config config = new Config();
        config.getNetworkConfig().setPort(this.hazelCastPort);
        config.getNetworkConfig().setPortAutoIncrement(false);
        config.setProperty("hazelcast.initial.min.cluster.size", "1");
        config.setProperty("hazelcast.shutdownhook.enabled", "false");
        JoinConfig join = config.getNetworkConfig().getJoin();
        boolean equals = System.getProperty("hazelcast.aws", "false").equals("true");
        log.info("Setting up Joiner with this being " + (equals ? "AWS" : "Multicast"));
        join.getAwsConfig().setEnabled(equals);
        if (equals) {
            join.getAwsConfig().setAccessKey(System.getProperty("hazelcast.access-key"));
            join.getAwsConfig().setSecretKey(System.getProperty("hazelcast.access-secret"));
        }
        join.getMulticastConfig().setEnabled(!equals);
        String property = System.getProperty("hazelcast.interface");
        if (property != null) {
            config.getNetworkConfig().getInterfaces().setEnabled(true).addInterface(property);
        }
        ListConfig listConfig = new ListConfig();
        listConfig.setName(JOBS);
        config.addListConfig(listConfig);
        ListConfig listConfig2 = new ListConfig();
        listConfig2.setName(REPLICATE_WEIGHTS);
        config.addListConfig(listConfig2);
        ListConfig listConfig3 = new ListConfig();
        listConfig3.setName("topics");
        config.addListConfig(listConfig3);
        ListConfig listConfig4 = new ListConfig();
        listConfig4.setName(UPDATES);
        config.addListConfig(listConfig4);
        ListConfig listConfig5 = new ListConfig();
        listConfig5.setName(AVAILABLE_WORKERS);
        config.addListConfig(listConfig5);
        MapConfig mapConfig = new MapConfig();
        mapConfig.setName(HEART_BEAT);
        config.addMapConfig(mapConfig);
        MapConfig mapConfig2 = new MapConfig();
        mapConfig2.setName(WORKER_ENABLED);
        config.addMapConfig(mapConfig2);
        MapConfig mapConfig3 = new MapConfig();
        mapConfig3.setName(LocalDataSetCache.DATA_SET_MAP);
        config.addMapConfig(mapConfig3);
        MapConfig mapConfig4 = new MapConfig();
        mapConfig4.setName(LocalFileUpdateSaver.UPDATE_SAVER);
        config.addMapConfig(mapConfig4);
        MapConfig mapConfig5 = new MapConfig();
        mapConfig5.setName(LocalWorkRetriever.WORK_RETRIEVER);
        config.addMapConfig(mapConfig5);
        return config;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public boolean addJobToCurrent(Job job) throws Exception {
        IAtomicReference atomicReference = this.h.getAtomicReference("job-" + job.getWorkerId());
        if (atomicReference.get() != null || !atomicReference.isNull()) {
            boolean z = false;
            while (!z) {
                for (String str : workers()) {
                    if (jobFor(str) == null) {
                        log.info("Redirecting worker " + job.getWorkerId() + " to " + str + " due to work already being allocated");
                        atomicReference = this.h.getAtomicReference("job-" + str);
                        job.setWorkerId(str);
                        z = true;
                    }
                }
            }
        }
        atomicReference.set(job);
        job.setWork(null);
        this.jobs.add(job);
        return true;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void setServerPort(int i) {
        this.hazelCastPort = i;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public int getServerPort() {
        return this.hazelCastPort;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public List<Job> currentJobs() throws Exception {
        return new ArrayList((Collection) this.jobs);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void updateJob(Job job) {
        this.h.getAtomicReference("job-" + job.getWorkerId()).set(job);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void clearJob(String str) throws Exception {
        if (str == null) {
            log.warn("No job to clear; was null, returning");
            return;
        }
        IAtomicReference atomicReference = this.h.getAtomicReference("job-" + str);
        if (atomicReference.isNull()) {
            return;
        }
        atomicReference.clear();
        log.info("Destroyed job ref " + str);
        Job job = null;
        Iterator it = this.jobs.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            Job job2 = (Job) it.next();
            if (job2.getWorkerId().equals(str)) {
                job = job2;
                break;
            }
        }
        this.jobs.remove(job);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void shutdown() {
        if (this.h != null) {
            this.h.shutdown();
        }
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void addTopic(String str) throws Exception {
        this.topics.add(str);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public List<String> topics() throws Exception {
        return this.topics;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public E getCurrent() throws Exception {
        E e = (E) this.master.get();
        if (e == null) {
            return null;
        }
        return e;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void setCurrent(E e) throws Exception {
        if (e == null || e.get() == null) {
            log.warn("Not setting a null update");
        } else {
            this.master.set(e);
        }
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public boolean isPretrain() {
        return ((Boolean) this.isPretrain.get()).booleanValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void moveToFinetune() {
        log.info("Moving to finetune");
        this.isPretrain.set(false);
        this.numBatches.set(0);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public Job jobFor(String str) {
        IAtomicReference atomicReference = this.h.getAtomicReference("job-" + str);
        if (atomicReference.isNull() || isCurrentlyJob(str)) {
            return null;
        }
        return (Job) atomicReference.get();
    }

    private boolean isCurrentlyJob(String str) {
        Iterator it = this.jobs.iterator();
        while (it.hasNext()) {
            if (((Job) it.next()).equals(str)) {
                return true;
            }
        }
        return false;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void availableForWork(String str) {
        if (this.workers.contains(str)) {
            return;
        }
        this.workers.add(str);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public List<String> jobIds() {
        ArrayList arrayList = new ArrayList();
        Iterator it = this.jobs.iterator();
        while (it.hasNext()) {
            arrayList.add(((Job) it.next()).getWorkerId());
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void addWorker(String str) {
        this.heartbeat.put(str, Long.valueOf(System.currentTimeMillis()));
        if (this.workers.contains(str)) {
            return;
        }
        log.info("Adding worker " + str);
        this.workers.add(str);
        log.info("Number of workers is now " + this.workers.size());
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void removeWorker(String str) {
        this.workers.remove(str);
        if (jobFor(str) != null) {
            try {
                clearJob(str);
            } catch (Exception e) {
                log.warn("Unable to clear job for worker with id" + str);
            }
        }
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public List<String> workers() {
        return this.workers;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public int numWorkers() {
        return this.workers.size();
    }

    public synchronized HazelcastInstance getH() {
        return this.h;
    }

    public synchronized void setH(HazelcastInstance hazelcastInstance) {
        this.h = hazelcastInstance;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public Map<String, Long> getHeartBeats() {
        return this.heartbeat;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void runPreTrainIterations(int i) {
        this.numTimesPretrain.set(Integer.valueOf(i));
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public int runPreTrainIterations() {
        return ((Integer) this.numTimesPretrain.get()).intValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public int numTimesPreTrainRun() {
        return ((Integer) this.numTimesPretrainRan.get()).intValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void incrementNumTimesPreTrainRan() {
        this.numTimesPretrainRan.set(Integer.valueOf(numTimesPreTrainRun() + 1));
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public boolean isDone() {
        try {
            return ((Boolean) this.done.get()).booleanValue();
        } catch (Exception e) {
            log.warn("Hazelcast already shutdown...returning true on isDone()");
            return true;
        }
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void finish() {
        try {
            this.done.set(true);
            updateSaver().cleanup();
        } catch (Exception e) {
            log.warn("Hazelcast already shutdown...done() being applyTransformToDestination is pointless");
        }
    }
}
