package ann;

import data.EstimPars;
import data.Stats;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.Serializable;
import utils.Error;

/* loaded from: input_file:lib/artificialneuralnets.jar:ann/ANNEnsemble.class */
public class ANNEnsemble implements Serializable {
    public static final int BEST = 0;
    public static final int AVERAGE = 1;
    public static final int PROB = 2;
    public static final int VOTING = 3;
    public static final int WEIG_AVER_RANK = 4;
    FF[] anns;
    double[] values;
    double[] weights;
    int num_anns;

    public ANNEnsemble() {
        this.weights = null;
    }

    public ANNEnsemble(int i) {
        this.weights = null;
        this.anns = new FF[i];
        this.values = new double[i];
        this.weights = new double[i];
        this.num_anns = 0;
    }

    public ANNEnsemble(String str, int i) throws Exception {
        this.weights = null;
        this.anns = new FF[i];
        this.values = new double[i];
        this.num_anns = 0;
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
        for (int i2 = 0; i2 < i; i2++) {
            this.anns[i2] = new FF(bufferedReader);
            this.values[i2] = Double.valueOf(bufferedReader.readLine()).doubleValue();
        }
    }

    public void save(String str) throws Exception {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str));
        for (int i = 0; i < this.anns.length; i++) {
            this.anns[i].print(bufferedWriter);
            bufferedWriter.write(String.valueOf(this.values[i]) + "\n");
        }
    }

    public void addANN(FF ff, double d) {
        if (this.num_anns < this.anns.length) {
            int i = this.num_anns;
            this.anns[i] = ff;
            this.values[i] = d;
            this.num_anns++;
            return;
        }
        double d2 = this.values[0];
        int i2 = 0;
        for (int i3 = 1; i3 < this.anns.length; i3++) {
            if (this.values[i3] > d2) {
                d2 = this.values[i3];
                i2 = i3;
            }
        }
        if (d < d2) {
            this.anns[i2] = ff;
            this.values[i2] = d;
        }
    }

    public void setANN(FF ff, int i) {
        this.anns[i] = ff;
    }

    public int getSize() {
        return this.anns.length;
    }

    public FF getBest() {
        double d = this.values[0];
        int i = 0;
        for (int i2 = 1; i2 < this.num_anns; i2++) {
            if (this.values[i2] < d) {
                d = this.values[i2];
                i = i2;
            }
        }
        return this.anns[i];
    }

    public FF getANN(int i) {
        return this.anns[i];
    }

    public double getValue(int i) {
        return this.values[i];
    }

    public double getWeight(int i) {
        return this.weights[i];
    }

    public void setValue(int i, double d) {
        this.values[i] = d;
    }

    public void setWeight(int i, double d) {
        this.weights[i] = d;
    }

    public ANNTestStats getAllErrors(NNDataMatrix nNDataMatrix, int i) {
        return getAllErrors(nNDataMatrix, new Cases(nNDataMatrix), i);
    }

    public ANNTestStats getAllErrors(NNDataMatrix nNDataMatrix, Cases cases, int i) {
        ANNTestStats aNNTestStats;
        if (i == 0) {
            aNNTestStats = getBest().getAllErrors(nNDataMatrix, cases);
        } else {
            aNNTestStats = new ANNTestStats();
            aNNTestStats.TrainStats = error(nNDataMatrix, i, 1);
            aNNTestStats.ValidStats = error(nNDataMatrix, i, 3);
            aNNTestStats.Valid2Stats = error(nNDataMatrix, i, 4);
            aNNTestStats.TestStats = error(nNDataMatrix, i, 2);
            aNNTestStats.nweights = 0;
            aNNTestStats.nfolds = 1;
        }
        return aNNTestStats;
    }

    public ANNTestStats trainGetAllErrors(NNDataMatrix nNDataMatrix, ANNPars aNNPars, Cases cases, int i) {
        Error error = new Error(1);
        for (int i2 = 0; i2 < this.num_anns; i2++) {
            this.anns[i2].random_weights(-1.0d, 1.0d);
            this.anns[i2].ep = 0;
            this.anns[i2].earlystopping(aNNPars, cases);
        }
        if (i == 4) {
            for (int i3 = 0; i3 < this.num_anns; i3++) {
                this.values[i3] = this.anns[i3].error(error, cases, 1);
            }
            for (int i4 = 0; i4 < this.num_anns; i4++) {
                this.weights[i4] = 1.0d;
                for (int i5 = 0; i5 < i4; i5++) {
                    if (this.values[i5] < this.values[i4]) {
                        this.weights[i5] = this.weights[i5] + 1.0d;
                    } else {
                        this.weights[i4] = this.weights[i4] + 1.0d;
                    }
                }
            }
        }
        return getAllErrors(nNDataMatrix, cases, i);
    }

    public ANNTestStats trainGetAllErrors(NNDataMatrix nNDataMatrix, ANNPars aNNPars, int i) {
        Cases cases = new Cases(nNDataMatrix);
        cases.save("casos.txt");
        return trainGetAllErrors(nNDataMatrix, aNNPars, cases, i);
    }

    public ANNAvgTestStats estimateError(NNDataMatrix nNDataMatrix, ANNPars aNNPars, EstimPars estimPars, int i) {
        ANNTestStats[] aNNTestStatsArr = new ANNTestStats[estimPars.getRuns()];
        for (int i2 = 0; i2 < estimPars.getRuns(); i2++) {
            if (estimPars.getShuffle()) {
                nNDataMatrix.getDM().shuffleAllExamples();
            }
            switch (estimPars.getMethod()) {
                case 0:
                    aNNTestStatsArr[i2] = trainingError(nNDataMatrix, aNNPars, estimPars, i);
                    break;
                case 1:
                    aNNTestStatsArr[i2] = holdoutEstimate(nNDataMatrix, aNNPars, estimPars, i);
                    break;
                case 2:
                    if (estimPars.getStrat()) {
                        nNDataMatrix.getDM().createKfoldStratPartitions(estimPars.getFolds());
                    }
                    aNNTestStatsArr[i2] = crossValidation(nNDataMatrix, aNNPars, estimPars.getFolds(), estimPars, i);
                    break;
                case 3:
                    aNNTestStatsArr[i2] = crossValidation(nNDataMatrix, aNNPars, nNDataMatrix.getDM().getUsedNumExamples(), estimPars, i);
                    break;
                case 4:
                default:
                    System.out.println("ERROR:NOT IMPLEMENTED");
                    break;
                case 5:
                    aNNTestStatsArr[i2] = trainGetAllErrors(nNDataMatrix, aNNPars, i);
                    break;
            }
        }
        return new ANNAvgTestStats(aNNTestStatsArr);
    }

    public ANNTestStats crossValidation(NNDataMatrix nNDataMatrix, ANNPars aNNPars, int i, EstimPars estimPars, int i2) {
        ANNTestStats aNNTestStats = null;
        Cases cases = new Cases(nNDataMatrix);
        for (int i3 = 0; i3 < i; i3++) {
            nNDataMatrix.getDM().Kfold(i, i3);
            if (estimPars.getUseVal() > 0) {
                nNDataMatrix.getDM().setValidation(estimPars.getUseVal());
            }
            cases.update_indexes(nNDataMatrix);
            ANNTestStats trainGetAllErrors = trainGetAllErrors(nNDataMatrix, aNNPars, cases, i2);
            if (i3 == 0) {
                aNNTestStats = trainGetAllErrors;
            } else {
                aNNTestStats.sum(trainGetAllErrors);
            }
        }
        return aNNTestStats;
    }

    public ANNTestStats holdoutEstimate(NNDataMatrix nNDataMatrix, ANNPars aNNPars, EstimPars estimPars, int i) {
        nNDataMatrix.getDM().holdout(estimPars.getTest());
        if (estimPars.getUseVal() > 0) {
            nNDataMatrix.getDM().setValidation(estimPars.getUseVal());
        }
        return trainGetAllErrors(nNDataMatrix, aNNPars, i);
    }

    public ANNTestStats trainingError(NNDataMatrix nNDataMatrix, ANNPars aNNPars, EstimPars estimPars, int i) {
        if (estimPars.getUseVal() > 0) {
            nNDataMatrix.getDM().setValidation(estimPars.getUseVal());
        }
        return trainGetAllErrors(nNDataMatrix, aNNPars, i);
    }

    public Stats error(NNDataMatrix nNDataMatrix, int i, int i2) {
        if (i == 0) {
            return getBest().error(nNDataMatrix, i2);
        }
        Stats stats = null;
        if (nNDataMatrix.getDM().getNumberExamples(i2) > 0) {
            stats = nNDataMatrix.getStats(getOutputs(nNDataMatrix, i2, i), 0, i2);
        }
        return stats;
    }

    public NNOutput[][] getOutputs(NNDataMatrix nNDataMatrix, int i, int i2) {
        NNOutput[][] treatOutputs;
        Cases cases = new Cases(nNDataMatrix);
        int i3 = 0;
        switch (i) {
            case 1:
                i3 = 0;
                break;
            case 2:
                i3 = 2;
                break;
            case 3:
                i3 = 1;
                break;
            case 4:
                i3 = 3;
                break;
        }
        double[][][] dArr = new double[this.num_anns][];
        if (i2 == 1 || i2 == 4) {
            for (int i4 = 0; i4 < this.num_anns; i4++) {
                dArr[i4] = this.anns[i4].outputs(cases, i3);
            }
            double[][] dArr2 = new double[dArr[0].length][dArr[0][0].length];
            for (int i5 = 0; i5 < dArr2.length; i5++) {
                for (int i6 = 0; i6 < dArr2[0].length; i6++) {
                    double d = 0.0d;
                    if (i2 == 1) {
                        for (int i7 = 0; i7 < this.num_anns; i7++) {
                            d += dArr[i7][i5][i6];
                        }
                        dArr2[i5][i6] = d / this.num_anns;
                    } else {
                        for (int i8 = 0; i8 < this.num_anns; i8++) {
                            d += this.weights[i8] * dArr[i8][i5][i6];
                        }
                        dArr2[i5][i6] = d / ((this.num_anns * (this.num_anns + 1)) / 2.0d);
                    }
                }
            }
            treatOutputs = nNDataMatrix.treatOutputs(dArr2, i);
        } else {
            NNOutput[][][] nNOutputArr = new NNOutput[this.num_anns][];
            for (int i9 = 0; i9 < this.num_anns; i9++) {
                dArr[i9] = this.anns[i9].outputs(cases, i3);
                nNOutputArr[i9] = new NNOutput[dArr[i9].length];
                int i10 = 0;
                for (int i11 = 0; i11 < nNDataMatrix.getDM().getNumberExamples(); i11++) {
                    if (nNDataMatrix.getDM().getStatus(i11) == i) {
                        nNOutputArr[i9][i10] = nNDataMatrix.treatOutputs(dArr[i9][i10]);
                        i10++;
                    }
                }
            }
            treatOutputs = consolidate(nNOutputArr, i2);
        }
        return treatOutputs;
    }

    public NNOutput[][] consolidate(NNOutput[][][] nNOutputArr, int i) {
        NNOutput[][] nNOutputArr2 = new NNOutput[nNOutputArr[0].length][nNOutputArr[0][0].length];
        for (int i2 = 0; i2 < nNOutputArr[0].length; i2++) {
            for (int i3 = 0; i3 < nNOutputArr[0][i2].length; i3++) {
                NNOutput[] nNOutputArr3 = new NNOutput[nNOutputArr.length];
                for (int i4 = 0; i4 < nNOutputArr.length; i4++) {
                    nNOutputArr3[i4] = nNOutputArr[i4][i2][i3];
                }
                nNOutputArr2[i2][i3] = consolidate(nNOutputArr3, i);
            }
        }
        return nNOutputArr2;
    }

    public NNOutput consolidate(NNOutput[] nNOutputArr, int i) {
        NNOutput nNOutput = null;
        switch (i) {
            case 2:
                int iValue = nNOutputArr[0].getIValue();
                double prob = nNOutputArr[0].getProb();
                for (int i2 = 1; i2 < nNOutputArr.length; i2++) {
                    if (nNOutputArr[i2].getProb() > prob) {
                        prob = nNOutputArr[i2].getProb();
                        iValue = nNOutputArr[i2].getIValue();
                    }
                }
                nNOutput = new NNOutput(iValue, prob);
                break;
            case 3:
                int[] iArr = new int[nNOutputArr.length];
                int[] iArr2 = new int[nNOutputArr.length];
                int i3 = 0;
                for (int i4 = 0; i4 < nNOutputArr.length; i4++) {
                    int i5 = 0;
                    while (i5 < i3 && iArr[i5] != nNOutputArr[i4].getIValue()) {
                        i5++;
                    }
                    if (i5 < i3) {
                        int i6 = i4;
                        iArr2[i6] = iArr2[i6] + 1;
                    } else {
                        iArr[i5] = nNOutputArr[i4].getIValue();
                        iArr2[i5] = 1;
                        i3++;
                    }
                }
                int i7 = iArr[0];
                int i8 = iArr2[0];
                int i9 = 0;
                for (int i10 = 1; i10 < i3; i10++) {
                    if (iArr2[i10] > i8) {
                        i8 = iArr2[i10];
                        i7 = iArr[i10];
                        i9 += iArr[i10];
                    }
                }
                nNOutput = new NNOutput(i7, i8 / i9);
                break;
        }
        return nNOutput;
    }

    public NNOutput output(double[] dArr, NNDataMatrix nNDataMatrix, int i) {
        NNOutput nNOutput;
        double[][] dArr2 = new double[this.num_anns];
        for (int i2 = 0; i2 < this.num_anns; i2++) {
            dArr2[i2] = new double[nNDataMatrix.getNNOutputs()];
            this.anns[i2].output(dArr, dArr2[i2]);
        }
        if (i == 1 || i == 4) {
            double[] dArr3 = new double[nNDataMatrix.getNNOutputs()];
            for (int i3 = 0; i3 < nNDataMatrix.getNNOutputs(); i3++) {
                double d = 0.0d;
                if (i == 1) {
                    for (int i4 = 0; i4 < this.num_anns; i4++) {
                        d += dArr2[i4][i3];
                    }
                    dArr3[i3] = d / this.num_anns;
                } else {
                    for (int i5 = 0; i5 < this.num_anns; i5++) {
                        d += this.weights[i5] * dArr2[i5][i3];
                    }
                    dArr3[i3] = d / ((this.num_anns * (this.num_anns + 1)) / 2.0d);
                }
            }
            nNOutput = nNDataMatrix.treatOutputs(dArr3)[0];
        } else {
            NNOutput[] nNOutputArr = new NNOutput[this.num_anns];
            for (int i6 = 0; i6 < this.num_anns; i6++) {
                nNOutputArr[i6] = nNDataMatrix.treatOutputs(dArr2[i6])[0];
            }
            nNOutput = consolidate(nNOutputArr, i);
        }
        return nNOutput;
    }
}
