package data;

import java.io.BufferedWriter;
import java.io.PrintWriter;
import utils.MatUtils;

/* loaded from: input_file:lib/artificialneuralnets.jar:data/AvgStats.class */
public class AvgStats {
    int nexamples;
    int runs;
    double sum_correct;
    double sumsq_correct;
    double[][] sum_matrix;
    double[][] sumsq_matrix;
    double sum_sse;
    double sumsq_sse;
    double sum_sad;
    double sumsq_sad;
    double sum_rmse;
    double sumsq_rmse;

    public AvgStats() {
        this.sum_correct = 0.0d;
        this.nexamples = 0;
        this.runs = 0;
        this.sum_sse = 0.0d;
        this.sum_sad = 0.0d;
        this.sumsq_sse = 0.0d;
        this.sumsq_sad = 0.0d;
        this.sumsq_correct = 0.0d;
        this.sum_rmse = 0.0d;
        this.sumsq_rmse = 0.0d;
        this.sum_matrix = null;
        this.sumsq_matrix = null;
    }

    public AvgStats(Stats[] statsArr) {
        this();
        this.runs = statsArr.length;
        int nClasses = statsArr[0].getNClasses();
        if (nClasses > 0) {
            this.sum_matrix = new double[nClasses][nClasses];
            this.sumsq_matrix = new double[nClasses][nClasses];
        }
        for (int i = 0; i < nClasses; i++) {
            for (int i2 = 0; i2 < nClasses; i2++) {
                this.sum_matrix[i][i2] = 0.0d;
                this.sumsq_matrix[i][i2] = 0.0d;
            }
        }
        for (int i3 = 0; i3 < this.runs; i3++) {
            this.nexamples += statsArr[i3].getNumExamples();
            for (int i4 = 0; i4 < nClasses; i4++) {
                for (int i5 = 0; i5 < nClasses; i5++) {
                    double[] dArr = this.sum_matrix[i4];
                    int i6 = i5;
                    dArr[i6] = dArr[i6] + statsArr[i3].getMatrixVal(i4, i5);
                    double[] dArr2 = this.sumsq_matrix[i4];
                    int i7 = i5;
                    dArr2[i7] = dArr2[i7] + Math.pow(statsArr[i3].getMatrixVal(i4, i5), 2.0d);
                }
            }
            if (nClasses > 0) {
                this.sum_correct += statsArr[i3].getCorrect();
                this.sumsq_correct += Math.pow(statsArr[i3].getCorrect(), 2.0d);
            }
            this.sum_sse += statsArr[i3].getSSE();
            this.sumsq_sse += Math.pow(statsArr[i3].getSSE(), 2.0d);
            this.sum_sad += statsArr[i3].getSAD();
            this.sumsq_sad += Math.pow(statsArr[i3].getSAD(), 2.0d);
            this.sum_rmse += statsArr[i3].getRMSE();
            this.sumsq_rmse += Math.pow(statsArr[i3].getRMSE(), 2.0d);
        }
    }

    public AvgStats(AvgStats[] avgStatsArr) {
        this();
        this.runs = 0;
        int length = avgStatsArr[0].sum_matrix == null ? 0 : avgStatsArr[0].sum_matrix.length;
        if (length > 0) {
            this.sum_matrix = new double[length][length];
            this.sumsq_matrix = new double[length][length];
        }
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                this.sum_matrix[i][i2] = 0.0d;
                this.sumsq_matrix[i][i2] = 0.0d;
            }
        }
        for (int i3 = 0; i3 < avgStatsArr.length; i3++) {
            this.runs += avgStatsArr[i3].runs;
            this.nexamples += avgStatsArr[i3].nexamples;
            for (int i4 = 0; i4 < length; i4++) {
                for (int i5 = 0; i5 < length; i5++) {
                    double[] dArr = this.sum_matrix[i4];
                    int i6 = i5;
                    dArr[i6] = dArr[i6] + avgStatsArr[i3].sum_matrix[i4][i5];
                    double[] dArr2 = this.sumsq_matrix[i4];
                    int i7 = i5;
                    dArr2[i7] = dArr2[i7] + avgStatsArr[i3].sumsq_matrix[i4][i5];
                }
            }
            if (length > 0) {
                this.sum_correct += avgStatsArr[i3].sum_correct;
                this.sumsq_correct += avgStatsArr[i3].sumsq_correct;
            }
            this.sum_sse += avgStatsArr[i3].sum_sse;
            this.sumsq_sse += avgStatsArr[i3].sumsq_sse;
            this.sum_rmse += avgStatsArr[i3].sum_rmse;
            this.sumsq_rmse += avgStatsArr[i3].sumsq_rmse;
            this.sum_sad += avgStatsArr[i3].sum_sad;
            this.sumsq_sad += avgStatsArr[i3].sumsq_sad;
        }
    }

    public int getNumExamples() {
        return this.nexamples / this.runs;
    }

    public double getCorrect() {
        return this.sum_correct / this.runs;
    }

    public double getCorrectPerc() {
        return this.sum_correct / this.nexamples;
    }

    public double getStdevCorrect() {
        return Math.sqrt((this.sumsq_correct - (Math.pow(this.sum_correct, 2.0d) / this.runs)) / (this.runs - 1));
    }

    public double getStdevCorrectPerc() {
        return getStdevCorrect() / getNumExamples();
    }

    public double getSSE() {
        return this.sum_sse / this.runs;
    }

    public double getRMSE() {
        return this.sum_rmse / this.runs;
    }

    public double getStdevSSE() {
        return Math.sqrt((this.sumsq_sse - (Math.pow(this.sum_sse, 2.0d) / this.runs)) / (this.runs - 1));
    }

    public double getStdevRMSE() {
        return Math.sqrt((this.sumsq_rmse - (Math.pow(this.sum_rmse, 2.0d) / this.runs)) / (this.runs - 1));
    }

    public double getSAD() {
        return this.sum_sad / this.runs;
    }

    public double getMAD() {
        return getSAD() / getNumExamples();
    }

    public double getStdevSAD() {
        return Math.sqrt((this.sumsq_sad - (Math.pow(this.sum_sad, 2.0d) / this.runs)) / (this.runs - 1));
    }

    public double getStdevMAD() {
        return getStdevSAD() / getNumExamples();
    }

    public double getAvgMatrixVal(int i, int i2) {
        return this.sum_matrix[i][i2] / this.runs;
    }

    public double getStdMatrixVal(int i, int i2) {
        return Math.sqrt((this.sumsq_matrix[i][i2] - (Math.pow(this.sum_matrix[i][i2], 2.0d) / this.runs)) / (this.runs - 1));
    }

    public void sum(AvgStats avgStats) {
        if (this.sum_matrix != null) {
            for (int i = 0; i < this.sum_matrix.length; i++) {
                for (int i2 = 0; i2 < this.sum_matrix[i].length; i2++) {
                    double[] dArr = this.sum_matrix[i];
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + avgStats.sum_matrix[i][i2];
                    double[] dArr2 = this.sumsq_matrix[i];
                    int i4 = i2;
                    dArr2[i4] = dArr2[i4] + avgStats.sumsq_matrix[i][i2];
                }
            }
        }
        this.nexamples += avgStats.nexamples;
        this.runs += avgStats.runs;
        this.sum_correct += avgStats.sum_correct;
        this.sum_sse += avgStats.sum_sse;
        this.sum_rmse += avgStats.sum_rmse;
        this.sum_sad += avgStats.sum_sad;
        this.sumsq_correct += avgStats.sumsq_correct;
        this.sumsq_sse += avgStats.sumsq_sse;
        this.sumsq_rmse += avgStats.sumsq_rmse;
        this.sumsq_sad += avgStats.sumsq_sad;
    }

    public void print(BufferedWriter bufferedWriter) throws Exception {
        bufferedWriter.write("N.examples: " + this.nexamples + "\n");
        if (this.sum_matrix != null) {
            bufferedWriter.write("Correct: " + MatUtils.doubleToString(getCorrectPerc() * 100.0d, 3) + "%\n");
            bufferedWriter.write("STD Correct: " + MatUtils.doubleToString(getStdevCorrectPerc() * 100.0d, 3) + "%\n");
            bufferedWriter.write("Confusion matrix (lines:correct)(col:predicted) \n");
            for (int i = 0; i < this.sum_matrix.length; i++) {
                for (int i2 = 0; i2 < this.sum_matrix.length; i2++) {
                    bufferedWriter.write(String.valueOf(MatUtils.doubleToString(getAvgMatrixVal(i, i2), 3)) + " ");
                }
                bufferedWriter.write("\n");
            }
            bufferedWriter.write("STDEV Confusion matrix \n");
            for (int i3 = 0; i3 < this.sum_matrix.length; i3++) {
                for (int i4 = 0; i4 < this.sum_matrix.length; i4++) {
                    bufferedWriter.write(String.valueOf(MatUtils.doubleToString(getStdMatrixVal(i3, i4), 3)) + " ");
                }
                bufferedWriter.write("\n");
            }
        }
        bufferedWriter.write("RMSE:" + MatUtils.doubleToString(getRMSE(), 3) + "\n");
        bufferedWriter.write("Stdev RMSE:" + MatUtils.doubleToString(getStdevRMSE(), 3) + "\n");
        bufferedWriter.write("MAD:" + MatUtils.doubleToString(getMAD(), 3) + "\n");
        bufferedWriter.write("Stdev MAD:" + MatUtils.doubleToString(getStdevMAD(), 3) + "\n");
        bufferedWriter.flush();
    }

    public void print() throws Exception {
        print(new BufferedWriter(new PrintWriter(System.out)));
    }
}
