package com.speechify.client.internal.util.extensions.collections

import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow

/**
 * A version of [Flow.windowed] whose windows can be seen as batches items from the original flow, each of size
 * [batchSize]. The batches never overlap, or miss any items, so all batches concatenated would produce the same
 * sequence as the original flow.
 *
 * This is same as [Flow.windowed] with `partialWindows=true` and `size=step=batchSize`, also is sometimes called
 * `chunked` (e.g. [Sequence.chunked])
 */
internal fun <T> Flow<T>.windowedToBatches(
    /**
     * Note that the last batch may have a smaller size.
     */
    batchSize: Int,
) =
    windowed(size = batchSize, step = batchSize, partialWindows = true)

/**
 * A variant of [windowedToBatches] where the size of each batch is not constant (each batch can have a different size)
 * because it is not decided based on the requested count of items, but the desired sum of [itemSize]s across the items
 * that end up in the batch (expressed in [aimedSizeSumInEachBatch]).
 */
internal fun <E> Flow<E>.windowedToBatchesOfAimedSizeSum(
    itemSize: (item: E) -> Int,
    /**
     * The size is 'aimed', because it is only be this exact size if the items in the batch happen to add up to this
     * number.
     *
     * **NOTE: The strategy to achieve the 'aimed' sum can change**. For now, it is simple 'minimum', but
     * this may still lead to batches significantly larger than the aim, so in the future the may be some choices
     * that end up producing a batch that is smaller than [aimedSizeSumInEachBatch]. This is so far in line with all
     * calls to this function, but if a caller requires a specific strategy, make sure that the function signature
     * reflects the intent (e.g. take a parameter for specifying the strategy or have a separate overload).
     */
    aimedSizeSumInEachBatch: Int,
) = flow<List<E>> {
    var itemsNotYetEmitted = mutableListOf<E>()
    var itemsNotYetEmittedSize = 0

    this@windowedToBatchesOfAimedSizeSum
        .collect { item ->
            itemsNotYetEmitted.add(item)
            itemsNotYetEmittedSize += itemSize(item)

            if (itemsNotYetEmittedSize >= aimedSizeSumInEachBatch) {
                emit(itemsNotYetEmitted)
                itemsNotYetEmitted = mutableListOf()
                itemsNotYetEmittedSize = 0
            }
        }

    if (itemsNotYetEmitted.size > 0) {
        emit(itemsNotYetEmitted)
    }
}

/**
 * A variant of [windowedToBatchesOfAimedSizeSum] where the batches aim to have size [aimedSizeSumInEachBatch] and
 * each gets immediately mapped to [BatchTransformResult] using [mapBatchWithRealignment].
 *
 * This [mapBatchWithRealignment] has the special 'realignment' capability, in that it may cause only first few items
 * of the proposed batch to be mapped, while the rest will be moved to the next batch.
 * This is especially useful for implementing cache of such batched processing, where different starting items
 * are possible.
 */
internal fun <InputItem, BatchTransformResult> Flow<InputItem>.windowedToBatchesOfAimedSizeSumWithMapAndRealignment(
    itemSize: GetItemSize<InputItem>,
    /**
     * For detailed documentation see [windowedToBatchesOfAimedSizeSum]'s parameter of the same.
     */
    aimedSizeSumInEachBatch: Int,
    maxSizeOptions: MaxSizeOptions<InputItem>,
    /**
     * The `batch` parameter will always have at least one element.
     *
     * NOTE: You can use [withSplittingToComplyWithMaxSize] around the function to add support for splitting to ensure
     * maximum batch size.
     */
    mapBatchWithRealignment: MapBatchWithRealignmentFn<InputItem, BatchTransformResult>,
): Flow<TransformedBatchResult<InputItem, BatchTransformResult>> =
    windowedToBatchesOfAimedSizeSumWithMapAndRealignment(
        itemSize = itemSize,
        aimedSizeSumInEachBatch = aimedSizeSumInEachBatch,
        mapBatchWithRealignment = mapBatchWithRealignment.withSplittingToComplyWithMaxSize(
            options = maxSizeOptions,
            itemSize = itemSize,
        ),
    )

/**
 * An overload which does not have the `maxSizeOptions` parameter, so it does not support capping of max-size nor
 * splitting [InputItem]s.
 */
internal fun <InputItem, BatchTransformResult> Flow<InputItem>.windowedToBatchesOfAimedSizeSumWithMapAndRealignment(
    itemSize: GetItemSize<InputItem>,
    aimedSizeSumInEachBatch: Int,
    mapBatchWithRealignment: MapBatchWithRealignmentFn<InputItem, BatchTransformResult>,
): Flow<TransformedBatchResult<InputItem, BatchTransformResult>> = flow {
    var itemsNotYetEmitted = mutableListOf<InputItem>()
    var itemsNotYetEmittedSize = 0
    var batchIndex = 0

    suspend fun emitCurrentAndUpdateStateWhile(whileIsTrue: () -> Boolean) {
        while (whileIsTrue()) {
            val mapResultWithRealignment = mapBatchWithRealignment(
                IndexedValue(
                    index = batchIndex,
                    value = itemsNotYetEmitted,
                ),
            )
            val inputItemsConsumedInThisBatch: List<InputItem>
            when (mapResultWithRealignment) {
                is BatchTransformResultWithRealignment.ExactMatch -> {
                    inputItemsConsumedInThisBatch = itemsNotYetEmitted
                    itemsNotYetEmitted = mutableListOf()
                    itemsNotYetEmittedSize = 0
                }

                is BatchTransformResultWithRealignment.MatchOfShorterInput -> {
                    itemsNotYetEmitted = mapResultWithRealignment.remainingUnmatchedItems.toMutableList()
                    inputItemsConsumedInThisBatch = mapResultWithRealignment.inputItemsConsumedInThisBatch
                    itemsNotYetEmittedSize = itemsNotYetEmitted.sumOf(itemSize)
                }
            }
            emit(
                value = TransformedBatchResult(
                    inputItems = inputItemsConsumedInThisBatch,
                    resultOfTransform = mapResultWithRealignment.result,
                ),
            )
            ++batchIndex
        }
    }

    this@windowedToBatchesOfAimedSizeSumWithMapAndRealignment
        .collect { item ->
            itemsNotYetEmitted.add(item)
            itemsNotYetEmittedSize += itemSize(item)

            emitCurrentAndUpdateStateWhile(
                whileIsTrue = {
                    itemsNotYetEmittedSize >= aimedSizeSumInEachBatch /* Note - calling
                     `emitCurrentAndUpdateStateWhile` here may lead to multiple emissions as well, and this is desired.
                     This is because each emission is free to reduce the batch, and it's possible that there are some
                     small items at the beginning which will make the remaining items still fulfill this condition.
                     It is desired because we do want to emit the rest immediately in that case, as else we are creating
                     an unnatural batch that has no chance of aligning (it's bigger than any other one would have been
                     in the first run without cache, which yielded no realignment).
                */
                },
            )
        }

    emitCurrentAndUpdateStateWhile(
        whileIsTrue = {
            itemsNotYetEmitted.size > 0 /* There are no more items. Now we just need to empty
             the `itemsNotYetEmitted` (still in a loop, as there may be realignments based on previous fragments) */
        },
    )
}

internal typealias MapBatchWithRealignmentFn<InputItem, BatchTransformResult> =
    suspend (indexedBatch: IndexedValue<List<InputItem>>) ->
    BatchTransformResultWithRealignment<InputItem, BatchTransformResult>

internal typealias GetItemSize<InputItem> = (item: InputItem) -> Int

/**
 * Allows to support capping at maximum batch size, including the situation where a single item is larger than the
 * maximum batch size - this is solved by splitting of such items.
 */
private fun <InputItem, BatchTransformResult>
MapBatchWithRealignmentFn<InputItem, BatchTransformResult>.withSplittingToComplyWithMaxSize(
    options: MaxSizeOptions<InputItem>,
    itemSize: GetItemSize<InputItem>,
): MapBatchWithRealignmentFn<InputItem, BatchTransformResult> =
    { suggestedBatch ->
        val (_, suggestedBatchItems: List<InputItem>) = suggestedBatch

        val batchSumSize = suggestedBatchItems.sumOf(itemSize)
        if (batchSumSize <= options.maxSizeSumInEachBatch) {
            this(suggestedBatch)
        } else {
            // If there is any part of the batch that fits the maximum size, return it.
            val (fittingWholeItems, notFittingWholeItems) = suggestedBatchItems.asSequence()
                .withRunningSum(itemSize)
                .partition(
                    predicate = { (sum: Int, _: InputItem) ->
                        sum <= options.maxSizeSumInEachBatch
                    },
                ).mapPairItems { it.map { (_: Int, item) -> item } }

            val itemsToAttempt: List<InputItem>
            val remainingItems: List<InputItem>
            if (fittingWholeItems.isNotEmpty()) {
                itemsToAttempt = fittingWholeItems
                remainingItems = notFittingWholeItems
            } else {
                val (firstUnsplit, restUnsplit) = notFittingWholeItems.partitionToFirstAndRestOrNull()
                    ?: throw IllegalStateException(
                        "Arrived at no items to split. This may be a bug in the function of this exception.",
                    )

                val (firstSplitPart, restSplitPart) = options.splitItemExceedingSize(
                    /* item = */ firstUnsplit,
                    /* maxSize = */ options.maxSizeSumInEachBatch,
                ).toList().partitionToFirstAndRestOrNull()
                    ?: throw IllegalStateException(
                        "Arrived at no items after splitting. This may be a bug in `options.splitItemExceedingSize`.",
                    )

                itemsToAttempt = listOf(firstSplitPart)
                /* Order is important here: first what the splitting just produced, and then what we have cut-off
                   earlier.
                 */

                remainingItems = restSplitPart + restUnsplit
            }

            val resultFromOriginalMap = this(
                IndexedValue(
                    index = suggestedBatch.index,
                    value = itemsToAttempt,
                ),
            )

            when (resultFromOriginalMap) {
                is BatchTransformResultWithRealignment.ExactMatch ->
                    BatchTransformResultWithRealignment.MatchOfShorterInput(
                        result = resultFromOriginalMap.result,
                        inputItemsConsumedInThisBatch = itemsToAttempt,
                        remainingUnmatchedItems = remainingItems,
                    )
                is BatchTransformResultWithRealignment.MatchOfShorterInput ->
                    BatchTransformResultWithRealignment.MatchOfShorterInput(
                        result = resultFromOriginalMap.result,
                        inputItemsConsumedInThisBatch = resultFromOriginalMap.inputItemsConsumedInThisBatch,
                        /* Order is important here too: first the remaining items that the original mapper just
                           produced, and then what we have cut-off and never passed to the original mapper.
                         */
                        remainingUnmatchedItems = resultFromOriginalMap.remainingUnmatchedItems + remainingItems,
                    )
            }
        }
    }

internal class MaxSizeOptions<T>(
    val maxSizeSumInEachBatch: Int,
    val splitItemExceedingSize: SplitItemByMaxSizeFn<T>,
)

internal typealias SplitItemByMaxSizeFn<T> = (item: T, maxSize: Int) -> Sequence<T>

internal class TransformedBatchResult<InputItem, BatchTransformResult>(
    val inputItems: List<InputItem>,
    val resultOfTransform: BatchTransformResult,
)

/**
 * A result of mapping (transforming) a batch of items taken from some larger sequence which enable a mechanism of
 * realigning the batching of input stream to what the producer prefers.
 * This is typically done in caching, when used for speeding up processing of a sequence of input, where cache entries
 * cover batches of input items, and subsequent requests can start from a different item, producing different batches.
 * Realigning enable avoiding the situation that part (or even entirety) of already cached content ends up never used
 * due to the second streaming starting from an item that does not equal the starting of any cache entries produced by
 * initial batching (the situation could theoretically prevail until the entire cached content, hence, without this
 * mechanism, entirety of perfectly valid cache contents may end up not reused).
 */
internal sealed class BatchTransformResultWithRealignment<out InputItem, out Result>(val result: Result) {

    abstract fun <NewResult> mapResult(transform: (Result) -> NewResult):
        BatchTransformResultWithRealignment<InputItem, NewResult>

    /**
     * Represents a result that matches the requested input batch fully, and requires no realignment.
     */
    class ExactMatch<out Result>(result: Result) :
        BatchTransformResultWithRealignment<Nothing, Result>(result) {
        override fun <NewResult> mapResult(transform: (Result) -> NewResult): ExactMatch<NewResult> =
            ExactMatch(transform(result))
    }

    /**
     * Represents a result that contains a useful product for consumer but, while it starts correctly, it is based
     * on a shorter input batch than requested, so requires the requester to move the [remainingUnmatchedItems] to the
     * next batch.
     */
    class MatchOfShorterInput<InputItem, out Result>(
        result: Result,
        /**
         * This always contains some head of the input items.
         * NOTE: It also may contain slices of items, if the usage supports splitting the input.
         */
        val inputItemsConsumedInThisBatch: List<InputItem>,
        /**
         * This always contains some tail of the input items.
         * NOTE: It also may contain slices of items, if the usage supports splitting the input.
         */
        val remainingUnmatchedItems: Iterable<InputItem>,
    ) : BatchTransformResultWithRealignment<InputItem, Result>(result) {
        override fun <NewResult> mapResult(
            transform: (Result) -> NewResult,
        ): MatchOfShorterInput<InputItem, NewResult> =
            MatchOfShorterInput(
                result = transform(result),
                inputItemsConsumedInThisBatch = inputItemsConsumedInThisBatch,
                remainingUnmatchedItems = remainingUnmatchedItems,
            )
    }
}
