package com.speechify.client.reader.classic

import com.speechify.client.api.content.Content
import com.speechify.client.api.content.ContentCursor
import com.speechify.client.api.content.ContentText
import com.speechify.client.api.content.ContentTextUtils
import com.speechify.client.api.content.contains
import com.speechify.client.api.content.containsCursor
import com.speechify.client.api.content.hasNontrivialIntersectionWith
import com.speechify.client.api.content.view.standard.StandardElement
import com.speechify.client.api.content.view.standard.toClassicBlockStyle
import com.speechify.client.internal.util.extensions.collections.groupConsecutiveBy
import com.speechify.client.internal.util.text.groupingToSentences.getSentencesAsIndexRanges
import com.speechify.client.internal.util.text.groupingToWords.getWordsWithPunctuationAsIndexRanges
import com.speechify.client.reader.classic.ClassicBlock.List.FlatList
import com.speechify.client.reader.classic.ClassicBlock.List.NestedList
import com.speechify.client.reader.classic.ClassicBlock.List.NestedListItemContent
import com.speechify.client.reader.classic.ClassicBlock.List.Style
import com.speechify.client.reader.classic.ClassicBlockBuilder.Companion.buildToText
import com.speechify.client.reader.core.ReaderFeatures
import com.speechify.client.reader.core.ReaderScope
import com.speechify.client.reader.core.Selection
import com.speechify.client.reader.core.SelectionHandle
import com.speechify.client.reader.core.SerialLocation
import com.speechify.client.reader.core.dispatch
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.Flow

private fun List<StandardElement>.parseToText(): ClassicBlockBuilder.Text {
    return ClassicBlockBuilder.Text.join(
        this.flatMap { it.parse() }
            .filterIsInstance<ClassicBlockBuilder.Text>(),
    )
}
private fun StandardElement.parseToText(): ClassicBlockBuilder.Text {
    return ClassicBlockBuilder.Text.join(
        this.parse().filterIsInstance<ClassicBlockBuilder.Text>(),
    )
}

private fun StandardElement.Table.Cell.tryParseToTextWithFallback(): ClassicBlockBuilder.Text? {
    val text = ClassicBlockBuilder.Text.join(
        texts = _elements
            .flatMap { it.parse() }
            .map { if (it is ProvidesFallbackContent) { it.fallbackTextBuilder } else { it } }
            .filterIsInstance<ClassicBlockBuilder.Text>(),
    )
    return if (text.elements.isNotEmpty()) { text } else { null }
}

private fun StandardElement.withChildren(elements: List<StandardElement>): StandardElement = when (this) {
    is StandardElement.Anchor.External -> StandardElement.Anchor.External(
        url = this.url,
        _elements = elements,
    )

    is StandardElement.Anchor.Internal -> StandardElement.Anchor.Internal(
        cursor = this.cursor,
        _elements = elements,
    )

    is StandardElement.Text -> error("Unreachable")
    is StandardElement.Underlined -> StandardElement.Underlined(elements)
    is StandardElement.Bold -> StandardElement.Bold(elements)
    is StandardElement.Code -> StandardElement.Code(elements)
    is StandardElement.Heading -> StandardElement.Heading(this.level, elements)
    is StandardElement.Italics -> StandardElement.Italics(elements)
    is StandardElement.Paragraph -> StandardElement.Paragraph(elements)
    is StandardElement.Image.Local -> this
    is StandardElement.Image.Remote -> this

    is StandardElement.Table.Cell -> TODO()
    is StandardElement.Table.Row -> TODO()
    is StandardElement.Table -> TODO()
    is StandardElement.List -> {
        val listItems = elements.filter { it is StandardElement.List.ListItem }
            .map { it as StandardElement.List.ListItem }
        StandardElement.List(
            listItems,
            StandardElement.List.ListStyle.DISC,
        )
    }

    is StandardElement.List.ListItem -> StandardElement.List.ListItem(elements)
}

private fun StandardElement.parse(): List<ClassicBlockBuilder> = when (this) {
    is StandardElement.Image -> listOf(ClassicBlockBuilder.Image(this))
    is StandardElement.Table -> listOf(
        ClassicBlockBuilder.Table(
            rows = this.rows.map { row ->
                ClassicBlockBuilder.Table.Row(
                    cells = row.cells.mapNotNull { cell ->
                        ClassicBlockBuilder.Table.Cell(
                            text = cell.tryParseToTextWithFallback() ?: return@mapNotNull null,
                            rowSpan = cell.rowSpan,
                            colSpan = cell.colSpan,
                            isHeader = cell.isHeader,
                        )
                    },
                )
            },
            fallbackTextBuilder = ClassicBlockBuilder.Text(
                elements = listOf(StandardElement.Text(text = this.text)),
            ),
        ),
    )

    is StandardElement.List -> {
        val validItems = items.filterNot { item ->
            item.elements.all { it.isDisallowedInListItem() }
        }

        // When every list item is invalid: just render the list in a simple textual format
        if (validItems.isEmpty()) {
            listOf(ClassicBlockBuilder.Text(elements = listOf(StandardElement.Text(text = this.text))))
        } else {
            // Attempt to render valid items only
            listOf(
                ClassicBlockBuilder.List(
                    items = validItems.map { ClassicBlockBuilder.List.ListItem(elements = it.elements.toList()) },
                    listStyle = this.listStyle,
                    fallbackTextBuilder = ClassicBlockBuilder.Text(
                        elements = listOf(StandardElement.Text(text = this.text)),
                    ),
                ),
            )
        }
    }

    is StandardElement.Text -> listOf(ClassicBlockBuilder.Text(listOf(this)))

    else -> this._elements.parse().map {
        when (it) {
            is ClassicBlockBuilder.Text -> ClassicBlockBuilder.Text(listOf(this.withChildren(it.elements)))
            else -> it
        }
    }
}

private fun List<StandardElement>.parse(): List<ClassicBlockBuilder> {
    return this.asSequence().flatMap { it.parse() }
        .groupConsecutiveBy { it is ClassicBlockBuilder.Text }
        .flatMap { (isText, builders) ->
            @Suppress("UNCHECKED_CAST")
            when {
                isText -> listOf(ClassicBlockBuilder.Text.join(builders as List<ClassicBlockBuilder.Text>))
                else -> builders
            }
        }.toList()
}

internal fun List<StandardElement>.joinToTextBlock(): ClassicBlockBuilder.Text {
    return buildToText(this)
}

private class FormattedTextBuilder(
    val contentText: ContentText,
    val formatting: Array<FormattedRange>,
) {

    fun build(scope: ReaderScope, readerFeatures: Flow<ReaderFeatures>): FormattedText {
        return FormattedText(
            dispatch = scope.dispatch,
            contentText = contentText,
            formatting = formatting,
            featuresHelper = ClassicTextFeaturesHelper(scope, readerFeatures, contentText),
        )
    }

    fun wrap(formatting: Formatting?): FormattedTextBuilder {
        if (formatting == null) return this
        return FormattedTextBuilder(
            contentText = contentText,
            formatting = this.formatting
                // all ranges with same formatting as the wrapper will be joined into a single range with full span
                .filterNot { it.formatting == formatting }
                .let { it + formatting.spanning(contentText.text) }
                .toTypedArray(),
        )
    }

    companion object {
        fun join(builders: List<FormattedTextBuilder>): FormattedTextBuilder {
            return FormattedTextBuilder(
                contentText = ContentTextUtils.concat(builders.map { it.contentText }),
                formatting = builders
                    .fold(emptyList<FormattedRange>() to 0) { (joinedRanges, joinedTextLength), text ->
                        val (leftNotJoin, leftMaybeJoin) = joinedRanges
                            .partition { it.range.endIndexExclusive < joinedTextLength }
                        val (rightMaybeJoin, rightNotJoin) = text.formatting
                            .partition { it.range.startIndex == 0 }
                        val leftJoinKeyed = leftMaybeJoin.groupBy { it.formatting }
                        val rightJoinKeyed = rightMaybeJoin.groupBy { it.formatting }

                        val joined = leftJoinKeyed.map {
                            check(it.value.size == 1) {
                                "Text block builder had two overlapping ranges of same type"
                            }
                            val leftJoinable = it.value.first()
                            val otherJoinables = rightJoinKeyed[it.key] ?: return@map leftJoinable
                            check(otherJoinables.size == 1) {
                                "Text block builder had two overlapping ranges of same type"
                            }
                            leftJoinable.joinedWith(otherJoinables.first())
                        }
                        val rightUnjoined = rightJoinKeyed.flatMap {
                            when {
                                leftJoinKeyed[it.key] == null -> it.value.map { it.offsetBy(joinedTextLength) }
                                else -> emptyList()
                            }
                        }
                        val rightNotJoinWithOffset = when {
                            joinedTextLength > 0 -> rightNotJoin.map { it.offsetBy(joinedTextLength) }
                            else -> rightNotJoin
                        }
                        val newJoinedRanges = leftNotJoin + joined + rightUnjoined + rightNotJoinWithOffset
                        newJoinedRanges to (joinedTextLength + text.contentText.length)
                    }.first.toTypedArray(),
            )
        }
    }
}

internal interface ProvidesFallbackContent {
    val fallbackTextBuilder: ClassicBlockBuilder.Text
}

internal sealed class ClassicBlockBuilder {
    data class Text(internal val elements: kotlin.collections.List<StandardElement>) :
        ClassicBlockBuilder() {

        private fun StandardElement.toFormatting(): Formatting? = when (this) {
            is StandardElement.Anchor.External -> Formatting.Linked(Link.External(this.url))
            is StandardElement.Bold -> Formatting.Bold
            is StandardElement.Code -> Formatting.Code
            is StandardElement.Italics -> Formatting.Italics
            is StandardElement.Paragraph -> null
            is StandardElement.Underlined -> Formatting.Underline
            else -> null
        }

        private fun StandardElement.toFormattedText(): FormattedTextBuilder? = when (this) {
            is StandardElement.Text -> FormattedTextBuilder(
                contentText = this.text,
                formatting = emptyArray(),
            )
            // filters out the StandardElement that doesn't have text content.
            is StandardElement.Image.Local, is StandardElement.Image.Remote -> null

            else -> {
                val builders = elements.mapNotNull { it.toFormattedText() }

                if (builders.isEmpty()) {
                    // Return null here to prevent an exception when
                    // concatenating an empty list of ContentText in join.
                    null
                } else {
                    val formatting = this.toFormatting()
                    FormattedTextBuilder.join(
                        builders = builders,
                    ).wrap(formatting)
                }
            }
        }

        fun buildTree(
            scope: ReaderScope,
            readerFeatures: Flow<ReaderFeatures>,
        ): FormattingTree {
            fun StandardElement.toFormattingTree(): FormattingTree = when (this) {
                is StandardElement.Text -> FormattingTree.Text(
                    dispatch = scope.dispatch,
                    contentText = this.text,
                    featuresHelper = ClassicTextFeaturesHelper(scope, readerFeatures, this.text),
                )

                else -> {
                    FormattingTree.Element(
                        formatting = this.toFormatting(),
                        children = _elements.map { it.toFormattingTree() }.toTypedArray(),
                    )
                }
            }

            return FormattingTree.Element(
                formatting = null,
                children = elements.map { it.toFormattingTree() }.toTypedArray(),
            )
        }

        fun buildText(
            scope: CoroutineScope,
            readerFeatures: Flow<ReaderFeatures>,
        ): FormattedText {
            val builder = FormattedTextBuilder.join(elements.mapNotNull { it.toFormattedText() })
            return builder.build(scope, readerFeatures)
        }

        companion object {
            fun join(texts: kotlin.collections.List<Text>): Text {
                return Text(texts.flatMap { it.elements })
            }
        }

        fun containsLocation(location: SerialLocation): Boolean {
            return containsCursor(elements.first().start, elements.last().end, location.cursor)
        }

        fun isLocationBeforeOrAtEnd(location: SerialLocation) = location.cursor.isBeforeOrAt(elements.last().end)
        fun isLocationBefore(location: SerialLocation) = location.cursor.isBefore(elements.first().start)
        fun isLocationAfter(location: SerialLocation) = location.cursor.isAfter(elements.last().end)
    }

    data class Image(val image: StandardElement.Image) : ClassicBlockBuilder()

    data class Table(
        val rows: kotlin.collections.List<Row>,
        override val fallbackTextBuilder: Text,
    ) : ClassicBlockBuilder(), ProvidesFallbackContent {
        data class Row(val cells: kotlin.collections.List<Cell>)
        data class Cell(val text: Text, val rowSpan: Int, val colSpan: Int, val isHeader: Boolean)
    }

    data class List(
        val items: kotlin.collections.List<ListItem>,
        val listStyle: StandardElement.List.ListStyle,
        override val fallbackTextBuilder: Text,
    ) : ClassicBlockBuilder(), ProvidesFallbackContent {
        fun buildFlatListItems(
            items: kotlin.collections.List<StandardElement>,
            indent: Int,
            style: Style?,
            scope: CoroutineScope,
            readerFeatures: Flow<ReaderFeatures>,
        ): kotlin.collections.List<FlatList.Item> {
            val flatListItems = mutableListOf<FlatList.Item>()
            items.forEach { item ->
                when (item) {
                    is StandardElement.Text -> {
                        flatListItems.add(
                            FlatList.Item(
                                indentLevel = indent,
                                style = style ?: Style.Bullet,
                                builder = listOf(item).joinToTextBlock(),
                                scope = scope,
                                readerFeatures = readerFeatures,
                            ),
                        )
                    }
                    is StandardElement.List -> {
                        flatListItems.addAll(
                            buildFlatListItems(
                                item._elements,
                                indent + 1,
                                item.listStyle.toClassicBlockStyle(),
                                scope,
                                readerFeatures,
                            ),
                        )
                    }
                    is StandardElement.List.ListItem -> {
                        flatListItems.addAll(
                            buildFlatListItems(item._elements, indent, style, scope, readerFeatures),
                        )
                    }
                    else -> {}
                }
            }
            return flatListItems
        }

        fun buildNestedItems(
            items: Array<StandardElement>,
            scope: CoroutineScope,
            readerFeatures: Flow<ReaderFeatures>,
        ): kotlin.collections.List<NestedListItemContent> {
            return items.map { listItem ->
                when (listItem) {
                    is StandardElement.Text,
                    is StandardElement.Underlined,
                    is StandardElement.Bold,
                    is StandardElement.Code,
                    is StandardElement.Heading,
                    is StandardElement.Italics,
                    is StandardElement.Paragraph,
                    is StandardElement.Anchor.External,
                    is StandardElement.Anchor.Internal,
                    -> NestedList.Text(
                        listOf(listItem).joinToTextBlock(),
                        scope,
                        readerFeatures,
                    )
                    is StandardElement.List -> {
                        val listItems = listItem.elements.map {
                            NestedList.Item(buildNestedItems(it.elements, scope, readerFeatures).toTypedArray())
                        }.toTypedArray()
                        NestedList(listItems, listItem.listStyle.toClassicBlockStyle())
                    }
                    else -> null
                }
            }.filterNotNull()
        }

        data class ListItem(
            val elements: kotlin.collections.List<StandardElement>,
        ) : ClassicBlockBuilder() {
            fun containsLocation(location: SerialLocation): Boolean {
                return containsCursor(elements.first().start, elements.last().end, location.cursor)
            }

            fun isLocationBeforeOrAtEnd(location: SerialLocation) = location.cursor.isBeforeOrAt(elements.last().end)
            fun isLocationBefore(location: SerialLocation) = location.cursor.isBefore(elements.first().start)
            fun isLocationAfter(location: SerialLocation) = location.cursor.isAfter(elements.last().end)
        }
    }

    internal companion object {

        internal fun splitIntoClassicBlocks(
            elements: kotlin.collections.List<StandardElement>,
            scope: CoroutineScope,
            readerFeatures: Flow<ReaderFeatures>,

            // save a bit of code by handling cases with no heading and heading with nullable arg
            headingLevel: Int?,
        ): kotlin.collections.List<ClassicBlock> {
            return elements.parse()
                .map { it ->
                    when {
                        it is Text -> {
                            when {
                                headingLevel == null -> ClassicBlock.Paragraph(
                                    scope = scope,
                                    readerFeatures = readerFeatures,
                                    builder = it,
                                )

                                else -> ClassicBlock.Heading(
                                    scope = scope,
                                    readerFeatures = readerFeatures,
                                    builder = it,
                                    level = headingLevel,
                                )
                            }
                        }

                        it is Image -> when (it.image) {
                            is StandardElement.Image.Local -> TODO()
                            is StandardElement.Image.Remote -> ClassicBlock.Image.Remote(
                                height = it.image.height,
                                width = it.image.width,
                                altText = it.image.altText,
                                url = it.image.url,
                            )
                        }

                        it is Table -> ClassicBlock.Table(
                            rows = it.rows.map {
                                ClassicBlock.Table.Row(
                                    cells = it.cells.map {
                                        ClassicBlock.Table.Row.Cell(
                                            scope = scope,
                                            readerFeatures = readerFeatures,
                                            builder = it.text,
                                            rowSpan = it.rowSpan,
                                            colSpan = it.colSpan,
                                            isHeader = it.isHeader,
                                        )
                                    }.toTypedArray(),
                                )
                            }.toTypedArray(),
                        )

                        it is List -> {
                            ClassicBlock.List(
                                builder = it,
                                style = it.listStyle.toClassicBlockStyle(),
                                scope = scope,
                                readerFeatures = readerFeatures,
                            )
                        }

                        else -> error("Unreachable")
                    }
                }
        }

        fun buildToText(
            elements: kotlin.collections.List<StandardElement>,
        ): Text {
            return elements.parseToText()
        }
    }
}

internal fun ContentText.getIndexRangeOfWordContainingCursor(cursor: ContentCursor): IntRange? {
    if (!this.containsCursor(cursor)) return null
    val index = this.getFirstIndexOfCursor(cursor)
    val words = this.text.getWordsWithPunctuationAsIndexRanges()
    return words.first { it.last >= index }
}

internal fun ContentText.getIndexRangeOfSentenceContainingCursor(cursor: ContentCursor): IntRange? {
    if (!this.containsCursor(cursor)) return null
    val index = this.getFirstIndexOfCursor(cursor)
    val sentences = this.text.getSentencesAsIndexRanges()
    return sentences.first { it.last >= index }
}

internal fun ContentText.getSentenceContainingCursor(cursor: ContentCursor): ContentText? {
    return this.getIndexRangeOfSentenceContainingCursor(cursor)?.let {
        this.slice(it.first, it.last + 1)
    }
}

internal fun ContentText.getWordContainingCursor(cursor: ContentCursor): ContentText? {
    return this.getIndexRangeOfWordContainingCursor(cursor)?.let {
        this.slice(it.first, it.last + 1)
    }
}

internal fun ContentText.intersect(other: Content): ContentText {
    val thisStart = this.getFirstIndexOfCursor(other.start)
    val thisEnd = this.getLastIndexOfCursor(other.end)
    return this.slice(thisStart, thisEnd + 1)
}

internal fun ContentText.intersectionRange(other: Content): IntRange? {
    if (!this.hasNontrivialIntersectionWith(other)) return null
    val thisStart = this.getFirstIndexOfCursor(other.start)
    val thisEnd = this.getLastIndexOfCursor(other.end)
    return thisStart..thisEnd
}

/**
 * This will identify the position of the overlapped [ContentText] in the complete [Content] that is highlighted
 * and returns the start or end handle if the current [ContentText] is at the boundary of overall selection.
 *
 */
internal fun ContentText.findBoundaryHandles(completeSelection: Selection): Pair<SelectionHandle?, SelectionHandle?>? {
    return when {
        completeSelection.start.isAfterOrAt(this.start) &&
            completeSelection.end.isBeforeOrAt(this.end) -> Pair(
            completeSelection.startHandle,
            completeSelection.endHandle,
        )

        completeSelection.start.isAfterOrAt(this.start) -> Pair(completeSelection.startHandle, null)
        completeSelection.end.isBeforeOrAt(this.end) -> Pair(null, completeSelection.endHandle)
        else -> null
    }
}

private fun StandardElement.isDisallowedInListItem(): Boolean = when (this) {
    // Always allows Text nodes (they are valid leaf nodes)
    is StandardElement.Text -> false
    // For non-Text nodes, check if either:
    // - They have no children (invalid leaf)
    // - None of their children eventually lead to a Text node
    else -> _elements.isEmpty() || !_elements.any { !it.isDisallowedInListItem() }
}
