@file:OptIn(ExperimentalSerializationApi::class) // ktlint-disable filename

package com.speechify.client.api.util.boundary

import com.speechify.client.api.util.Result
import com.speechify.client.api.util.SDKError
import com.speechify.client.api.util.successfully
import com.speechify.client.internal.util.boundary.SdkBoundaryMap
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerializationException
import kotlinx.serialization.SerializationStrategy
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.encoding.AbstractDecoder
import kotlinx.serialization.encoding.AbstractEncoder
import kotlinx.serialization.encoding.CompositeDecoder
import kotlinx.serialization.encoding.CompositeEncoder
import kotlinx.serialization.modules.EmptySerializersModule
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.serializer
import kotlin.reflect.KClass

/**
 * There is on gripe I have with the encoding api. There are two methods that should be mirrored but aren't.
 *
 * [AbstractEncoder.beginStructure] is called when a new substructure starts and expects a nested decoder to be
 * returned, so it can encode said nested structure.
 *
 * [AbstractEncoder.endStructure] is called when the new substructure is finished being encoded. The problem is
 * that this second method is called on the child encoder and not the parent
 *
 * Pseudocode that drives the encoder:
 * ```
 * encoder2 = encoder.beginStructure()
 *  encoder2.encodeField()
 *  encoder2.encodeField()
 *  encoder2.encodeField()
 *  encoder2.endStructure() // this is called in the nested encoder
 * ```
 *
 * This means we need to find a way to communicate to the parent that encoding is done thus we extend the api of our
 * encoders to have a cleanup. The encoders then store a reference to their parent, so they can communicate to the parent
 * that their encoding has finished.
 */
private interface BoundaryEncoder : CompositeEncoder {
    fun cleanUpStructure(descriptor: SerialDescriptor) {}
}

private class BoundaryMapEncoder(
    val map: SdkBoundaryMap<Any?> = SdkBoundaryMap.of(),
    override val serializersModule: SerializersModule = EmptySerializersModule(),
    private val parent: BoundaryEncoder? = null,
) : AbstractEncoder(), BoundaryEncoder {

    var nextKey: String? = null

    fun consumeKey(): String {
        return when (val k = nextKey) {
            null -> throw BoundaryMapSerializationException("top level type must be a composite object")
            else -> {
                nextKey = null
                k
            }
        }
    }

    override fun encodeElement(descriptor: SerialDescriptor, index: Int): Boolean {
        if (nextKey != null) throw IllegalStateException("called encode element twice")
        nextKey = descriptor.getElementName(index)
        return true
    }

    override fun encodeLong(value: Long) {
        throw IllegalArgumentException("longs are not safe to pass across boundaries")
    }

    override fun encodeValue(value: Any) {
        map[consumeKey()] = value
    }

    override fun encodeNull() {
        map[consumeKey()] = null
    }

    /**
     * Because our format is a boundary map, we only allow classes as the top level type that we encode, i.e. we can't
     * encode a list or a string as the first thing. This means that the first thing the "driver" will call for this
     * encoder will always be [beginStructure], to avoid unnecessary nesting we ignore the first call to [beginStructure]
     * which we know is the first since it's called [encodeElement], which means, [nextKey] is null
     *
     * @see BoundaryMapDecoder.ignoreBeginStructure for the reverse hack
     */
    override fun beginStructure(descriptor: SerialDescriptor) = when (val key = nextKey) {
        null -> this
        else -> encoderForKind(descriptor.kind) { map[key] = it }
    }

    override fun endStructure(descriptor: SerialDescriptor) {
        parent?.cleanUpStructure(descriptor)
    }

    override fun cleanUpStructure(descriptor: SerialDescriptor) {
        val k = consumeKey()
        /*
         * If we just finished encoding a list we have to turn the mutable list in the map into an array
         * to keep it boundary safe. We use a list as this is the only way to mutate it in-place, avoiding passing
         * a reference to the map around.
         */
        when (val v = map[k]) {
            is MutableList<*> -> map[k] = v.toTypedArray()
        }
    }

    private class SimpleArrayEncoder(
        val list: MutableList<Any?>,
        override val serializersModule: SerializersModule,
        private val parent: BoundaryEncoder?,
    ) :
        AbstractEncoder(), BoundaryEncoder {

        override fun encodeValue(value: Any) {
            list += value
        }

        override fun encodeNull() {
            list += null
        }

        override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder {
            return encoderForKind(descriptor.kind) { list += it }
        }

        override fun endStructure(descriptor: SerialDescriptor) {
            parent?.cleanUpStructure(descriptor)
        }
    }

    private companion object {
        private fun BoundaryEncoder.encoderForKind(kind: SerialKind, save: (Any) -> Unit): BoundaryEncoder {
            return when (kind) {
                is StructureKind.LIST -> {
                    val newArray = mutableListOf<Any?>()
                    save(newArray)
                    SimpleArrayEncoder(newArray, serializersModule, this)
                }

                else -> {
                    val newMap = SdkBoundaryMap.of<Any?>()
                    save(newMap)
                    BoundaryMapEncoder(newMap, serializersModule, this)
                }
            }
        }
    }
}

private class BoundaryMapDecoder(
    private val map: BoundaryMap<*>,
    override val serializersModule: SerializersModule = EmptySerializersModule(),
    /**
     * This lets us conform to the encoding done by remembering to also skip the first [beginStructure]
     */
    private var ignoreBeginStructure: Boolean = false,
) : AbstractDecoder() {
    private val entries = map.keys().toMutableList()

    override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
        val entry = peekNextKey() ?: return CompositeDecoder.DECODE_DONE
        return descriptor.getElementIndex(entry)
    }

    /*
     * Due to the way kotlinx.serialization works we need to provide this keys in this order, otherwise it crashes
     * this is only relevant for serialization of open inheritance hierarchies (this includes interfaces)
     *
     * For this code
     * ```
     * interface Foo
     * class Bar(val x: Int) : Foo
     * ```
     * instances of bar are serialized like so:
     * ```
     * { "type": "Bar", "value": { "x": 42 } }
     * ```
     * and kotlin mandates that `type` be provided first
     */
    private fun peekNextKey() = when {
        "type" in entries -> "type"
        "value" in entries -> "value"
        else -> entries.lastOrNull()
    }

    private fun popNextKey() = when {
        entries.remove("type") -> "type"
        entries.remove("value") -> "value"
        else -> {
            entries.removeLastOrNull()
                ?: throw BoundaryMapSerializationException("no more values to deserialize")
        }
    }

    override fun decodeValue(): Any {
        val value = map[popNextKey()]
        if (value == null) {
            throw BoundaryMapSerializationException("expected value got null")
        }
        return value
    }

    override fun decodeNull(): Nothing? {
        val key = popNextKey()
        return if (key in map && map[key] == null) {
            null
        } else {
            throw BoundaryMapSerializationException("expected '$key' to map to null but maps to '${map[key]}'")
        }
    }

    override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
        // See BoundaryMapEncoder.beginStructure for why this is needed
        if (ignoreBeginStructure) {
            ignoreBeginStructure = false
            return this
        }
        val key = popNextKey()
        return when (val structure = map[key]!!) { // if it was in entries than it must be in map
            is Array<*> -> {
                ArrayDecoder(serializersModule, structure)
            }

            is BoundaryMap<*> -> {
                BoundaryMapDecoder(structure, serializersModule)
            }

            else -> {
                throw BoundaryMapSerializationException(
                    "expected '$key' to map to be an array or map, got ${structure::class}",
                )
            }
        }
    }

    private class ArrayDecoder(
        override val serializersModule: SerializersModule,
        val array: Array<*>,
    ) : AbstractDecoder() {
        var index: Int = 0

        override fun decodeElementIndex(descriptor: SerialDescriptor) =
            if (index < array.size) index else CompositeDecoder.DECODE_DONE

        override fun decodeValue() = when (val next = next()) {
            null -> throw BoundaryMapSerializationException("expected a value found null")
            else -> next
        }

        override fun decodeNull(): Nothing? {
            val next = next()
            if (next != null) throw BoundaryMapSerializationException("expected null got '$next'")
            return null
        }

        private fun next(): Any? = array.getOrElse(index++) {
            throw BoundaryMapSerializationException("empty array")
        }
    }
}

private class BoundaryMapSerializationException(message: String? = null, cause: Throwable? = null) :
    SerializationException(message, cause)

sealed class Boundary private constructor(
    private val serializersModule: SerializersModule,
) {
    class BoundarySettings {
        var serializersModule: SerializersModule = EmptySerializersModule()
    }

    companion object Default : Boundary(EmptySerializersModule())

    internal class Impl(module: SerializersModule) : Boundary(module)

    fun <T : Any> encodeToBoundaryMap(
        serializer: SerializationStrategy<T>,
        value: T,
    ): Result<SdkBoundaryMap<Any?>> {
        return when (serializer.descriptor.kind) {
            // The top level T must be an object of a map to be encodable as a boundary map. Anything else is considered a
            // programmer error, so we throw.
            is StructureKind.LIST, is PrimitiveKind ->
                throw IllegalArgumentException("can't serialize ${serializer.descriptor.kind} as boundary map")

            else -> {
                try {
                    with(BoundaryMapEncoder(serializersModule = serializersModule)) {
                        encodeSerializableValue(serializer, value)
                        map.successfully()
                    }
                } catch (e: SerializationException) {
                    Result.Failure(SDKError.Serialization(e, value, BoundaryMap::class))
                }
            }
        }
    }

    inline fun <reified T : Any> encodeToBoundaryMap(value: T) =
        encodeToBoundaryMap(serializer(), value)

    fun <T : Any> decodeFromBoundaryMap(
        deserializer: DeserializationStrategy<T>,
        map: SdkBoundaryMap<Any?>,
        targetClass: KClass<T>,
    ): Result<T> {
        return try {
            BoundaryMapDecoder(map, serializersModule = serializersModule, ignoreBeginStructure = true)
                .decodeSerializableValue(deserializer).successfully()
        } catch (e: SerializationException) {
            Result.Failure(SDKError.Serialization(e, map, targetClass))
        }
    }

    inline fun <reified T : Any> decodeFromBoundaryMap(map: SdkBoundaryMap<Any?>) =
        decodeFromBoundaryMap(serializer(), map, T::class)
}

fun Boundary(builder: Boundary.BoundarySettings.() -> Unit): Boundary {
    return Boundary.Impl(Boundary.BoundarySettings().apply(builder).serializersModule)
}
