package com.speechify.client.internal.util.collections.maps

import com.speechify.client.internal.launchAsync
import com.speechify.client.internal.sync.AtomicRef
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update

/**
 * Use this if you need a thread safe map that requires suspending when obtaining a value.
 * If you don't require suspending, use [BlockingThreadsafeMap] instead.
 */
internal open class LockingThreadsafeMap<K, V>(
    initialMap: Map<K, V> = emptyMap(),
) :
    ThreadSafeMutableMapAsyncOperations<K, V>,
    MapFilledWhileAsyncGettingWithRetrieval<K, V>,
    MapFilledWhileAsyncGettingWithRetrievalIncludingPending<K, V> {
    protected val threadSafeMap: BlockingThreadsafeMap<K, Deferred<V>> =
        initialMap.mapValues {
            CompletableDeferred(it.value) as Deferred<V>
        }
            .toMutableMap()
            .asBlockingThreadsafeMap()

    override suspend fun getOrPut(
        key: K,
        defaultValue: suspend () -> V,
    ): V =
        threadSafeMap.getOrPut(key) {
            launchAsync(start = CoroutineStart.LAZY) { defaultValue() }
        }.await()

    /**
     * This will not block if the map is currently being updated, but return the current value immediately.
     * Use [getIncludingPending] if you want to wait for pending updates.
     */
    @OptIn(ExperimentalCoroutinesApi::class)
    override operator fun get(key: K): V? =
        threadSafeMap[key]
            ?.let {
                if (it.isCompleted) {
                    it.getCompleted()
                } else {
                    null
                }
            }

    override fun getIncludingPending(key: K): Deferred<V>? =
        threadSafeMap[key]

    override fun removeIncludingPending(key: K): Deferred<V>? =
        threadSafeMap.remove(key)
}

/**
 * A version of [LockingThreadsafeMap] that also emits the latest map on every update.
 */
internal class LockingThreadsafeMapWithUpdateCallback<K, V>(
    initialMap: Map<K, V> = emptyMap(),
) : LockingThreadsafeMap<K, V>(initialMap) {

    private val mapUpdateEventSink = MutableStateFlow<Map<K, V>>(initialMap)
    val mapUpdateEventFlow = mapUpdateEventSink.asStateFlow()
    override suspend fun getOrPut(key: K, defaultValue: suspend () -> V): V {
        return super.getOrPut(key, defaultValue).also { value ->
            mapUpdateEventSink.update {
                it + (key to value)
            }
        }
    }
}

/**
 * A locking version of [MutableMapBy].
 */
internal class LockingThreadsafeMapBy<KRich, KWithSemanticEqual, Value>(
    private val getEquatableKey: (keyRich: KRich) -> KWithSemanticEqual,
) : MapFilledWhileAsyncGettingWithRetrievalIncludingPending<KRich, Value> {
    private val backingMap = LockingThreadsafeMap<KWithSemanticEqual, Value>()
    override suspend fun getOrPut(
        key: KRich,
        defaultValue: suspend () -> Value,
    ): Value =
        backingMap.getOrPut(
            key = getEquatableKey(key),
            defaultValue = defaultValue,
        )

    override fun getIncludingPending(key: KRich): Deferred<Value>? =
        backingMap.getIncludingPending(getEquatableKey(key))

    override suspend fun add(key: KRich, value: suspend () -> Value): Value =
        backingMap.add(
            key = getEquatableKey(key),
            value = value,
        )
}

/**
 * Use this the values are requested in batches.
 */
internal class LockingThreadsafeMapWithMulti<K, V>(
    initialMap: Map<K, V> = emptyMap(),
) :
    MapFilledWhileAsyncGetting<K, V>,
    MapFilledWhileAsyncGettingWithMultiAndClear<K, V> {
    private val threadSafeMapAtomic: AtomicRef<BlockingThreadsafeMap<K, Deferred<V>>> =
        AtomicRef(
            initialMap.mapValues {
                CompletableDeferred(it.value) as Deferred<V>
            }
                .toMutableMap()
                .asBlockingThreadsafeMap(),
        )

    override suspend fun getOrPutMulti(
        keys: List<K>,
        produceMissingValues: suspend (keysOfMissingEntries: List<K>) -> List<V>,
    ): List<V> =
        threadSafeMapAtomic.value.getOrPutMulti(
            keys,
            produceMissingValues = { keysOfMissingEntries ->
                /* We only have the API to produce all values at once so there is one Deferred we here, but then
                 * we need to split it into multiple Deferreds, one for each key.
                 */
                val produced = launchAsync(
                    /* We also don't start the Deferred - we will only start if this one's individual Deferreds
                     end up getting added to the map. */
                    start = CoroutineStart.LAZY,
                ) {
                    val values = produceMissingValues(keysOfMissingEntries)
                    if (values.size != keysOfMissingEntries.size) {
                        throw IllegalArgumentException(
                            "Inputs were not the same length while this was expected. First: ${values.size}, " +
                                "second: ${keysOfMissingEntries.size}",
                        )
                    }
                    values
                }

                /* Now break down into the multiple Deferreds under each key.  */
                keysOfMissingEntries.withIndex().map { (index, key) ->
                    launchAsync(
                        /* Here, likewise, don't start the Deferreds yet. Will start only those that get returned
                         from the map */
                        start = CoroutineStart.LAZY,
                    ) { produced.await()[index] }
                }
            },
        )
            .let { getOrPutMultiResult ->
                getOrPutMultiResult.map {
                    it.await()
                }
            }

    override fun clear(): List<Pair<K, Deferred<V>>> =
        /* TODO - consider introducing ThreadsafeMutableMapClear.clearGetLastEntries` and having one here, thanks
         *  to which this map could be vastly simplified.
         */
        threadSafeMapAtomic.swap(
            with = BlockingThreadsafeMap(),
        )
            .entries
            .toList()
}
