/*
 * Decompiled with CFR 0.152.
 */
package marytts.htsengine;

import marytts.htsengine.HMMData;
import marytts.htsengine.HTSDWin;
import marytts.util.MaryUtils;
import org.apache.log4j.Logger;

public class HTSPStream {
    public static final int WLEFT = 0;
    public static final int WRIGHT = 1;
    private int feaType;
    private int vSize;
    private int order;
    private int nT;
    private int width;
    private double[][] par;
    private double[][] mseq;
    private double[][] ivseq;
    private double[] g;
    private double[][] wuw;
    private double[] wum;
    private HTSDWin dw;
    private double mean;
    private double var;
    private int maxGVIter = 200;
    private double GVepsilon = 1.0E-4;
    private double minEucNorm = 0.01;
    private double stepInit = 0.1;
    private double stepDec = 0.5;
    private double stepInc = 1.2;
    private double w1 = 1.0;
    private double w2 = 1.0;
    private double lzero = -1.0E10;
    private double norm = 0.0;
    private double GVobj = 0.0;
    private double HMMobj = 0.0;
    private double[] gvmean;
    private double[] gvcovInv;
    private boolean[] gvSwitch;
    private int gvLength;
    private Logger logger = MaryUtils.getLogger("PStream");

    public HTSPStream(int vector_size, int utt_length, int fea_type, int maxIterationsGV) throws Exception {
        this.dw = new HTSDWin();
        this.feaType = fea_type;
        this.vSize = vector_size;
        this.order = vector_size / this.dw.getNum();
        this.nT = utt_length;
        this.maxGVIter = maxIterationsGV;
        this.width = 3;
        this.par = new double[this.nT][this.order];
        this.mseq = new double[this.nT][this.vSize];
        this.ivseq = new double[this.nT][this.vSize];
        this.g = new double[this.nT];
        this.wuw = new double[this.nT][this.width];
        this.wum = new double[this.nT];
        this.gvSwitch = new boolean[this.nT];
        for (int i = 0; i < this.nT; ++i) {
            this.gvSwitch[i] = true;
        }
        this.gvLength = this.nT;
    }

    public void setVsize(int val) {
        this.vSize = val;
    }

    public int getVsize() {
        return this.vSize;
    }

    public void setOrder(int val) {
        this.order = val;
    }

    public int getOrder() {
        return this.order;
    }

    public void setPar(int i, int j, double val) {
        this.par[i][j] = val;
    }

    public double getPar(int i, int j) {
        return this.par[i][j];
    }

    public int getT() {
        return this.nT;
    }

    public void setMseq(int i, int j, double val) {
        this.mseq[i][j] = val;
    }

    public double getMseq(int i, int j) {
        return this.mseq[i][j];
    }

    public void setIvseq(int i, int j, double val) {
        this.ivseq[i][j] = val;
    }

    public double getIvseq(int i, int j) {
        return this.ivseq[i][j];
    }

    public void setG(int i, double val) {
        this.g[i] = val;
    }

    public double getG(int i) {
        return this.g[i];
    }

    public void setWUW(int i, int j, double val) {
        this.wuw[i][j] = val;
    }

    public double getWUW(int i, int j) {
        return this.wuw[i][j];
    }

    public void setWUM(int i, double val) {
        this.wum[i] = val;
    }

    public double getWUM(int i) {
        return this.wum[i];
    }

    public int getDWwidth(int i, int j) {
        return this.dw.getWidth(i, j);
    }

    public void setGvMeanVar(double[] mean, double[] ivar) {
        this.gvmean = mean;
        this.gvcovInv = ivar;
    }

    public void setGvSwitch(int i, boolean bv) {
        if (!bv) {
            --this.gvLength;
        }
        this.gvSwitch[i] = bv;
    }

    private void printWUW(int t) {
        for (int i = 0; i < this.width; ++i) {
            System.out.print("WUW[" + t + "][" + i + "]=" + this.wuw[t][i] + "  ");
        }
        System.out.println("");
    }

    public void mlpg(HMMData htsData, boolean useGV) {
        int M = this.order;
        boolean debug = false;
        if (htsData.getUseContextDependentGV()) {
            this.logger.info("Context-dependent global variance optimization: gvLength = " + this.gvLength);
        } else {
            this.logger.info("Global variance optimization");
        }
        for (int m = 0; m < M; ++m) {
            this.calcWUWandWUM(m, debug);
            this.ldlFactorization(debug);
            this.forwardSubstitution();
            this.backwardSubstitution(m);
            if (!useGV || this.gvLength <= 0) continue;
            if (htsData.getGvMethodGradient()) {
                this.gvParmGenGradient(m, debug);
                continue;
            }
            this.gvParmGenDerivative(m, debug);
        }
    }

    private void calcWUWandWUM(int m, boolean debug) {
        int k;
        int t;
        for (t = 0; t < this.nT; ++t) {
            int i;
            this.wum[t] = 0.0;
            for (i = 0; i < this.width; ++i) {
                this.wuw[t][i] = 0.0;
            }
            for (i = 0; i < this.dw.getNum(); ++i) {
                int iorder = i * this.order + m;
                for (int j = this.dw.getWidth(i, 0); j <= this.dw.getWidth(i, 1); ++j) {
                    if (t + j < 0 || t + j >= this.nT || this.dw.getCoef(i, -j) == 0.0) continue;
                    double WU = this.dw.getCoef(i, -j) * this.ivseq[t + j][iorder];
                    int n = t;
                    this.wum[n] = this.wum[n] + WU * this.mseq[t + j][iorder];
                    for (k = 0; k < this.width && t + k < this.nT; ++k) {
                        if (k - j > this.dw.getWidth(i, 1) || this.dw.getCoef(i, k - j) == 0.0) continue;
                        double[] dArray = this.wuw[t];
                        int n2 = k;
                        dArray[n2] = dArray[n2] + WU * this.dw.getCoef(i, k - j);
                        double d = WU * this.dw.getCoef(i, k - j);
                    }
                }
            }
        }
        if (debug) {
            for (t = 0; t < this.nT; ++t) {
                System.out.format("t=%d wum=%f  wuw:", t, this.wum[t]);
                for (k = 0; k < this.wuw[t].length; ++k) {
                    System.out.format("%f ", this.wuw[t][k]);
                }
                System.out.format("\n", new Object[0]);
            }
            System.out.format("\n", new Object[0]);
        }
    }

    private void ldlFactorization(boolean debug) {
        for (int t = 0; t < this.nT; ++t) {
            int i;
            if (debug) {
                System.out.println("WUW calculation:");
                this.printWUW(t);
            }
            for (i = 1; i < this.width && t - i >= 0; ++i) {
                double[] dArray = this.wuw[t];
                dArray[0] = dArray[0] - this.wuw[t - i][i + 1 - 1] * this.wuw[t - i][i + 1 - 1] * this.wuw[t - i][0];
            }
            for (i = 2; i <= this.width; ++i) {
                int j = 1;
                while (i + j <= this.width && t - j >= 0) {
                    double[] dArray = this.wuw[t];
                    int n = i - 1;
                    dArray[n] = dArray[n] - this.wuw[t - j][j + 1 - 1] * this.wuw[t - j][i + j - 1] * this.wuw[t - j][0];
                    ++j;
                }
                double[] dArray = this.wuw[t];
                int n = i - 1;
                dArray[n] = dArray[n] / this.wuw[t][0];
            }
            if (!debug) continue;
            System.out.println("LDL factorization:");
            this.printWUW(t);
            System.out.println();
        }
    }

    private void forwardSubstitution() {
        for (int t = 0; t < this.nT; ++t) {
            this.g[t] = this.wum[t];
            for (int i = 1; i < this.width && t - i >= 0; ++i) {
                int n = t;
                this.g[n] = this.g[n] - this.wuw[t - i][i + 1 - 1] * this.g[t - i];
            }
        }
    }

    private void backwardSubstitution(int m) {
        for (int t = this.nT - 1; t >= 0; --t) {
            this.par[t][m] = this.g[t] / this.wuw[t][0];
            for (int i = 1; i < this.width && t + i < this.nT; ++i) {
                double[] dArray = this.par[t];
                int n = m;
                dArray[n] = dArray[n] - this.wuw[t][i + 1 - 1] * this.par[t + i][m];
            }
        }
    }

    private void gvParmGenDerivative(int m, boolean debug) {
        int iter;
        int t;
        double step2 = this.stepInit;
        double prev = -this.lzero;
        double obj = 0.0;
        double[] diag = new double[this.nT];
        double[] par_ori = new double[this.nT];
        this.mean = 0.0;
        this.var = 0.0;
        boolean numDown = false;
        for (t = 0; t < this.nT; ++t) {
            this.g[t] = 0.0;
            par_ori[t] = this.par[t][m];
        }
        this.convGV(m);
        this.calcWUWandWUM(m, false);
        for (iter = 1; iter <= this.maxGVIter; ++iter) {
            obj = this.calcDerivative(m);
            if (obj > prev) {
                step2 *= this.stepDec;
            }
            if (obj < prev) {
                step2 *= this.stepInc;
            }
            for (t = 0; t < this.nT; ++t) {
                double[] dArray = this.par[t];
                int n = m;
                dArray[n] = dArray[n] + step2 * this.g[t];
            }
            prev = obj;
        }
        this.logger.info("Derivative GV optimization for feature: (" + m + ")  number of iterations=" + (iter - 1));
    }

    private void gvParmGenGradient(int m, boolean debug) {
        int iter;
        int t;
        double step2 = this.stepInit;
        double obj = 0.0;
        double prev = 0.0;
        double[] diag = new double[this.nT];
        double[] par_ori = new double[this.nT];
        this.mean = 0.0;
        this.var = 0.0;
        int numDown = 0;
        int totalNumIter = 0;
        int firstIter = 0;
        for (t = 0; t < this.nT; ++t) {
            this.g[t] = 0.0;
            par_ori[t] = this.par[t][m];
        }
        this.convGV(m);
        this.calcWUWandWUM(m, false);
        for (iter = 1; iter <= this.maxGVIter; ++iter) {
            obj = this.calcGradient(m);
            if (iter > 1) {
                if (obj > prev) {
                    step2 *= this.stepInc;
                    numDown = 0;
                }
                if (obj < prev) {
                    for (t = 0; t < this.nT; ++t) {
                        double[] dArray = this.par[t];
                        int n = m;
                        dArray[n] = dArray[n] - step2 * diag[t];
                    }
                    step2 *= this.stepDec;
                    for (t = 0; t < this.nT; ++t) {
                        double[] dArray = this.par[t];
                        int n = m;
                        dArray[n] = dArray[n] + step2 * diag[t];
                    }
                    --iter;
                    if (++numDown < 100) continue;
                    this.logger.info("  ***Convergence problems....optimization stopped. Number of iterations: " + iter);
                    break;
                }
            } else if (debug) {
                this.logger.info("  First iteration:  GVobj=" + obj + " (HMMobj=" + this.HMMobj + "  GVobj=" + this.GVobj + ")");
            }
            if (this.norm < this.minEucNorm || iter > 1 && Math.abs(obj - prev) < this.GVepsilon) {
                if (debug) {
                    this.logger.info("  Number of iterations: [   " + iter + "   ] GVobj=" + obj + " (HMMobj=" + this.HMMobj + "  GVobj=" + this.GVobj + ")");
                }
                ++totalNumIter;
                if (m == 0) {
                    firstIter = iter;
                }
                if (!debug) break;
                if (iter > 1) {
                    this.logger.info("  Converged (norm=" + this.norm + ", change=" + Math.abs(obj - prev) + ")");
                    break;
                }
                this.logger.info("  Converged (norm=" + this.norm + ")");
                break;
            }
            for (t = 0; t < this.nT; ++t) {
                double[] dArray = this.par[t];
                int n = m;
                dArray[n] = dArray[n] + step2 * this.g[t];
                diag[t] = this.g[t];
            }
            prev = obj;
        }
        if (iter > this.maxGVIter) {
            this.logger.info("   optimization stopped by reaching max number of iterations (no global variance applied)");
            for (t = 0; t < this.nT; ++t) {
                this.par[t][m] = par_ori[t];
            }
        }
        totalNumIter = iter;
        this.logger.info("Gradient GV optimization for feature: (" + m + ")  number of iterations=" + totalNumIter);
    }

    private double calcGradient(int m) {
        int t;
        double w = 1.0 / (double)(this.dw.getNum() * this.nT);
        this.calcGV(m);
        this.GVobj = -0.5 * this.w2 * (this.var - this.gvmean[m]) * this.gvcovInv[m] * (this.var - this.gvmean[m]);
        double vd = this.gvcovInv[m] * (this.var - this.gvmean[m]);
        for (t = 0; t < this.nT; ++t) {
            this.g[t] = this.wuw[t][0] * this.par[t][m];
            for (int i = 2; i <= this.width; ++i) {
                if (t + i - 1 < this.nT) {
                    int n = t;
                    this.g[n] = this.g[n] + this.wuw[t][i - 1] * this.par[t + i - 1][m];
                }
                if (t - i + 1 < 0) continue;
                int n = t;
                this.g[n] = this.g[n] + this.wuw[t - i + 1][i - 1] * this.par[t - i + 1][m];
            }
        }
        this.HMMobj = 0.0;
        this.norm = 0.0;
        for (t = 0; t < this.nT; ++t) {
            this.HMMobj += -0.5 * this.w1 * w * this.par[t][m] * (this.g[t] - 2.0 * this.wum[t]);
            double h = (double)(this.nT - 1) * vd + 2.0 * this.gvcovInv[m] * (this.par[t][m] - this.mean) * (this.par[t][m] - this.mean);
            h = -this.w1 * w * this.wuw[t][0] - this.w2 * 2.0 / (double)(this.nT * this.nT) * h;
            h = -1.0 / h;
            if (this.gvSwitch[t]) {
                double aux = (this.par[t][m] - this.mean) * vd;
                this.g[t] = h * (this.w1 * w * (-this.g[t] + this.wum[t]) + this.w2 * -2.0 / (double)this.nT * aux);
            } else {
                this.g[t] = h * (this.w1 * w * (-this.g[t] + this.wum[t]));
            }
            this.norm += this.g[t] * this.g[t];
        }
        this.norm = Math.sqrt(this.norm);
        return this.HMMobj + this.GVobj;
    }

    private double calcDerivative(int m) {
        int t;
        double w = 1.0 / (double)(this.dw.getNum() * this.nT);
        this.calcGV(m);
        this.GVobj = -0.5 * this.w2 * this.var * this.gvcovInv[m] * (this.var - 2.0 * this.gvmean[m]);
        double vd = -2.0 * this.gvcovInv[m] * (this.var - this.gvmean[m]) / (double)this.nT;
        for (t = 0; t < this.nT; ++t) {
            this.g[t] = this.wuw[t][0] * this.par[t][m];
            for (int i = 2; i <= this.width; ++i) {
                if (t + i - 1 < this.nT) {
                    int n = t;
                    this.g[n] = this.g[n] + this.wuw[t][i - 1] * this.par[t + i - 1][m];
                }
                if (t - i + 1 < 0) continue;
                int n = t;
                this.g[n] = this.g[n] + this.wuw[t - i + 1][i - 1] * this.par[t - i + 1][m];
            }
        }
        this.HMMobj = 0.0;
        for (t = 0; t < this.nT; ++t) {
            this.HMMobj += this.w1 * w * this.par[t][m] * (this.wum[t] - 0.5 * this.g[t]);
            double h = -this.w1 * w * this.wuw[t][0] - this.w2 * 2.0 / (double)(this.nT * this.nT) * ((double)(this.nT - 1) * this.gvcovInv[m] * (this.var - this.gvmean[m]) + 2.0 * this.gvcovInv[m] * (this.par[t][m] - this.mean) * (this.par[t][m] - this.mean));
            this.g[t] = this.gvSwitch[t] ? 1.0 / h * (this.w1 * w * (-this.g[t] + this.wum[t]) + this.w2 * vd * (this.par[t][m] - this.mean)) : 1.0 / h * (this.w1 * w * (-this.g[t] + this.wum[t]));
        }
        return -(this.HMMobj + this.GVobj);
    }

    private void convGV(int m) {
        this.calcGV(m);
        double ratio = Math.sqrt(this.gvmean[m] / this.var);
        for (int t = 0; t < this.nT; ++t) {
            if (!this.gvSwitch[t]) continue;
            this.par[t][m] = ratio * (this.par[t][m] - this.mean) + this.mean;
        }
    }

    private void calcGV(int m) {
        int t;
        this.mean = 0.0;
        this.var = 0.0;
        for (t = 0; t < this.nT; ++t) {
            if (!this.gvSwitch[t]) continue;
            this.mean += this.par[t][m];
        }
        this.mean /= (double)this.gvLength;
        for (t = 0; t < this.nT; ++t) {
            if (!this.gvSwitch[t]) continue;
            this.var += (this.par[t][m] - this.mean) * (this.par[t][m] - this.mean);
        }
        this.var /= (double)this.gvLength;
    }
}

