/*
 * Decompiled with CFR 0.152.
 */
package org.jcodec.codecs.vpx.vp9;

import org.jcodec.codecs.vpx.VPXBooleanDecoder;
import org.jcodec.codecs.vpx.vp9.Consts;
import org.jcodec.codecs.vpx.vp9.DecodingContext;
import org.jcodec.codecs.vpx.vp9.ModeInfo;
import org.jcodec.codecs.vpx.vp9.Probabilities;

public class Residual {
    private int[][][] coefs;
    public static int[][] blk_size_lookup = new int[][]{{-1, 0, 2}, {1, 3, 5}, {4, 6, 8}, {7, 9, 11}, {10, 12, -1}};

    public Residual(int[][][] coefs) {
        this.coefs = coefs;
    }

    public static Residual read(int miCol, int miRow, int blSz, VPXBooleanDecoder decoder, Probabilities probStore, DecodingContext c, ModeInfo mode) {
        int[][][] coefs = new int[3][][];
        for (int pl = 0; pl < 3; ++pl) {
            int subW = Residual.msb(Consts.blW[blSz] / c.getSubX()) - 2;
            int subH = Residual.msb(Consts.blH[blSz] / c.getSubY()) - 2;
            int uvBlSz = blk_size_lookup[subH][subW - subH + 1];
            int txSize = pl == 0 ? mode.getTxSize() : Residual.uvTxSize(c, mode.getTxSize(), blSz, uvBlSz);
            int step4x4 = 1 << txSize;
            int plBlSz = pl == 0 ? blSz : uvBlSz;
            int frameWPix = c.getMiFrameWidth() << 3 >> c.getSubX();
            int frameHPix = c.getMiFrameHeight() << 3 >> c.getSubY();
            int blX = miCol << 3 >> c.getSubX();
            int blY = miRow << 3 >> c.getSubY();
            coefs[pl] = new int[Consts.blH[plBlSz] * Consts.blW[plBlSz]][];
            int blkIdx = 0;
            for (int y = 0; y < Consts.blH[plBlSz]; y += step4x4) {
                int x = 0;
                while (x < Consts.blW[plBlSz]) {
                    int posX = blX + (x << 2);
                    int posY = blY + (y << 2);
                    if (!mode.isSkip() && posX < frameWPix && posY < frameHPix) {
                        coefs[pl][blkIdx] = Residual.tokens(pl, posX, posY, txSize, blkIdx, mode.isInter(), decoder, probStore, c);
                    }
                    x += step4x4;
                    ++blkIdx;
                }
            }
        }
        return new Residual(coefs);
    }

    private static int msb(int v) {
        if (((v &= 0xFF) & 0xF0) != 0) {
            if ((v & 0xC0) != 0) {
                return 6 | v >> 7;
            }
            return 4 | v >> 5;
        }
        if ((v & 0xC) != 0) {
            return 2 | v >> 3;
        }
        return v >> 1;
    }

    private static int uvTxSize(DecodingContext c, int txSize, int blSz, int uvBlSz) {
        if (blSz < 3) {
            return 0;
        }
        return Math.min(txSize, Consts.maxTxLookup[uvBlSz]);
    }

    public static int[] tokens(int plane, int startX, int startY, int txSz, int blockIdx, boolean isInter, VPXBooleanDecoder decoder, Probabilities probStore, DecodingContext c) {
        int maxCoeff = 16 << (txSz << 1);
        boolean expectMoreCoefs = true;
        int[] scan = c.getScan(plane, txSz, blockIdx);
        int txType = c.getTxType(plane, txSz, blockIdx);
        int[] coefs = new int[maxCoeff];
        for (int cf = 0; cf < maxCoeff; ++cf) {
            boolean moreCoefs;
            int band = txSz == 0 ? Consts.coefband_4x4[cf] : Consts.coefband_8x8plus[cf];
            int pos = scan[cf];
            if (!expectMoreCoefs && !(moreCoefs = Residual.readMoreCoefs(plane, pos, txSz, startX, startY, txType, band, isInter, decoder, probStore, c))) break;
            int token = Residual.readToken(plane, pos, txSz, startX, startY, txType, band, isInter, decoder, probStore, c);
            if (token == 0) {
                expectMoreCoefs = true;
                coefs[pos] = 0;
                continue;
            }
            int coef = Residual.readCoef(token, decoder, c);
            int sign = decoder.readBitEq();
            coefs[pos] = sign == 1 ? -coef : coef;
            expectMoreCoefs = false;
        }
        return coefs;
    }

    private static int readCoef(int token, VPXBooleanDecoder decoder, DecodingContext c) {
        int bit;
        int cat = Consts.extra_bits[token][0];
        int numExtra = Consts.extra_bits[token][1];
        int coef = Consts.extra_bits[token][2];
        if (token == 10) {
            for (bit = 0; bit < c.getBitDepth() - 8; ++bit) {
                int high_bit = decoder.readBit(255);
                coef += high_bit << 5 + c.getBitDepth() - bit;
            }
        }
        for (bit = 0; bit < numExtra; ++bit) {
            int coef_bit = decoder.readBit(Consts.cat_probs[cat][bit]);
            coef += coef_bit << numExtra - 1 - bit;
        }
        return coef;
    }

    private static int pareto(int bin, int prob) {
        if (bin < 2) {
            return prob;
        }
        int x = (prob - 1) / 2;
        if ((prob & 1) != 0) {
            return Consts.PARETO_TABLE[x][bin - 2];
        }
        return Consts.PARETO_TABLE[x][bin - 2] + Consts.PARETO_TABLE[x + 1][bin - 2] >> 1;
    }

    private static int readToken(int plane, int coefi, int txSz, int posX, int posY, int txType, int band, boolean isInter, VPXBooleanDecoder decoder, Probabilities probStore, DecodingContext c) {
        int ctx = Residual.calcTokenContext(plane, coefi, txSz, posX, posY, txType, c);
        int[][][][][][] probs = probStore.getCoefProbs();
        int prob0 = Residual.pareto(0, probs[txSz][plane > 0 ? 1 : 0][isInter ? 1 : 0][band][ctx][1]);
        int prob1 = Residual.pareto(1, probs[txSz][plane > 0 ? 1 : 0][isInter ? 1 : 0][band][ctx][2]);
        return decoder.readTree(Consts.TOKEN_TREE, prob0, prob1);
    }

    private static boolean readMoreCoefs(int plane, int coefi, int txSz, int posX, int posY, int txType, int band, boolean isInter, VPXBooleanDecoder decoder, Probabilities probStore, DecodingContext c) {
        int ctx = Residual.calcTokenContext(plane, coefi, txSz, posX, posY, txType, c);
        int[][][][][][] probs = probStore.getCoefProbs();
        return decoder.readBit(probs[txSz][plane > 0 ? 1 : 0][isInter ? 1 : 0][band][ctx][0]) == 1;
    }

    private static int calcTokenContext(int plane, int coefi, int txSz, int posX, int posY, int txType, DecodingContext c) {
        if (coefi == 0) {
            int[][] aboveNonzeroContext = c.getAboveNonzeroContext();
            int[][] leftNonzeroContext = c.getLeftNonzeroContext();
            int subX = plane > 0 ? c.getSubX() : 0;
            int subY = plane > 0 ? c.getSubY() : 0;
            int max4x = c.getMiFrameWidth() << 1 >> subX;
            int max4y = c.getMiFrameHeight() << 1 >> subY;
            int tx4 = 1 << txSz;
            int pos4x = posX >> 2;
            int pos4y = posY >> 2;
            int aboveNz = 0;
            int leftNz = 0;
            for (int i = 0; i < tx4; ++i) {
                if (pos4x + i < max4x) {
                    aboveNz |= aboveNonzeroContext[plane][pos4x + i];
                }
                if (pos4y + i >= max4y) continue;
                leftNz |= leftNonzeroContext[plane][pos4y + i];
            }
            return aboveNz + leftNz;
        }
        int abovePos = 0;
        int leftPos = 0;
        if (coefi != 0) {
            int y = coefi >> (txSz += 2);
            int x = coefi & 63 >> 6 - txSz;
            abovePos = (y - 1 << txSz) + x;
            leftPos = (y << txSz) + x - 1;
            if (txType == 2 || x == 0) {
                leftPos = abovePos;
            }
            if (txType == 1 || y == 0) {
                abovePos = leftPos;
            }
        }
        int[] tokenCache = c.getTokenCache();
        return 1 + tokenCache[abovePos] + tokenCache[leftPos] >> 1;
    }
}

