package org.deeplearning4j.iterativereduce.actor.multilayer;

import akka.actor.ActorRef;
import akka.actor.Props;
import akka.contrib.pattern.DistributedPubSubMediator;
import java.util.List;
import org.deeplearning4j.iterativereduce.actor.core.Ack;
import org.deeplearning4j.iterativereduce.actor.core.ClearWorker;
import org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.iterativereduce.multi.UpdateableImpl;
import org.nd4j.linalg.dataset.DataSet;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/multilayer/WorkerActor.class */
public class WorkerActor extends org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor<UpdateableImpl> {
    public WorkerActor(Conf conf, StateTracker<UpdateableImpl> stateTracker) throws Exception {
        super(conf, stateTracker);
        setup(conf);
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.BROADCAST, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(this.id, getSelf()), getSelf());
        heartbeat();
        stateTracker.addWorker(this.id);
    }

    public WorkerActor(ActorRef actorRef, Conf conf, StateTracker<UpdateableImpl> stateTracker) throws Exception {
        super(conf, actorRef, stateTracker);
        setup(conf);
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.BROADCAST, getSelf()), getSelf());
        stateTracker.addWorker(this.id);
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(this.id, getSelf()), getSelf());
        heartbeat();
    }

    public static Props propsFor(ActorRef actorRef, Conf conf, StateTracker<UpdateableImpl> stateTracker) {
        return Props.create(WorkerActor.class, new Object[]{actorRef, conf, stateTracker});
    }

    public static Props propsFor(Conf conf, StateTracker<UpdateableImpl> stateTracker) {
        return Props.create(WorkerActor.class, new Object[]{conf, stateTracker});
    }

    public void onReceive(Object obj) throws Exception {
        if ((obj instanceof DistributedPubSubMediator.SubscribeAck) || (obj instanceof DistributedPubSubMediator.UnsubscribeAck)) {
            this.mediator.tell(new DistributedPubSubMediator.Publish("topics", obj), getSelf());
            log.info("Subscribed to " + ((DistributedPubSubMediator.SubscribeAck) obj).toString());
        } else {
            if (obj instanceof BaseMultiLayerNetwork) {
                if (this.results == 0) {
                    this.results = new UpdateableImpl((BaseMultiLayerNetwork) obj);
                } else {
                    this.results.set((BaseMultiLayerNetwork) obj);
                }
                log.info("Set network");
                return;
            }
            if (obj instanceof Ack) {
                log.info("Ack from master on worker " + this.id);
            } else {
                unhandled(obj);
            }
        }
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public UpdateableImpl compute(List<UpdateableImpl> list) {
        return compute();
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public UpdateableImpl compute() {
        if (this.tracker.isDone()) {
            return null;
        }
        if (!this.tracker.workerEnabled(this.id)) {
            log.info("Worker " + this.id + " should be re enabled if not doing work");
            return null;
        }
        log.info("Training network on worker " + this.id);
        BaseMultiLayerNetwork baseMultiLayerNetwork = getResults().get();
        this.isWorking.set(true);
        while (baseMultiLayerNetwork == null) {
            try {
                baseMultiLayerNetwork = this.tracker.getCurrent().get();
                this.results.set(baseMultiLayerNetwork);
                log.info("Network is currently null");
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        DataSet dataSet = null;
        if (this.currentJob == null || !this.tracker.workerEnabled(this.id)) {
            log.warn("No job found for " + this.id + " despite compute being called");
        } else {
            log.info("Found job for worker " + this.id);
            dataSet = this.currentJob.getWork() instanceof List ? DataSet.merge((List) this.currentJob.getWork()) : (DataSet) this.currentJob.getWork();
        }
        if (this.currentJob == null) {
            return null;
        }
        if (dataSet == null) {
            throw new IllegalStateException("No job found for worker " + this.id);
        }
        if (this.conf.isNormalizeZeroMeanAndUnitVariance()) {
            dataSet.normalizeZeroMeanZeroUnitVariance();
        }
        if (this.conf.isScale()) {
            dataSet.scale();
        }
        if (dataSet.getFeatureMatrix() == null || dataSet.getLabels() == null) {
            throw new IllegalStateException("Input cant be null");
        }
        if (this.tracker.isPretrain()) {
            log.info("Worker " + this.id + " pretraining");
            baseMultiLayerNetwork.pretrain(dataSet.getFeatureMatrix(), this.conf.getDeepLearningParams());
        } else {
            baseMultiLayerNetwork.setInput(dataSet.getFeatureMatrix());
            log.info("Worker " + this.id + " finetune");
            if (this.tracker.testSet() != null) {
                baseMultiLayerNetwork.finetune(dataSet.getLabels(), this.conf.getConf().getFinetuneLearningRate(), this.conf.getConf().getFinetuneEpochs(), this.tracker.create(baseMultiLayerNetwork));
            } else {
                baseMultiLayerNetwork.finetune(dataSet.getLabels(), this.conf.getConf().getFinetuneLearningRate(), this.conf.getConf().getFinetuneEpochs(), (TrainingEvaluator) null);
            }
        }
        try {
            if (!this.tracker.isDone()) {
                this.tracker.clearJob(this.id);
            }
            if (!this.tracker.isDone()) {
                this.isWorking.set(false);
            }
            return new UpdateableImpl(baseMultiLayerNetwork);
        } catch (Exception e2) {
            throw new RuntimeException(e2);
        }
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public boolean incrementIteration() {
        return false;
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public void setup(Conf conf) {
        super.setup(conf);
    }

    public void aroundPostStop() {
        super.aroundPostStop();
        this.mediator.tell(new DistributedPubSubMediator.Publish(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.MASTER, new ClearWorker(this.id)), getSelf());
        this.heartbeat.cancel();
    }

    /* JADX WARN: Type inference failed for: r1v3, types: [E extends org.deeplearning4j.scaleout.iterativereduce.Updateable<?>, org.deeplearning4j.scaleout.iterativereduce.Updateable] */
    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public UpdateableImpl getResults() {
        try {
            if (this.results == 0) {
                this.results = this.tracker.getCurrent();
            }
            return this.results;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public void update(UpdateableImpl updateableImpl) {
        this.results = updateableImpl;
    }
}
