package com.speechify.client.internal.util.extensions.collections

import kotlin.math.pow

internal data class Cluster<T>(
    var average: Double,
    var max: Double = Double.MIN_VALUE,
    var min: Double = Double.MAX_VALUE,
    var count: Int,
    var items: MutableList<T>,
)

private data class DataPoint<T>(
    val value: Double,
    val item: T,
    var cluster: Int,

)

// clustering algorithm copied from: https://github.com/TheAlgorithms/C/blob/master/machine_learning/k_means_clustering.c
// original had a small bug where it sometimes divided by 0, this is fixed in our implementation.
internal fun <T> List<T>.kMeans(suggestedK: Int, valueExtractor: (T) -> Double): List<Cluster<T>> = when {
    suggestedK <= 1 -> listOf(
        Cluster(
            this.map { valueExtractor(it) }.average(),
            count = this.size,
            items = this.toMutableList(),
        ),
    )
    suggestedK >= this.size -> {
        /*
        If no of clusters is more than observations
       each observation can be its own cluster
        */
        this.map {
            val value = valueExtractor(it)
            Cluster(
                average = value,
                count = 1,
                items = mutableListOf(it),
            )
        }
    }
    else -> {
        // if k is larger than half the number of values it will just produce a useless result,
        // so we override the passed in k with a more sensible one
        val k = if (suggestedK > this.size / 2) this.size / 2 else suggestedK

        val dataPoints = this.mapIndexed { i, it -> DataPoint(value = valueExtractor(it), cluster = i % k, item = it) }
        val clusters = (0 until k).map { Cluster(average = .0, count = 0, items = mutableListOf<T>()) }
        val minAcceptedError = this.size / 10_000
        // it's not possible to prove that the following loop terminates, so we bound its execution
        // by an arbitrary amount of iterations. The result should be "correct" in the sense that this
        // is an approximation so there is no wrong result, only bad uses user experience
        var attempts = 10_000
        do {
            for (c in clusters) {
                c.average = 0.0
                c.count = 0
                c.items.clear()
            }
            for (d in dataPoints) {
                val c = d.cluster
                with(clusters[c]) {
                    average += d.value
                    count++
                    items.add(d.item)
                }
            }
            for (c in clusters) {
                if (c.count != 0) c.average /= c.count
            }
            var changed = 0
            for (d in dataPoints) {
                val t = calculateNearest(d, clusters)
                if (t != d.cluster) {
                    changed++
                    d.cluster = t
                }
            }
        } while (changed > minAcceptedError && attempts-- > 0)
        clusters
    }
}.map {
    if (it.items.size > 0) {
        it.max = it.items.maxOf { item -> valueExtractor(item) }
        it.min = it.items.minOf { item -> valueExtractor(item) }
    }
    it
}

private fun <T> calculateNearest(point: DataPoint<T>, clusters: List<Cluster<T>>): Int {
    return clusters
        .asSequence()
        .mapIndexed { i, it -> i to (it.average - point.value).pow(2) }
        .minBy { it.second }
        .first
}
