diff --git a/aiapi/src/main/kotlin/eu/m724/aiapi/AiApi.kt b/aiapi/src/main/kotlin/eu/m724/aiapi/AiApi.kt index 2d65b3a..da7053b 100644 --- a/aiapi/src/main/kotlin/eu/m724/aiapi/AiApi.kt +++ b/aiapi/src/main/kotlin/eu/m724/aiapi/AiApi.kt @@ -7,6 +7,7 @@ import io.ktor.client.engine.cio.CIO import io.ktor.client.plugins.contentnegotiation.ContentNegotiation import io.ktor.client.plugins.sse.SSE import io.ktor.client.plugins.sse.SSEBufferPolicy +import io.ktor.client.plugins.sse.deserialize import io.ktor.client.plugins.sse.sse import io.ktor.client.request.header import io.ktor.client.request.setBody @@ -15,14 +16,11 @@ import io.ktor.http.HttpMethod import io.ktor.http.contentType import io.ktor.http.headers import io.ktor.serialization.kotlinx.json.json -import kotlinx.coroutines.channels.awaitClose +import io.ktor.sse.TypedServerSentEvent import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.callbackFlow -import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.mapNotNull -import kotlinx.coroutines.launch import kotlinx.serialization.json.Json +import kotlinx.serialization.serializer class AiApi( private val session: String @@ -50,17 +48,16 @@ class AiApi( method = HttpMethod.Post contentType(ContentType.Application.Json) setBody(requestBody) + }, + deserialize = { + typeInfo, jsonString -> + val serializer = Json.serializersModule.serializer(typeInfo.kotlinType!!) + Json.decodeFromString(serializer, jsonString)!! } ) { block( - incoming.mapNotNull { - try { - val streamGptEvent = Json.decodeFromString(it.data!!) - streamGptEvent - } catch (e: Exception) { - // TODO - null - } + incoming.mapNotNull { event: TypedServerSentEvent -> + deserialize(event.data) } ) } diff --git a/src/main/kotlin/App.kt b/src/main/kotlin/App.kt index 9730017..4be382c 100644 --- a/src/main/kotlin/App.kt +++ b/src/main/kotlin/App.kt @@ -14,6 +14,7 @@ import com.jakewharton.mosaic.layout.size import com.jakewharton.mosaic.layout.wrapContentSize import com.jakewharton.mosaic.layout.wrapContentWidth import com.jakewharton.mosaic.modifier.Modifier +import com.jakewharton.mosaic.ui.Color import com.jakewharton.mosaic.ui.Text import eu.m724.modifier.BorderedBoxWithTabs import eu.m724.modifier.RunTitle @@ -41,6 +42,11 @@ fun App( }, subTitle = runs[selectedTabIndex].model ?: "" ) { + Text( + value = runs[selectedTabIndex].reasoningContent.wrap(90), + color = Color(200, 200, 200) + ) + Text( value = runs[selectedTabIndex].content.wrap(90) ) diff --git a/src/main/kotlin/ViewModel.kt b/src/main/kotlin/ViewModel.kt index 95731ea..bf3ab72 100644 --- a/src/main/kotlin/ViewModel.kt +++ b/src/main/kotlin/ViewModel.kt @@ -28,10 +28,11 @@ class ViewModel { var index = 0 _runs.update { index = it.size - it + Run(RunState.Queued, "") + it + Run(RunState.Queued, "", "") } - var response = "" + var reasoningContent = "" + var responseContent = "" val timeSource = TimeSource.Monotonic val startMark by lazy { @@ -43,6 +44,7 @@ class ViewModel { GlobalScope.launch(Dispatchers.IO) { aiApi.requestCompletion( requestBody = StreamGptRequestBody( + model = "free-model", messages = listOf( ChatMessage( role = ChatMessage.Companion.Role.User, @@ -56,15 +58,25 @@ class ViewModel { is StreamGptEvent.Completion -> { val now = timeSource.markNow() + event.choices[0].delta.reasoning?.let { chunk -> + reasoningContent += chunk + tokens++ + } + event.choices[0].delta.content?.let { chunk -> - response += chunk + responseContent += chunk tokens++ } _runs.update { it.toMutableList().apply { val tps = (tokens.toDouble() / (now - startMark).inWholeSeconds).toInt() - this[index] = Run(RunState.InProgress(tps), response, model = event.model) + this[index] = Run( + state = RunState.InProgress(tps), + reasoningContent = reasoningContent, + content = responseContent, + model = event.model + ) } } } @@ -84,6 +96,7 @@ class ViewModel { data class Run( val state: RunState, + val reasoningContent: String, val content: String, val model: String? = null ) diff --git a/src/main/kotlin/modifier/RunTitle.kt b/src/main/kotlin/modifier/RunTitle.kt index 1c6af83..3c0803c 100644 --- a/src/main/kotlin/modifier/RunTitle.kt +++ b/src/main/kotlin/modifier/RunTitle.kt @@ -46,7 +46,7 @@ fun RunTitle( } ) - if (state is RunState.InProgress) { + if (state is RunState.InProgress && state.tokensPerSecond != Int.MAX_VALUE) { Text( value = " ${state.tokensPerSecond}t/s", color = Color(0, 120, 0)