package BNlearning;

import com.agenarisk.learning.structure.config.Config;
import com.agenarisk.learning.structure.exception.StructureLearningException;
import com.agenarisk.learning.structure.logger.BLogger;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:BNlearning/SaiyanH_phase_1a_marginalDep.class */
public class SaiyanH_phase_1a_marginalDep {
    private static FileWriter writer;
    private ArrayList<Integer>[] directedDependencies;
    private ArrayList<Double>[] dependencyScores;
    private Double[][] undirectedDependencies;
    public Double[][] undirectedAveragedDependencies;
    public List<List<String>> marginalDep = new ArrayList();
    public List<List<String>> marginalDepBelowTheta = new ArrayList();
    public Integer undirectedDependenciesCounter = 0;
    private String graphDetails = "";

    public SaiyanH_phase_1a_marginalDep(Integer num, Integer num2, constraintsDirected constraintsdirected, constraintsUndirected constraintsundirected, constraintsForbidden constraintsforbidden, constraintsTemporal constraintstemporal, Double d, constraintsInputGraph constraintsinputgraph) throws Exception {
        this.dependencyScores = new ArrayList[num.intValue()];
        this.directedDependencies = new ArrayList[num.intValue()];
        this.undirectedDependencies = new Double[num.intValue()][num.intValue()];
        this.undirectedAveragedDependencies = new Double[num.intValue()][num.intValue()];
        for (int i = 0; i < num.intValue(); i++) {
            this.directedDependencies[i] = new ArrayList<>();
            this.dependencyScores[i] = new ArrayList<>();
            for (int i2 = 0; i2 < num.intValue(); i2++) {
                this.undirectedDependencies[i][i2] = Double.valueOf(0.0d);
                this.undirectedAveragedDependencies[i][i2] = Double.valueOf(0.0d);
            }
        }
        initialisePairwiseSearch(num.intValue(), num2.intValue());
        saveAveragedDependencies(constraintsdirected, constraintsundirected, constraintsforbidden, constraintstemporal, constraintsinputgraph);
        saveMarginalDep(d);
        if (Config.getInstance().getLearningSaiyanHSaveAssocScores().booleanValue()) {
            saveDependenciesFile();
        }
        global.pairwiseDependencies = global.copyNestedList(this.marginalDep);
    }

    private void saveMarginalDep(Double d) {
        Integer num = 0;
        Integer num2 = 0;
        for (int i = 0; i < global.varCount.intValue(); i++) {
            for (int i2 = i + 1; i2 < global.varCount.intValue(); i2++) {
                this.marginalDep.add(new ArrayList());
                this.marginalDep.get(num.intValue()).add(global.getVariableName(Integer.valueOf(i)));
                this.marginalDep.get(num.intValue()).add(global.getVariableName(Integer.valueOf(i2)));
                this.marginalDep.get(num.intValue()).add(Double.toString(this.undirectedAveragedDependencies[i][i2].doubleValue()));
                num = Integer.valueOf(num.intValue() + 1);
                if (Double.compare(this.undirectedAveragedDependencies[i][i2].doubleValue(), d.doubleValue()) < 0) {
                    this.marginalDepBelowTheta.add(new ArrayList());
                    this.marginalDepBelowTheta.get(num2.intValue()).add(global.getVariableName(Integer.valueOf(i)));
                    this.marginalDepBelowTheta.get(num2.intValue()).add(global.getVariableName(Integer.valueOf(i2)));
                    num2 = Integer.valueOf(num2.intValue() + 1);
                }
            }
        }
    }

    private void initialisePairwiseSearch(int i, int i2) throws Exception {
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i; i4++) {
                if (i4 != i3) {
                    handleSearchMMD(i3, i4, i2);
                }
            }
        }
        averageDependencies(Integer.valueOf(i), true);
    }

    private void handleSearchMMD(int i, int i2, int i3) {
        int i4 = 0;
        ArrayList arrayList = new ArrayList();
        for (int i5 = 0; i5 < global.states[i].size(); i5++) {
            arrayList.add(Double.valueOf(0.0d));
        }
        for (int i6 = 0; i6 < global.states[i2].size(); i6++) {
            for (int i7 = 0; i7 < i3; i7++) {
                try {
                    if (Database.trainingData[i7 + 1][i2].equals(global.states[i2].get(i6))) {
                        i4++;
                        int i8 = 0;
                        while (i8 < global.states[i].size()) {
                            if (global.states[i].get(i8).equals(Database.trainingData[i7 + 1][i])) {
                                arrayList.set(i8, Double.valueOf(((Double) arrayList.get(i8)).doubleValue() + 1.0d));
                                i8 = global.states[i].size();
                            }
                            i8++;
                        }
                    }
                } catch (Exception e) {
                    BLogger.out.println("\u001b[31mMissing data value found in data row " + (i7 + 1));
                    BLogger.out.println("\u001b[31mTo perform structure learning with missing data, fill empty cells with a new value that corresponds to missing data; e.g., 'missing'.");
                    BLogger.out.println("\u001b[31mSystem exits.");
                    throw new StructureLearningException("Missing data value found in data row " + (i7 + 1) + " To perform structure learning with missing data, fill empty cells with a new value that corresponds to missing data; e.g., 'missing'.");
                }
            }
            for (int i9 = 0; i9 < arrayList.size(); i9++) {
                arrayList.set(i9, Double.valueOf(((Double) arrayList.get(i9)).doubleValue() / i4));
            }
            saveDependencies(i, i2, Double.valueOf(scoreMMD.getDistDif(global.priorDistributions[i], arrayList, Integer.valueOf(i), Integer.valueOf(i2), -1)).doubleValue());
            for (int i10 = 0; i10 < global.states[i].size(); i10++) {
                arrayList.set(i10, Double.valueOf(0.0d));
            }
            i4 = 0;
        }
    }

    private void saveDependencies(int i, int i2, double d) {
        if (this.directedDependencies[i].contains(Integer.valueOf(i2))) {
            int indexOf = this.directedDependencies[i].indexOf(Integer.valueOf(i2));
            this.dependencyScores[i].set(indexOf, Double.valueOf(d + this.dependencyScores[i].get(indexOf).doubleValue()));
        } else {
            this.dependencyScores[i].add(Double.valueOf(d));
            this.directedDependencies[i].add(Integer.valueOf(i2));
        }
    }

    private void saveAveragedDependencies(constraintsDirected constraintsdirected, constraintsUndirected constraintsundirected, constraintsForbidden constraintsforbidden, constraintsTemporal constraintstemporal, constraintsInputGraph constraintsinputgraph) throws IOException {
        for (int i = 0; i < global.varCount.intValue(); i++) {
            for (int i2 = 0; i2 < global.varCount.intValue(); i2++) {
                this.undirectedAveragedDependencies[i][i2] = getPairwiseScore(i, i2, true, constraintsdirected, constraintsundirected, constraintsforbidden, constraintstemporal, constraintsinputgraph);
            }
        }
    }

    private void saveDependenciesFile() throws IOException {
        writer = new FileWriter(Config.getInstance().getPathOutput().resolve(Config.getInstance().getLearningAlgorithm().toString()).resolve("marginalDep.csv").toString());
        for (int i = 0; i < this.marginalDep.size(); i++) {
            writer.append((CharSequence) this.marginalDep.get(i).get(0));
            writer.append(',');
            writer.append((CharSequence) this.marginalDep.get(i).get(1));
            writer.append(',');
            writer.append((CharSequence) this.marginalDep.get(i).get(2));
            writer.append('\n');
        }
        writer.flush();
        writer.close();
        BLogger.out.println("marginalDep.csv saved.");
    }

    private void saveDependencyScoresFile() throws IOException {
        writer = new FileWriter(Config.getInstance().getPathOutput().resolve("SaiyanH").resolve("marginalDepScores.csv").toString());
        writer.append((CharSequence) "");
        writer.append(',');
        for (int i = 0; i < global.varCount.intValue(); i++) {
            writer.append((CharSequence) global.getVariableName(Integer.valueOf(i)));
            writer.append(',');
        }
        writer.append('\n');
        for (int i2 = 0; i2 < global.varCount.intValue(); i2++) {
            writer.append((CharSequence) global.getVariableName(Integer.valueOf(i2)));
            writer.append(',');
            for (int i3 = 0; i3 < global.varCount.intValue(); i3++) {
                writer.append((CharSequence) Double.toString(this.undirectedAveragedDependencies[i2][i3].doubleValue()));
                writer.append(',');
            }
            writer.append('\n');
        }
        writer.flush();
        writer.close();
        BLogger.out.println("marginalDepScores.csv saved.");
    }

    private void saveDirectedDependencyScoresFile() throws IOException {
        writer = new FileWriter(Config.getInstance().getPathOutput().resolve("SaiyanH").resolve("marginalDirectedDepScores.csv").toString());
        writer.append((CharSequence) "");
        writer.append(',');
        for (int i = 0; i < global.varCount.intValue(); i++) {
            writer.append((CharSequence) global.getVariableName(Integer.valueOf(i)));
            writer.append(',');
        }
        writer.append('\n');
        for (int i2 = 0; i2 < global.varCount.intValue(); i2++) {
            writer.append((CharSequence) global.getVariableName(Integer.valueOf(i2)));
            writer.append(',');
            for (int i3 = 0; i3 < global.varCount.intValue(); i3++) {
                writer.append((CharSequence) Double.toString(this.undirectedDependencies[i2][i3].doubleValue()));
                writer.append(',');
            }
            writer.append('\n');
        }
        writer.flush();
        writer.close();
        BLogger.out.println("marginalDirectedDepScores.csv saved.");
    }

    private void averageDependencies(Integer num, Boolean bool) {
        for (int i = 0; i < num.intValue(); i++) {
            for (int i2 = 0; i2 < this.directedDependencies[i].size(); i2++) {
                Double d = this.dependencyScores[i].get(i2);
                if (bool.booleanValue()) {
                    Double valueOf = Double.valueOf(d.doubleValue() / Integer.valueOf(global.states[this.directedDependencies[i].get(i2).intValue()].size()).intValue());
                    this.dependencyScores[i].set(i2, valueOf);
                    this.undirectedDependencies[this.directedDependencies[i].get(i2).intValue()][i] = valueOf;
                } else {
                    this.dependencyScores[i].set(i2, d);
                    this.undirectedDependencies[this.directedDependencies[i].get(i2).intValue()][i] = d;
                }
            }
        }
    }

    private void printDependenciesAndScores(int i) {
        BLogger.out.println("Printing the dependencies discovered per set of variables");
        for (int i2 = 0; i2 < i; i2++) {
            BLogger.out.println("Parent/s discovered for node " + Database.trainingData[0][i2] + ": ");
            for (int i3 = 0; i3 < this.directedDependencies[i2].size(); i3++) {
                if (this.directedDependencies[i2].get(i3).intValue() < 0) {
                    BLogger.out.println("No dependencies discovered.");
                } else {
                    BLogger.out.println(Database.trainingData[0][this.directedDependencies[i2].get(i3).intValue()] + " with score " + this.dependencyScores[i2].get(i3));
                }
            }
        }
    }

    public Double getPairwiseScore(int i, int i2, Boolean bool, constraintsDirected constraintsdirected, constraintsUndirected constraintsundirected, constraintsForbidden constraintsforbidden, constraintsTemporal constraintstemporal, constraintsInputGraph constraintsinputgraph) {
        if (bool.booleanValue()) {
            Integer num = this.undirectedDependenciesCounter;
            this.undirectedDependenciesCounter = Integer.valueOf(this.undirectedDependenciesCounter.intValue() + 1);
        }
        return (!(Config.getInstance().getConstraintsProhibitEdgesSameTemporalTier().booleanValue() && constraintstemporal.sameTier(Integer.valueOf(i), Integer.valueOf(i2)).booleanValue()) && !constraintsforbidden.forbiddenConstraint(Integer.valueOf(i), Integer.valueOf(i2)).booleanValue() && global.states[i].size() >= 2 && global.states[i2].size() >= 2) ? (constraintsdirected.edgeConstraint(Integer.valueOf(i), Integer.valueOf(i2)).booleanValue() || constraintsundirected.edgeConstraint(Integer.valueOf(i), Integer.valueOf(i2)).booleanValue() || constraintsinputgraph.edgeConstraint(Integer.valueOf(i), Integer.valueOf(i2)).booleanValue() || constraintsundirected.edgeConstraint(Integer.valueOf(i), Integer.valueOf(i2)).booleanValue()) ? Double.valueOf(1.0d) : Double.valueOf((this.undirectedDependencies[i][i2].doubleValue() + this.undirectedDependencies[i2][i].doubleValue()) / 2.0d) : Double.valueOf(0.0d);
    }

    public Double getDirectedScore(int i, int i2) {
        return this.undirectedDependencies[i][i2];
    }

    private void generateGraph1(constraintsDirected constraintsdirected) throws Exception {
        String str = "edge [dir=none] ";
        for (int i = 0; i < this.marginalDep.size(); i++) {
            str = str + this.marginalDep.get(i).get(0).replaceAll("\\W", "") + "->" + this.marginalDep.get(i).get(1).replaceAll("\\W", "") + "[label=\"" + String.format("%.3f", Double.valueOf(Double.parseDouble(this.marginalDep.get(i).get(2)))) + "\"];";
        }
        handleGraphDetails();
        global.graphVizInputFunction(str + this.graphDetails, "Saiyan2_Phase_1", "learning");
    }

    private void handleGraphDetails() {
        this.graphDetails = "graph[fontname=Arial, fontsize = 10,  label=\"Algorithm: " + Config.getInstance().getLearningAlgorithm() + " \\lRelationship score: " + (Config.getInstance().getLearningSaiyanHdiscScoreType().equals("MI") ? "" : Config.getInstance().getLearningSaiyanHdiscScoreType() + "[" + Config.getInstance().getLearningSaiyanHdiscDistanceType() + "]") + " \\lPhase_1: Undirected graph with all dependencies that have dependency score > " + global.dependencyThresholdMI + " \\lTotal dependencies: " + this.marginalDep.size() + "\\lDepencences with a (max) score of 1.0 indicate directed constraints. \\l\"]";
    }
}
