package fr.emac.gind.ml;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import fr.emac.gind.commons.utils.jaxb.SOAException;
import fr.emac.gind.commons.utils.xml.XMLPrettyPrinter;
import fr.emac.gind.event.ml.MLHandler;
import fr.emac.gind.marshaller.XMLJAXBContext;
import fr.emac.gind.tweet.GJaxbTweet;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.lang.reflect.Field;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.xml.namespace.QName;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Document;

/* loaded from: input_file:fr/emac/gind/ml/RNNML.class */
public class RNNML extends MLHandler {
    private static Logger LOG = LoggerFactory.getLogger(RNNML.class);
    private static int NB_INPUTS = 100;
    private static int NB_OUTPUTS = 100;
    protected Map<Integer, String> dic = new HashMap();
    protected Map<String, Integer> revDic = new HashMap();
    protected Map<Integer, String> dicLabels = new HashMap();
    protected Map<String, Integer> revDicLabels = new HashMap();
    private MultiLayerNetwork model = null;

    public void initModel() throws Exception {
        File extractConfigurationFile = extractConfigurationFile(this.trainedModels);
        if (!this.useTrainedModel && this.datasetsToTrain != null) {
            List<File> buildFeaturesAndLabels = buildFeaturesAndLabels(extractConfigurationFile(this.datasetsToTrain));
            createDictionary(buildFeaturesAndLabels.get(0), buildFeaturesAndLabels.get(1));
            this.model = KerasModelImport.importKerasSequentialModelAndWeights(new FileInputStream(extractConfigurationFile));
            this.model.fit(readINDArray(buildFeaturesAndLabels.get(0)), readINDArray(buildFeaturesAndLabels.get(1)));
            return;
        }
        if (!this.useTrainedModel || this.features == null || this.labels == null) {
            return;
        }
        createDictionary(extractConfigurationFile(this.features), extractConfigurationFile(this.labels));
        this.model = KerasModelImport.importKerasSequentialModelAndWeights(new FileInputStream(extractConfigurationFile));
    }

    private INDArray readINDArray(File file) throws Exception {
        throw new Exception("Not Yet Implemented");
    }

    private List<File> buildFeaturesAndLabels(File file) throws Exception {
        throw new Exception("Not Yet Implemented");
    }

    public void execute(Document document, Map<QName, String> map) throws Exception {
        LOG.debug("analyze " + XMLPrettyPrinter.print(document));
        String textContent = document.getDocumentElement().getTextContent();
        if (document.getDocumentElement().getNodeName().contains("Tweet")) {
            LOG.debug("it's a tweet...");
            try {
                textContent = XMLJAXBContext.getInstance().unmarshallDocument(document, GJaxbTweet.class).getText();
            } catch (SOAException e) {
                throw new IOException((Throwable) e);
            }
        }
        cleanLine(textContent);
        LOG.debug("text to analyze " + textContent);
        if (this.model == null) {
            LOG.debug("model is null, initializing it");
            initModel();
        }
        NB_INPUTS = this.dic.size();
        NB_OUTPUTS = this.dicLabels.size();
        LOG.debug("dicts size (features/labels) " + NB_INPUTS + "/" + NB_OUTPUTS);
        List asList = Arrays.asList(cleanAndSplit(textContent));
        long j = 141;
        Layer layer = this.model.getLayer(0).conf().getLayer();
        Optional findAny = Arrays.asList(layer.getClass().getDeclaredFields()).stream().filter(field -> {
            return field.getName().contentEquals("inputLength");
        }).findAny();
        if (findAny.isPresent()) {
            Field field2 = (Field) findAny.get();
            field2.setAccessible(true);
            LOG.debug("set finalPad to " + field2.getInt(layer));
            j = field2.getInt(layer);
        }
        long j2 = j;
        INDArray zeros = Nd4j.zeros(new int[]{(int) j2});
        for (int size = asList.size() - 1; size > 0; size--) {
            if (this.revDic.get(asList.get(size)) != null) {
                zeros.putScalar(j2 - 1, Double.valueOf(this.revDic.get(asList.get(size)).intValue()).doubleValue());
            }
            j2--;
        }
        INDArray output = this.model.output(zeros.reshape(new long[]{1, j}));
        System.out.println(output.toStringFull());
        System.out.println("agmax " + output.argMax(new int[]{1}));
        int i = output.argMax(new int[]{1}).getInt(new int[]{0});
        if (List.of(2, 3, 5, 8, 9).contains(Integer.valueOf(i))) {
            LOG.debug("class of fact " + this.dicLabels.get(Integer.valueOf(i)));
            sendEventFromFoundConceptsWithSpecifiedConcept("fact", this.dicLabels.get(Integer.valueOf(i)), asList);
        }
    }

    public void createDictionary(File file, File file2) throws Exception {
        String readFileToString = FileUtils.readFileToString(file, Charset.defaultCharset());
        ObjectMapper objectMapper = new ObjectMapper();
        for (Map.Entry entry : ((Map) objectMapper.readValue(readFileToString, new TypeReference<Map<String, String>>() { // from class: fr.emac.gind.ml.RNNML.1
        })).entrySet()) {
            this.revDic.put((String) entry.getKey(), Integer.valueOf(Integer.parseInt((String) entry.getValue())));
            this.dic.put(Integer.valueOf(Integer.parseInt((String) entry.getValue())), (String) entry.getKey());
        }
        for (Map.Entry entry2 : ((Map) objectMapper.readValue(FileUtils.readFileToString(file2, Charset.defaultCharset()), new TypeReference<Map<String, String>>() { // from class: fr.emac.gind.ml.RNNML.2
        })).entrySet()) {
            this.dicLabels.put(Integer.valueOf(Integer.parseInt((String) entry2.getKey())), (String) entry2.getValue());
            this.revDicLabels.put((String) entry2.getValue(), Integer.valueOf(Integer.parseInt((String) entry2.getKey())));
        }
    }

    private static Map<String, String> createLabels(List<String> list) {
        HashMap hashMap = new HashMap();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            List asList = Arrays.asList(it.next().split(" "));
            String str = (String) asList.get(0);
            for (int i = 1; i < asList.size(); i++) {
                hashMap.put((String) asList.get(i), str);
            }
        }
        return hashMap;
    }

    public Map<Integer, String> getDic() {
        return this.dic;
    }

    public Map<String, Integer> getRevDic() {
        return this.revDic;
    }

    public Map<Integer, String> getDicLabels() {
        return this.dicLabels;
    }

    public Map<String, Integer> getRevDicLabels() {
        return this.revDicLabels;
    }

    public MultiLayerNetwork getModel() {
        return this.model;
    }

    public void setModel(MultiLayerNetwork multiLayerNetwork) {
        this.model = multiLayerNetwork;
    }
}
