package com.speechify.client.helpers.content.standard.book

import com.speechify.client.api.content.TextEnrichment
import com.speechify.client.api.content.hasEnrichment
import com.speechify.client.api.util.images.Viewport
import com.speechify.client.api.util.images.verticalDistanceTo
import com.speechify.client.helpers.content.standard.book.FootnoteClassifier.Companion.TARGET_SCORE_KEY
import kotlin.math.abs
import kotlin.math.exp
import kotlin.math.sqrt

/**
 * This function will try to detect footnotes in the give list of line groups.
 * The return value is a list of line groups that are detected to be footnotes.
 */
internal fun detectFootnotes(
    lines: List<LineGroup>,
    stats: LineStats,
    pageSize: Viewport,
): List<LineGroup> {
    val vectors = getVectors(lines, stats, pageSize)

    return DEFAULT_CLASSIFIER.classify(vectors)
        .withIndex()
        .filter { it.value }
        .map { lines[it.index] }
}

/**
 * This function gets the feature vectors for the given list of line groups.
 */
internal fun getVectors(lines: List<LineGroup>, stats: LineStats, pageSize: Viewport): List<DoubleArray> {
    val vectors = mutableListOf<DoubleArray>()
    lines.withIndex().forEach {
        val previous = lines.getOrNull(it.index - 1)
        val current = it
        val next = lines.getOrNull(it.index + 1)

        val gapAbove = if (previous != null) {
            current.value.normalizedBox.verticalDistanceTo(previous.normalizedBox)
        } else {
            0.0
        }
        val gapBelow = if (next != null) {
            current.value.normalizedBox.verticalDistanceTo(next.normalizedBox)
        } else {
            0.0
        }

        val lineHeight = current.value.normalizedBox.height / current.value.lines.size
        val fractionOfUsualLineHeight = lineHeight / stats.usualLineHeight

        val fontSizeRatioToPreviousItem =
            if (previous != null) lineHeight / previous.lines.map { it.normalizedBox.height }.average() else 1.0

        val footnoteArea = 0.40
        val edgeProximityScore =
            (1 - ((1 - current.value.normalizedBox.top) / footnoteArea)).coerceAtMost(1.0).coerceAtLeast(0.0)

        val linesThatStartWithDigits =
            current.value.lines.count { line -> line.chunks.first().text.text.trim().all { c -> c.isDigit() } }
        val linesThatStartWithDigitsRatio =
            linesThatStartWithDigits.toDouble() / current.value.lines.size.toDouble()

        val numberOfSubscriptChunks =
            current.value.lines.sumOf { line ->
                line.chunks.count { chunk ->
                    chunk.text.hasEnrichment(TextEnrichment.Subscript)
                }
            }
        val numberOfChunks = current.value.lines.sumOf { line -> line.chunks.size }
        val subscriptChunksRatio = numberOfSubscriptChunks.toDouble() / numberOfChunks.toDouble()

        val allTexts = current.value.lines.flatMap { line -> line.chunks.map { chunk -> chunk.text.text } }
        val digits = allTexts.sumOf { text -> text.count { c -> c.isDigit() } }

        val ratioOfDigits = digits.toDouble() / allTexts.sumOf { text -> text.length }

        val firstChunkIsNumberAndPeriod = current.value.lines.first().chunks.first().text.text.trim().let {
            it.dropLast(1).all { c -> c.isDigit() } && it.endsWith(".")
        }

        // We use a map here to make it easier to add/remove features.
        // For efficiency, we convert to an array for actual processing so this is simply for readability.
        vectors.add(
            buildMap {
                put("lineCount", current.value.lines.size.toDouble())
                put("differenceFromUsualLineHeight", stats.usualLineHeight - lineHeight)
                put("fractionOfUsualLineHeightIsSmaller", if (fractionOfUsualLineHeight < 1.0) 1.0 else 0.0)
                put("fractionOfUsualLineHeightIsLarger", if (fractionOfUsualLineHeight > 1.0) 1.0 else 0.0)
                put("gapAbove", 1 - exp(-abs(gapAbove * pageSize.height) / 20))
                put("gapBelow", 1 - exp(-abs(gapBelow * pageSize.height) / 20))
                put(
                    "startsWithNumber",
                    if (current.value.lines.first().chunks.first().text.text[0].isDigit()) 1.0 else 0.0,
                )
                put("isInBottom40OfPage", if (current.value.normalizedBox.top > 0.6) 1.0 else 0.0)
                put("edgeProximityScore", edgeProximityScore)
                put("fontSizeRatioToPreviousItem", fontSizeRatioToPreviousItem)
                put("fontSizeRatioToPreviousItemInverted", 1 - fontSizeRatioToPreviousItem)
                put(
                    "secondChunkIsSubScript",
                    if (current.value.lines.firstOrNull()?.chunks?.getOrNull(1)?.text?.hasEnrichment(
                            TextEnrichment.Subscript,
                        ) == true
                    ) {
                        1.0
                    } else {
                        0.0
                    },
                )
                put("linesThatStartWithDigits", linesThatStartWithDigits.toDouble())
                put("subscriptChunksRatio", subscriptChunksRatio)
                put("subscriptChunksRatioGreater75", if (subscriptChunksRatio > 0.75) 1.0 else 0.0)
                put("linesThatStartWithDigitsRatio", linesThatStartWithDigitsRatio)
                put(
                    "linesThatStartWithDigitsRatioGreater75",
                    if (linesThatStartWithDigitsRatio > 0.75) 1.0 else 0.0,
                )
                put("ratioOfDigits", ratioOfDigits)
                put("isRatioOfDigitsGreater75", if (ratioOfDigits > 0.75) 1.0 else 0.0)
                put("firstChunkIsNumberAndPeriod", if (firstChunkIsNumberAndPeriod) 1.0 else 0.0)
            }.toVector(),
        )
    }

    return vectors
}

/**
 * This class is responsible for classifying a list of feature vectors as being a footnote or not.
 * It takes in a map of predetermined weights that were generated by our find weights functionality.
 */
internal class FootnoteClassifier(
    weights: Map<String, Double>,
    private val normalizations: Map<String, NormalizationData>,
) {

    private val targetScore = weights.get(TARGET_SCORE_KEY)!!
    private val weightVector = weights.toVector()
    private val normalizationVector: Array<NormalizationData?>

    init {
        // The normalization vectors are also supplied as a map.
        // For efficiency, we store the in the array index as the values they are normalizing, making access simpler.
        val indices = weights.entries
            .filterNot { it.key == TARGET_SCORE_KEY }
            .sortedBy { it.key }

        normalizationVector = Array(indices.size) { i ->
            normalizations[indices[i].key]
        }
    }

    /**
     * Given a list of feature vectors created by the [FootnoteDetector] getVectors method, this method
     * will check each vector using the weights provided for being a foot note.
     * The returned list will indicate true for each vector that is a foot note, and false otherwise.
     */
    fun classify(vectors: List<DoubleArray>): List<Boolean> {
        return vectors.map { vector ->
            score(vector) > targetScore
        }
    }

    /**
     * This takes the feature vector and assigns it a score using the normalization data and weights.
     * Internal for the training code.
     */
    internal fun score(vector: DoubleArray): Double {
        if (vector.size != weightVector.size) {
            throw IllegalArgumentException("Vector size does not match weights size")
        }

        val weightedVector = DoubleArray(vector.size) { i ->
            val normalizedValue = normalizationVector[i]?.normalize(vector[i]) ?: vector[i]
            normalizedValue * weightVector[i]
        }

        return weightedVector.sum()
    }

    companion object {
        /**
         * Special key for the weights that describes the cut off where an item is considered a footnote.
         * Also determined using our code to find the best weights.
         */
        internal const val TARGET_SCORE_KEY = "TARGET_SCORE"
    }

    /**
     * Inspired by the https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization
     * We use this to make sure that we can use not only binary or 0-1 values, but also values that are not normalized.
     * Like for example the raw count of lines.
     * The mean and variance are calculated on the training data.
     */
    internal data class NormalizationData(val mean: Double, val variance: Double) {
        fun normalize(value: Double): Double {
            return (value - mean) / sqrt(variance)
        }
    }
}

/**
 * Converts a map to a vector, where the keys are sorted alphabetically and the values are the values of the map.
 * This is for efficiency reasons, it is much faster to multiply two double arrays than to multiply two maps.
 */
internal fun Map<String, Double>.toVector(): DoubleArray {
    val relevantEntries = this.entries
        .filterNot { it.key == FootnoteClassifier.TARGET_SCORE_KEY }
        .sortedBy { it.key }

    val vector = DoubleArray(relevantEntries.size)
    relevantEntries.forEachIndexed { index, entry ->
        vector[index] = entry.value
    }

    return vector
}

/**
 * The default weights for the classifier.
 * These can be determined using the "find weights" functionality in the FootnoteWeightDetection.
 * The definition of what each value represents can be found in [getVectors].
 */
internal val MACHINE_LEARNED_WEIGHTS = mapOf(
    "lineCount" to 0.191411758072082,
    "differenceFromUsualLineHeight" to 1.5256162745857,
    "fractionOfUsualLineHeightIsSmaller" to 0.17020749792811163,
    "fractionOfUsualLineHeightIsLarger" to -0.1400300059721555,
    "gapAbove" to 0.641489215566626,
    "gapBelow" to 0.0826246601976311,
    "startsWithNumber" to 0.7391743057048685,
    "isInBottom40OfPage" to 1.7598707879688826,
    "edgeProximityScore" to 0.0,
    "fontSizeRatioToPreviousItem" to -1.6004488152551835,
    "fontSizeRatioToPreviousItemInverted" to -1.0822046473469906,
    "secondChunkIsSubScript" to 0.2929141228376006,
    "linesThatStartWithDigits" to 0.2120307729703964,
    "subscriptChunksRatio" to 0.3927752920191987,
    "subscriptChunksRatioGreater75" to 0.0,
    "linesThatStartWithDigitsRatio" to 0.09527920356706776,
    "linesThatStartWithDigitsRatioGreater75" to -0.0,
    "ratioOfDigits" to 0.02775871521461709,
    "isRatioOfDigitsGreater75" to -1.498298461262424,
    "firstChunkIsNumberAndPeriod" to 0.9639645133952163,
    TARGET_SCORE_KEY to 1.7695271326828261,

    /*"lineCount" to 0.191411758072082,
    "differenceFromUsualLineHeight" to 0.16015093212230225,
    "fractionOfUsualLineHeightIsSmaller" to 0.17020749792811163,
    "fractionOfUsualLineHeightIsLarger" to -0.1400300059721555,
    "gapAbove" to 1.0443033539030475,
    "gapBelow" to 0.0,
    "startsWithNumber" to 0.7391743057048685,
    "isInBottom40OfPage" to 1.7598707879688826,
    "edgeProximityScore" to 0.0,
    "fontSizeRatioToPreviousItem" to -1.7357562516979956,
    "fontSizeRatioToPreviousItemInverted" to -0.5411023236734953,
    "secondChunkIsSubScript" to 0.1464570614188003,
    "linesThatStartWithDigits" to 0.2120307729703964,
    "subscriptChunksRatio" to 0.0,
    "subscriptChunksRatioGreater75" to 0.22668554290232729,
    "linesThatStartWithDigitsRatio" to -0.14851002091013732,
    "linesThatStartWithDigitsRatioGreater75" to -0.6457147125947085,
    "ratioOfDigits" to 0.0,
    "isRatioOfDigitsGreater75" to -6.05687686458533,
    "firstChunkIsNumberAndPeriod" to 0.9082990267876035,
    "TARGET_SCORE" to 1.7510644153863397,*/
)

/**
 * Fields in our feature vector that we need to normalize.
 * This is a subset of the keys in [MACHINE_LEARNED_WEIGHTS].
 * This is used in the training code to find the required normalization data.
 */
internal val FIELDS_TO_NORMALIZE = arrayOf("lineCount", "linesThatStartWithDigits")

private val DEFAULT_CLASSIFIER = FootnoteClassifier(
    MACHINE_LEARNED_WEIGHTS,
    buildMap {
        put(
            "lineCount",
            FootnoteClassifier.NormalizationData(mean = 4.829733163913596, variance = 35.10442724773078),
        )
        put(
            "linesThatStartWithDigits",
            FootnoteClassifier.NormalizationData(mean = 0.21664548919949175, variance = 0.5178677815648085),
        )
    },
)
