Response streaming

Also had to rework message displaying a bit, to accommodate
This commit is contained in:
Minecon724 2025-06-21 15:41:04 +02:00
commit 48f721e756
Signed by: Minecon724
GPG key ID: A02E6E67AB961189
24 changed files with 528 additions and 219 deletions

View file

@ -0,0 +1,26 @@
class SseCallAdapterFactory(
private val client: OkHttpClient,
private val moshi: Moshi // Or Gson
) : CallAdapter.Factory() {
override fun get(
returnType: Type,
annotations: Array<out Annotation>,
retrofit: Retrofit
): CallAdapter<*, *>? {
// Ensure the return type is a Flow
if (getRawType(returnType) != Flow::class.java) {
return null
}
// Ensure the Flow's generic type is SseEvent
val flowType = getParameterUpperBound(0, returnType as ParameterizedType)
if (getRawType(flowType) != SseEvent::class.java) {
return null
}
// Get the generic type of SseEvent<T>
val eventType = getParameterUpperBound(0, flowType as ParameterizedType)
return SseCallAdapter<Any>(client, moshi, eventType)
}
}

View file

@ -58,6 +58,7 @@ dependencies {
implementation(libs.retrofit)
implementation(libs.retrofit.converter.gson)
implementation(libs.androidx.material3.window.size.class1)
implementation(libs.okhttp.sse)
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core)

View file

@ -15,11 +15,13 @@ import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.imePadding
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.LazyListState
@ -45,7 +47,11 @@ import androidx.compose.material3.windowsizeclass.calculateWindowSizeClass
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.saveable.rememberSaveable
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.focus.focusRequester
@ -62,7 +68,7 @@ import eu.m724.chatapp.activity.chat.composable.NestedScrollKeyboardHider
import eu.m724.chatapp.activity.chat.composable.SimpleTextFieldWithPadding
import eu.m724.chatapp.activity.chat.composable.disableBringIntoViewOnFocus
import eu.m724.chatapp.activity.ui.theme.ChatAppTheme
import eu.m724.chatapp.api.data.ChatMessage
import eu.m724.chatapp.api.data.response.completion.ChatMessage
import kotlinx.coroutines.launch
@AndroidEntryPoint
@ -76,30 +82,36 @@ class ChatActivity : ComponentActivity() {
enableEdgeToEdge()
setContent {
val windowSizeClass = calculateWindowSizeClass(this)
val uiState by viewModel.uiState.collectAsStateWithLifecycle()
val softwareKeyboardController = LocalSoftwareKeyboardController.current
val context = LocalContext.current
val coroutineScope = rememberCoroutineScope()
val chatState = rememberChatState(
requestInProgress = uiState.requestInProgress,
onSend = { message ->
viewModel.sendMessage(message)
}
)
val chatState = rememberChatState()
val threadViewLazyListState = rememberLazyListState()
val onSend = {
if (chatState.composerValue.isNotBlank() && !uiState.requestInProgress) {
viewModel.sendMessage(chatState.composerValue)
}
}
ChatScreen(
windowSizeClass = windowSizeClass,
uiState = uiState,
chatState = chatState,
threadViewLazyListState = threadViewLazyListState,
onSend = onSend,
onRequestFocus = {
if (uiState.requestInProgress) return@ChatScreen
coroutineScope.launch {
if (threadViewLazyListState.layoutInfo.visibleItemsInfo.find { it.key == "composer" } == null) {
if (uiState.messages.isNotEmpty()) {
threadViewLazyListState.animateScrollToItem(uiState.messages.size)
// TODO this makes the composer full screen but if the condition above is false it doesn't, that may be kind of unintuitive
}
}
@ -109,26 +121,30 @@ class ChatActivity : ComponentActivity() {
}
)
LaunchedEffect(uiState.requestInProgress) {
if (uiState.requestInProgress) {
chatState.composerValue = ""
// scroll to the last user message too
threadViewLazyListState.animateScrollToItem(uiState.messages.size - 2)
} else if (uiState.messages.isNotEmpty()) {
// scroll to the last user message too
threadViewLazyListState.animateScrollToItem(uiState.messages.size - 2)
// if the composer is visible (message is short enough), focus on it
// if the message is long, we let the user read it
threadViewLazyListState.layoutInfo.visibleItemsInfo.firstOrNull {
it.key == "composer"
}?.let {
chatState.requestFocus()
softwareKeyboardController?.show() // TODO perhaps it's pointless to focus since we can click on the toolbar? maybe make it configurable
}
}
}
LaunchedEffect(Unit) {
viewModel.uiEvents.collect { event ->
when (event) {
is ChatActivityUiEvent.ProcessingRequest -> {
threadViewLazyListState.animateScrollToItem(uiState.messages.size)
}
is ChatActivityUiEvent.SuccessfulResponse -> {
chatState.composerValue = ""
if (uiState.messages.isNotEmpty()) {
threadViewLazyListState.animateScrollToItem(uiState.messages.size - 2)
}
threadViewLazyListState.layoutInfo.visibleItemsInfo.firstOrNull {
it.key == "composer"
}?.let {
chatState.requestFocus()
softwareKeyboardController?.show() // TODO perhaps it's pointless to focus since we can click on the toolbar?
}
}
is ChatActivityUiEvent.Error -> {
Toast.makeText(context, event.error, Toast.LENGTH_SHORT)
.show() // TODO better way of showing this. snackbar?
@ -146,6 +162,7 @@ fun ChatScreen(
uiState: ChatActivityUiState,
chatState: ChatState,
threadViewLazyListState: LazyListState,
onSend: () -> Unit,
onRequestFocus: () -> Unit
) {
val isTablet = windowSizeClass.widthSizeClass > WindowWidthSizeClass.Compact
@ -158,12 +175,14 @@ fun ChatScreen(
}
) { innerPadding ->
ChatScreenContent(
modifier = Modifier.fillMaxSize().padding(innerPadding),
modifier = Modifier
.fillMaxSize()
.padding(innerPadding),
isTablet = isTablet,
messages = uiState.messages,
liveResponse = uiState.liveResponse,
uiState = uiState,
chatState = chatState,
threadViewLazyListState = threadViewLazyListState,
onSend = onSend,
onRequestFocus = onRequestFocus
)
}
@ -174,10 +193,10 @@ fun ChatScreen(
fun ChatScreenContent(
modifier: Modifier = Modifier,
isTablet: Boolean,
messages: List<ChatMessage>,
liveResponse: String,
uiState: ChatActivityUiState,
chatState: ChatState,
threadViewLazyListState: LazyListState,
onSend: () -> Unit,
onRequestFocus: () -> Unit
) {
val layout: @Composable (@Composable () -> Unit, @Composable () -> Unit) -> Unit =
@ -206,10 +225,12 @@ fun ChatScreenContent(
layout(
{
ThreadView(
modifier = Modifier.fillMaxSize().padding(horizontal = 24.dp),
modifier = Modifier
.fillMaxSize()
.padding(horizontal = 24.dp),
lazyListState = threadViewLazyListState,
messages = messages,
liveResponse = liveResponse,
messages = uiState.messages,
uiState = uiState,
chatState = chatState
)
},
@ -222,13 +243,14 @@ fun ChatScreenContent(
horizontalAlignment = Alignment.CenterHorizontally
) {
ChatToolBar(
chatState = chatState,
canSend = chatState.composerValue.isNotBlank() && !uiState.requestInProgress,
onSend = onSend,
onEmptySpaceClick = onRequestFocus
)
LanguageModelMistakeWarning(
modifier = Modifier
.padding(vertical = 10.dp) // TODO this is troublesome if there's navigation bar below or any kind of padding
.padding(vertical = 10.dp) // TODO this is troublesome if there's navigation bar below or any kind of padding. But without it, it looks even worse
)
}
}
@ -239,18 +261,12 @@ fun ChatScreenContent(
fun ThreadView(
lazyListState: LazyListState,
messages: List<ChatMessage>,
liveResponse: String,
uiState: ChatActivityUiState,
chatState: ChatState,
modifier: Modifier = Modifier
) {
val localSoftwareKeyboardController = LocalSoftwareKeyboardController.current
val composerValue = if (chatState.requestInProgress) {
chatState.lastPrompt
} else {
chatState.composerValue
}
LazyColumn(
modifier = modifier
.nestedScroll( // Hides the keyboard when scrolling
@ -272,23 +288,21 @@ fun ThreadView(
}
item(key = "composer") {
ChatMessageComposer(
modifier = Modifier
.fillParentMaxHeight() // so that you can click anywhere on the screen to focus the text field
.disableBringIntoViewOnFocus()
.focusRequester(chatState.focusRequester),
value = composerValue,
onValueChange = {
chatState.composerValue = it
},
submitted = chatState.requestInProgress
)
}
if (chatState.requestInProgress) {
item {
ChatMessageResponse(
content = liveResponse
if (!uiState.requestInProgress) {
ChatMessageComposer(
modifier = Modifier
.fillParentMaxHeight() // so that you can click anywhere on the screen to focus the text field
.disableBringIntoViewOnFocus()
.focusRequester(chatState.focusRequester),
value = chatState.composerValue,
onValueChange = {
chatState.composerValue = it
}
)
} else {
// so basically if this was absent, there would be no space anymore to scroll below, and if the conversation were short enough, it would jump to the top because you can't scroll to something that's not there.
Spacer(
modifier = Modifier.fillParentMaxHeight()
)
}
}
@ -300,11 +314,28 @@ fun ChatMessagePrompt(
content: String,
modifier: Modifier = Modifier
) {
// TODO
var animate by rememberSaveable { mutableStateOf(false) }
val textPadding by animateDpAsState(
targetValue = if (animate) 16.dp else 0.dp,
label = "composerTextPaddingAnimation"
)
val textOpacity by animateFloatAsState(
targetValue = if (animate) 0.7f else 1.0f,
label = "composerTextOpacityAnimation"
)
LaunchedEffect(Unit) {
animate = true
}
Text(
text = content,
modifier = modifier
.padding(horizontal = 16.dp),
color = MaterialTheme.colorScheme.onSurface.copy(alpha = 0.7f)
.padding(horizontal = textPadding),
color = MaterialTheme.colorScheme.onSurface.copy(alpha = textOpacity)
)
}
@ -323,30 +354,18 @@ fun ChatMessageResponse(
fun ChatMessageComposer(
value: String,
onValueChange: (String) -> Unit,
submitted: Boolean,
modifier: Modifier = Modifier
) {
val textPadding by animateDpAsState(
targetValue = if (submitted) 16.dp else 0.dp,
label = "composerTextPaddingAnimation"
)
val textOpacity by animateFloatAsState(
targetValue = if (submitted) 0.7f else 1.0f,
label = "composerTextOpacityAnimation"
)
SimpleTextFieldWithPadding(
modifier = modifier,
value = value,
onValueChange = onValueChange,
enabled = !submitted,
placeholder = {
Text("Type your message...") // TODO hide when just browsing history?
},
padding = PaddingValues(vertical = 10.dp, horizontal = textPadding),
padding = PaddingValues(vertical = 10.dp),
textStyle = LocalTextStyle.current.copy(
color = MaterialTheme.colorScheme.onSurface.copy(alpha = textOpacity)
color = MaterialTheme.colorScheme.onSurface
)
)
}
@ -354,11 +373,12 @@ fun ChatMessageComposer(
@Composable
fun ChatToolBar(
modifier: Modifier = Modifier,
chatState: ChatState,
canSend: Boolean,
onSend: () -> Unit,
onEmptySpaceClick: () -> Unit
) {
val sendButtonColor by animateColorAsState(
targetValue = if (chatState.canSend) {
targetValue = if (canSend) {
IconButtonDefaults.iconButtonColors().contentColor
} else {
IconButtonDefaults.iconButtonColors().disabledContentColor
@ -376,11 +396,11 @@ fun ChatToolBar(
horizontalArrangement = Arrangement.End
) {
IconButton(
onClick = chatState::performSend,
onClick = onSend,
modifier = Modifier
.height(48.dp)
.padding(horizontal = 8.dp),
enabled = chatState.canSend,
enabled = canSend,
colors = IconButtonDefaults.iconButtonColors(
contentColor = sendButtonColor,
disabledContentColor = sendButtonColor

View file

@ -1,7 +1,5 @@
package eu.m724.chatapp.activity.chat
sealed interface ChatActivityUiEvent {
data object ProcessingRequest : ChatActivityUiEvent
data class SuccessfulResponse(val message: String): ChatActivityUiEvent
data class Error(val error: String): ChatActivityUiEvent
}

View file

@ -1,6 +1,6 @@
package eu.m724.chatapp.activity.chat
import eu.m724.chatapp.api.data.ChatMessage
import eu.m724.chatapp.api.data.response.completion.ChatMessage
data class ChatActivityUiState(
/**
@ -13,13 +13,5 @@ data class ChatActivityUiState(
*/
val requestInProgress: Boolean = false,
/**
* The response right now, updates when streaming
*/
val liveResponse: String = "",
/**
* All the messages of this chat
*/
val messages: List<ChatMessage> = listOf()
val messages: List<ChatMessage> = emptyList()
)

View file

@ -4,88 +4,112 @@ import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import dagger.hilt.android.lifecycle.HiltViewModel
import eu.m724.chatapp.api.AiApiService
import eu.m724.chatapp.api.data.ChatMessage
import eu.m724.chatapp.api.data.request.ChatCompletionRequest
import eu.m724.chatapp.api.data.request.completion.ChatCompletionRequest
import eu.m724.chatapp.api.data.response.completion.ChatCompletionResponseEvent
import eu.m724.chatapp.api.data.response.completion.ChatMessage
import eu.m724.chatapp.api.retrofit.sse.SseEvent
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.receiveAsFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import javax.inject.Inject
@HiltViewModel
class ChatActivityViewModel @Inject constructor(
private val aiApiService: AiApiService
) : ViewModel() {
private val messages = mutableListOf<ChatMessage>()
private val _uiState = MutableStateFlow(ChatActivityUiState())
val uiState: StateFlow<ChatActivityUiState> = _uiState.asStateFlow()
private val _uiEvents = Channel<ChatActivityUiEvent>()
val uiEvents = _uiEvents.receiveAsFlow()
fun sendMessage(message: String) {
_uiState.update {
var uiState = it.copy(
requestInProgress = true,
liveResponse = "",
)
if (it.chatTitle == null) {
uiState = uiState.copy(chatTitle = message)
}
uiState
}
private val messages = mutableListOf<ChatMessage>()
fun sendMessage(promptContent: String) {
messages.add(ChatMessage(
role = ChatMessage.Role.User,
content = message
content = promptContent
))
viewModelScope.launch {
_uiEvents.send(ChatActivityUiEvent.ProcessingRequest)
var responseContent = ""
val response = aiApiService.chatComplete(ChatCompletionRequest(
model = "free-model",
messages = messages,
temperature = 1.0f,
maxTokens = 128,
frequencyPenalty = 0.0f,
presencePenalty = 0.0f
))
_uiState.update {
it.copy(
requestInProgress = true,
messages = messages + ChatMessage(
role = ChatMessage.Role.Assistant,
content = responseContent
),
chatTitle = it.chatTitle ?: promptContent,
)
}
if (!response.isSuccessful || response.body() == null) {
messages.removeLast()
aiApiService.getChatCompletion(ChatCompletionRequest(
model = "free-model",
messages = messages,
temperature = 1.0f,
maxTokens = 128,
frequencyPenalty = 0.0f,
presencePenalty = 0.0f
)).onEach { event ->
when (event) {
is SseEvent.Open -> {
_uiState.update {
it.copy(
requestInProgress = false
)
}
is SseEvent.Event<ChatCompletionResponseEvent> -> {
event.data.choices?.firstOrNull()?.let { choice ->
if (choice.delta.content != null) {
responseContent += choice.delta.content
_uiEvents.send(ChatActivityUiEvent.Error(response.code().toString()))
_uiState.update {
it.copy(
messages = messages + ChatMessage(
role = ChatMessage.Role.Assistant,
content = responseContent
)
)
}
}
}
}
is SseEvent.Closed -> {
messages.add(ChatMessage(
role = ChatMessage.Role.Assistant,
content = responseContent
))
// TODO launch toast or something
return@launch
_uiState.update {
it.copy(
requestInProgress = false,
messages = messages.toList()
)
}
}
is SseEvent.Failure -> {
// TODO here
println(event.response?.message)
_uiEvents.send(ChatActivityUiEvent.Error(event.error.toString()))
// TODO investigate if closed is called here too
}
}
val completion = response.body()!!
val choice = completion.choices[0]
messages.add(choice.message)
}.catch { exception ->
_uiState.update {
it.copy(
requestInProgress = false,
messages = messages.toList()
requestInProgress = false
)
}
_uiEvents.send(ChatActivityUiEvent.SuccessfulResponse(message))
}
_uiEvents.send(ChatActivityUiEvent.Error(exception.toString()))
// TODO investigate if closed or failure is called
}.launchIn(viewModelScope)
}
}

View file

@ -1,7 +1,6 @@
package eu.m724.chatapp.activity.chat
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
@ -9,25 +8,9 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.focus.FocusRequester
class ChatState(
val focusRequester: FocusRequester,
private val onSend: (String) -> Unit, // Store the lambda
initialRequestInProgress: Boolean
val focusRequester: FocusRequester
) {
var composerValue by mutableStateOf("")
var lastPrompt by mutableStateOf("")
var requestInProgress by mutableStateOf(initialRequestInProgress)
val canSend: Boolean
get() = composerValue.isNotBlank() && !requestInProgress
// This method will be called by the UI (e.g., the send button)
fun performSend() {
if (canSend) {
lastPrompt = composerValue
onSend(composerValue)
composerValue = ""
}
}
fun requestFocus() {
focusRequester.requestFocus()
@ -36,24 +19,13 @@ class ChatState(
companion object {
@Composable
fun rememberChatState(
requestInProgress: Boolean,
onSend: (String) -> Unit, // Takes the message string as a parameter
focusRequester: FocusRequester = remember { FocusRequester() }
): ChatState {
val state = remember {
ChatState(
focusRequester = focusRequester,
onSend = onSend, // Pass the lambda directly
initialRequestInProgress = requestInProgress
)
return remember {
ChatState(focusRequester = focusRequester)
}
LaunchedEffect(requestInProgress) {
state.requestInProgress = requestInProgress
}
return state
}
}
}

View file

@ -19,11 +19,11 @@ import androidx.compose.ui.text.input.VisualTransformation
fun SimpleTextFieldWithPadding(
value: String,
onValueChange: (String) -> Unit,
enabled: Boolean,
placeholder: @Composable () -> Unit,
padding: PaddingValues,
textStyle: TextStyle,
modifier: Modifier = Modifier
modifier: Modifier = Modifier,
enabled: Boolean = true
) {
val interactionSource = remember { MutableInteractionSource() }

View file

@ -59,7 +59,7 @@ import androidx.compose.ui.graphics.Color
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp
import eu.m724.chatapp.api.data.response.LanguageModel
import eu.m724.chatapp.api.data.response.models.LanguageModel
import java.math.RoundingMode
import java.text.DecimalFormat

View file

@ -1,6 +1,6 @@
package eu.m724.chatapp.activity.select
import eu.m724.chatapp.api.data.response.LanguageModel
import eu.m724.chatapp.api.data.response.models.LanguageModel
data class SelectModelUiState(
val models: List<LanguageModel> = listOf()

View file

@ -1,50 +1,123 @@
package eu.m724.chatapp.api
import com.google.gson.FieldNamingPolicy
import com.google.gson.Gson
import com.google.gson.GsonBuilder
import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.components.SingletonComponent
import eu.m724.chatapp.BuildConfig
import eu.m724.chatapp.api.retrofit.interceptor.AiApiRequestExceptionInterceptor
import eu.m724.chatapp.api.retrofit.interceptor.AiApiRequestHeadersInterceptor
import eu.m724.chatapp.api.retrofit.sse.SseCallAdapterFactory
import okhttp3.OkHttpClient
import okhttp3.logging.HttpLoggingInterceptor
import retrofit2.Retrofit
import retrofit2.converter.gson.GsonConverterFactory
import java.util.concurrent.TimeUnit
import javax.inject.Named
import javax.inject.Qualifier
import javax.inject.Singleton
@Module
@InstallIn(SingletonComponent::class)
object AiApiNetworkModule {
@Provides
@Singleton
fun provideOkHttpClient(): OkHttpClient {
val interceptor = AiApiRequestInterceptor(
userAgent = BuildConfig.USER_AGENT,
apiEndpoint = BuildConfig.API_ENDPOINT,
apiKey = BuildConfig.API_KEY
@Named("apiKey")
fun provideApiKey(): String = BuildConfig.API_KEY
@Provides
@Named("apiEndpoint")
fun provideApiEndpoint(): String = BuildConfig.API_ENDPOINT
@Provides
@Named("userAgent")
fun provideUserAgent(): String = BuildConfig.USER_AGENT
@Provides
@Named("isDebug")
fun provideIsDebug(): Boolean = BuildConfig.DEBUG
@Provides
fun provideOkHttpClientBuilder(
@Named("apiKey") apiKey: String,
@Named("apiEndpoint") apiEndpoint: String,
@Named("userAgent") userAgent: String,
@Named("isDebug") isDebug: Boolean
): OkHttpClient.Builder {
val interceptor = AiApiRequestHeadersInterceptor(
userAgent = userAgent,
apiEndpoint = apiEndpoint,
apiKey = apiKey
)
val builder = OkHttpClient.Builder()
.addInterceptor(interceptor)
if (BuildConfig.DEBUG) {
builder.addInterceptor(HttpLoggingInterceptor().setLevel(HttpLoggingInterceptor.Level.BODY))
if (isDebug) {
// level body makes the response buffered which nukes sse
builder.addInterceptor(HttpLoggingInterceptor().apply { level = HttpLoggingInterceptor.Level.HEADERS })
}
return builder.build()
return builder
}
/**
* The standard client is used for standard (non-SSE) requests
*/
@Provides
@Singleton
@StandardClient
fun provideStandardOkHttpClient(
builder: OkHttpClient.Builder,
gson: Gson
): OkHttpClient {
val interceptor = AiApiRequestExceptionInterceptor(
gson = gson
)
return builder
.addInterceptor(interceptor)
.build()
}
/**
* The long-lived client is used for long-lived (SSE) requests
*/
@Provides
@Singleton
@LongLivedClient
fun provideLongLivedOkHttpClient(
builder: OkHttpClient.Builder
): OkHttpClient {
return builder
.readTimeout(0, TimeUnit.SECONDS) // Apparently there are other safe guards against a zombie connection, but I don't know in practice
.connectTimeout(30, TimeUnit.SECONDS)
.build()
}
@Provides
@Singleton
fun provideRetrofit(okHttpClient: OkHttpClient): Retrofit {
val gson = GsonBuilder()
fun provideGson(): Gson {
return GsonBuilder()
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) // snake_case
.create()
}
@Provides
@Singleton
fun provideRetrofit(
@StandardClient standardOkHttpClient: OkHttpClient,
@LongLivedClient longLivedOkHttpClient: OkHttpClient,
gson: Gson,
@Named("apiEndpoint") apiEndpoint: String,
@Named("isDebug") isDebug: Boolean
): Retrofit {
return Retrofit.Builder()
.baseUrl(BuildConfig.API_ENDPOINT)
.client(okHttpClient)
.baseUrl(apiEndpoint)
.client(standardOkHttpClient) // Use the standard client by default
.addCallAdapterFactory(SseCallAdapterFactory(longLivedOkHttpClient, gson, isDebug)) // this intercepts SSE requests and makes them use the long-lived client
.addConverterFactory(GsonConverterFactory.create(gson))
.build()
}
@ -54,4 +127,12 @@ object AiApiNetworkModule {
fun provideAiApiService(retrofit: Retrofit): AiApiService {
return retrofit.create(AiApiService::class.java)
}
}
}
@Qualifier
@Retention(AnnotationRetention.BINARY)
annotation class StandardClient
@Qualifier
@Retention(AnnotationRetention.BINARY)
annotation class LongLivedClient

View file

@ -1,17 +1,21 @@
package eu.m724.chatapp.api
import eu.m724.chatapp.api.data.request.ChatCompletionRequest
import eu.m724.chatapp.api.data.response.ChatCompletionResponse
import eu.m724.chatapp.api.data.response.LanguageModelsResponse
import eu.m724.chatapp.api.data.request.completion.ChatCompletionRequest
import eu.m724.chatapp.api.data.response.completion.ChatCompletionResponseEvent
import eu.m724.chatapp.api.data.response.models.LanguageModelsResponse
import eu.m724.chatapp.api.retrofit.sse.SseEvent
import kotlinx.coroutines.flow.Flow
import retrofit2.Response
import retrofit2.http.Body
import retrofit2.http.GET
import retrofit2.http.POST
import retrofit2.http.Streaming
interface AiApiService {
@GET("models?detailed=true")
suspend fun getModels(): Response<LanguageModelsResponse>
@POST("chat/completions")
suspend fun chatComplete(@Body body: ChatCompletionRequest): Response<ChatCompletionResponse>
@Streaming
fun getChatCompletion(@Body body: ChatCompletionRequest): Flow<SseEvent<ChatCompletionResponseEvent>>
}

View file

@ -0,0 +1,17 @@
package eu.m724.chatapp.api.data
data class AiApiExceptionData(
override val message: String,
val type: String,
val param: String,
val code: Int
) : Exception(message)
data class AiApiExceptionDataWrapper(
val error: AiApiExceptionData
)
class AiApiException(
val httpCode: Int,
val error: AiApiExceptionData?
) : Exception("SSE HTTP Error: $httpCode")

View file

@ -1,6 +1,6 @@
package eu.m724.chatapp.api.data.request
package eu.m724.chatapp.api.data.request.completion
import eu.m724.chatapp.api.data.ChatMessage
import eu.m724.chatapp.api.data.response.completion.ChatMessage
data class ChatCompletionRequest(
/**
@ -40,7 +40,9 @@ data class ChatCompletionRequest(
* Read more: https://www.promptitude.io/glossary/frequency-penalty
* @see frequencyPenalty
*/
val presencePenalty: Float
val presencePenalty: Float,
val stream: Boolean = true
) {
init {
require(temperature >= 0.0) { "temperature must be at least 0.0" }

View file

@ -1,24 +1,13 @@
package eu.m724.chatapp.api.data.response
package eu.m724.chatapp.api.data.response.completion
import com.google.gson.annotations.JsonAdapter
import com.google.gson.annotations.SerializedName
import eu.m724.chatapp.api.data.ChatMessage
import eu.m724.chatapp.api.serialize.EpochSecondToLocalDateTimeDeserializer
import java.time.LocalDateTime
data class ChatCompletionResponse(
data class ChatCompletionResponseEvent(
/**
* Request ID
*/
val id: String,
/**
* Request time
*/
@SerializedName("created")
@JsonAdapter(EpochSecondToLocalDateTimeDeserializer::class)
val createdAt: LocalDateTime,
/**
* Completion choices. Usually has only one element.
*/
@ -35,14 +24,19 @@ data class CompletionChoice(
val index: Int,
/**
* The generated message
* The generated message delta, you should merge it with the previous delta
*/
val message: ChatMessage,
val delta: CompletionChoiceDelta,
/**
* The reason why generating the response has stopped
* The reason why generating the response has stopped. null if the response hasn't finished yet.
*/
val finishReason: CompletionFinishReason
val finishReason: CompletionFinishReason?
)
data class CompletionChoiceDelta(
/** The next generated token, may be null if the response just finished */
val content: String?
)
enum class CompletionFinishReason {

View file

@ -1,4 +1,4 @@
package eu.m724.chatapp.api.data
package eu.m724.chatapp.api.data.response.completion
import com.google.gson.annotations.SerializedName

View file

@ -1,4 +1,4 @@
package eu.m724.chatapp.api.data.response
package eu.m724.chatapp.api.data.response.models
import com.google.gson.annotations.SerializedName

View file

@ -1,4 +1,4 @@
package eu.m724.chatapp.api.serialize
package eu.m724.chatapp.api.data.serialize
import com.google.gson.JsonDeserializationContext
import com.google.gson.JsonDeserializer

View file

@ -0,0 +1,31 @@
package eu.m724.chatapp.api.retrofit.interceptor
import com.google.gson.Gson
import eu.m724.chatapp.api.data.AiApiException
import eu.m724.chatapp.api.data.AiApiExceptionDataWrapper
import okhttp3.Interceptor
import okhttp3.Response
class AiApiRequestExceptionInterceptor(
private val gson: Gson
) : Interceptor {
override fun intercept(chain: Interceptor.Chain): Response {
val request = chain.request()
val response = chain.proceed(request)
if (response.isSuccessful) {
return response
}
response.close()
val apiError =
try {
gson.fromJson(response.body!!.string(), AiApiExceptionDataWrapper::class.java)
} catch (_: Exception) {
null
}?.error
throw AiApiException(response.code, apiError)
}
}

View file

@ -1,9 +1,9 @@
package eu.m724.chatapp.api
package eu.m724.chatapp.api.retrofit.interceptor
import okhttp3.Interceptor
import okhttp3.Response
class AiApiRequestInterceptor(
class AiApiRequestHeadersInterceptor(
private val userAgent: String,
private val apiEndpoint: String,
private val apiKey: String

View file

@ -0,0 +1,86 @@
package eu.m724.chatapp.api.retrofit.sse
import com.google.gson.Gson
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.callbackFlow
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import okhttp3.sse.EventSource
import okhttp3.sse.EventSourceListener
import okhttp3.sse.EventSources
import retrofit2.Call
import retrofit2.CallAdapter
import java.lang.reflect.Type
class SseCallAdapter<T : Any>(
private val client: OkHttpClient,
private val gson: Gson,
private val eventType: Type,
private val debug: Boolean
) : CallAdapter<T, Flow<SseEvent<T>>> {
override fun responseType(): Type = eventType
override fun adapt(call: Call<T>): Flow<SseEvent<T>> {
return callbackFlow {
val listener = object : EventSourceListener() {
override fun onOpen(eventSource: EventSource, response: Response) {
trySend(SseEvent.Open(eventSource, response))
}
override fun onEvent(
eventSource: EventSource,
id: String?,
type: String?,
data: String
) {
if (debug) {
println("raw sse data: " +data)
}
if (data.trim() == "[DONE]") {
// The server is about to close the connection
return
}
try {
val eventData = gson.fromJson(data, eventType) as T?
if (eventData != null) {
trySend(SseEvent.Event(id, type, eventData))
}
} catch (e: Exception) {
val failure = SseEvent.Failure(e, null)
trySend(failure)
close(e)
}
}
override fun onClosed(eventSource: EventSource) {
trySend(SseEvent.Closed)
close() // Close the flow
}
override fun onFailure(
eventSource: EventSource,
t: Throwable?,
response: Response?
) {
// TODO aiapiexception here
val error = t ?: RuntimeException("Unknown SSE error")
trySend(SseEvent.Failure(error, response))
close(error)
}
}
// Create a new request from the Retrofit call
val request: Request = call.request()
val eventSource = EventSources.createFactory(client).newEventSource(request, listener)
// This block is called when the Flow is cancelled
awaitClose {
eventSource.cancel()
}
}
}
}

View file

@ -0,0 +1,37 @@
package eu.m724.chatapp.api.retrofit.sse
import com.google.gson.Gson
import kotlinx.coroutines.flow.Flow
import okhttp3.OkHttpClient
import retrofit2.CallAdapter
import retrofit2.Retrofit
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
class SseCallAdapterFactory(
private val client: OkHttpClient,
private val gson: Gson,
private val debug: Boolean
) : CallAdapter.Factory() {
override fun get(
returnType: Type,
annotations: Array<out Annotation>,
retrofit: Retrofit
): CallAdapter<*, *>? {
// Ensure the return type is a Flow
if (getRawType(returnType) != Flow::class.java) {
println("wrong type: " + getRawType(returnType))
return null
}
// Ensure the Flow's generic type is SseEvent
val flowType = getParameterUpperBound(0, returnType as ParameterizedType)
if (getRawType(flowType) != SseEvent::class.java) {
return null
}
// Get the generic type of SseEvent<T>
val eventType = getParameterUpperBound(0, flowType as ParameterizedType)
return SseCallAdapter<Any>(client, gson, eventType, debug)
}
}

View file

@ -0,0 +1,22 @@
package eu.m724.chatapp.api.retrofit.sse
import okhttp3.Response
import okhttp3.sse.EventSource
/**
* A sealed class representing the different states of an SSE connection.
* @param T The type of the data expected in the event payload.
*/
sealed class SseEvent<out T> {
// The connection was successfully opened
data class Open(val eventSource: EventSource, val response: Response) : SseEvent<Nothing>()
// A new event (message) was received from the server
data class Event<out T>(val id: String?, val name: String?, val data: T) : SseEvent<T>()
// The connection was closed, either by the server or client
object Closed : SseEvent<Nothing>()
// An unrecoverable error occurred
data class Failure(val error: Throwable, val response: Response?) : SseEvent<Nothing>()
}

View file

@ -16,6 +16,7 @@ retrofit = "3.0.0"
secrets = "2.0.1"
loggingInterceptor = "4.12.0"
material3WindowSizeClass = "1.3.2"
okhttpSse = "4.12.0"
[libraries]
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
@ -40,6 +41,7 @@ retrofit = { group = "com.squareup.retrofit2", name = "retrofit", version.ref =
retrofit-converter-gson = { group = "com.squareup.retrofit2", name = "converter-gson", version.ref = "retrofit"}
logging-interceptor = { group = "com.squareup.okhttp3", name = "logging-interceptor", version.ref = "loggingInterceptor" }
androidx-material3-window-size-class1 = { group = "androidx.compose.material3", name = "material3-window-size-class", version.ref = "material3WindowSizeClass" }
okhttp-sse = { group = "com.squareup.okhttp3", name = "okhttp-sse", version.ref = "okhttpSse" }
[plugins]
android-application = { id = "com.android.application", version.ref = "agp" }