diff --git a/app/SseCallAdapterFactory.kt b/app/SseCallAdapterFactory.kt new file mode 100644 index 0000000..88483e9 --- /dev/null +++ b/app/SseCallAdapterFactory.kt @@ -0,0 +1,26 @@ +class SseCallAdapterFactory( + private val client: OkHttpClient, + private val moshi: Moshi // Or Gson +) : CallAdapter.Factory() { + + override fun get( + returnType: Type, + annotations: Array, + retrofit: Retrofit + ): CallAdapter<*, *>? { + // Ensure the return type is a Flow + if (getRawType(returnType) != Flow::class.java) { + return null + } + + // Ensure the Flow's generic type is SseEvent + val flowType = getParameterUpperBound(0, returnType as ParameterizedType) + if (getRawType(flowType) != SseEvent::class.java) { + return null + } + + // Get the generic type of SseEvent + val eventType = getParameterUpperBound(0, flowType as ParameterizedType) + return SseCallAdapter(client, moshi, eventType) + } +} \ No newline at end of file diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 1a10b24..032e261 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -58,6 +58,7 @@ dependencies { implementation(libs.retrofit) implementation(libs.retrofit.converter.gson) implementation(libs.androidx.material3.window.size.class1) + implementation(libs.okhttp.sse) testImplementation(libs.junit) androidTestImplementation(libs.androidx.junit) androidTestImplementation(libs.androidx.espresso.core) diff --git a/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivity.kt b/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivity.kt index 6f029bd..344c524 100644 --- a/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivity.kt +++ b/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivity.kt @@ -15,11 +15,13 @@ import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.imePadding import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.width import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.LazyListState @@ -45,7 +47,11 @@ import androidx.compose.material3.windowsizeclass.calculateWindowSizeClass import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.saveable.rememberSaveable +import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.focus.focusRequester @@ -62,7 +68,7 @@ import eu.m724.chatapp.activity.chat.composable.NestedScrollKeyboardHider import eu.m724.chatapp.activity.chat.composable.SimpleTextFieldWithPadding import eu.m724.chatapp.activity.chat.composable.disableBringIntoViewOnFocus import eu.m724.chatapp.activity.ui.theme.ChatAppTheme -import eu.m724.chatapp.api.data.ChatMessage +import eu.m724.chatapp.api.data.response.completion.ChatMessage import kotlinx.coroutines.launch @AndroidEntryPoint @@ -76,30 +82,36 @@ class ChatActivity : ComponentActivity() { enableEdgeToEdge() setContent { val windowSizeClass = calculateWindowSizeClass(this) + val uiState by viewModel.uiState.collectAsStateWithLifecycle() + val softwareKeyboardController = LocalSoftwareKeyboardController.current val context = LocalContext.current + val coroutineScope = rememberCoroutineScope() - - val chatState = rememberChatState( - requestInProgress = uiState.requestInProgress, - onSend = { message -> - viewModel.sendMessage(message) - } - ) - + val chatState = rememberChatState() val threadViewLazyListState = rememberLazyListState() + val onSend = { + if (chatState.composerValue.isNotBlank() && !uiState.requestInProgress) { + viewModel.sendMessage(chatState.composerValue) + } + } + ChatScreen( windowSizeClass = windowSizeClass, uiState = uiState, chatState = chatState, threadViewLazyListState = threadViewLazyListState, + onSend = onSend, onRequestFocus = { + if (uiState.requestInProgress) return@ChatScreen + coroutineScope.launch { if (threadViewLazyListState.layoutInfo.visibleItemsInfo.find { it.key == "composer" } == null) { if (uiState.messages.isNotEmpty()) { threadViewLazyListState.animateScrollToItem(uiState.messages.size) + // TODO this makes the composer full screen but if the condition above is false it doesn't, that may be kind of unintuitive } } @@ -109,26 +121,30 @@ class ChatActivity : ComponentActivity() { } ) + LaunchedEffect(uiState.requestInProgress) { + if (uiState.requestInProgress) { + chatState.composerValue = "" + + // scroll to the last user message too + threadViewLazyListState.animateScrollToItem(uiState.messages.size - 2) + } else if (uiState.messages.isNotEmpty()) { + // scroll to the last user message too + threadViewLazyListState.animateScrollToItem(uiState.messages.size - 2) + + // if the composer is visible (message is short enough), focus on it + // if the message is long, we let the user read it + threadViewLazyListState.layoutInfo.visibleItemsInfo.firstOrNull { + it.key == "composer" + }?.let { + chatState.requestFocus() + softwareKeyboardController?.show() // TODO perhaps it's pointless to focus since we can click on the toolbar? maybe make it configurable + } + } + } + LaunchedEffect(Unit) { viewModel.uiEvents.collect { event -> when (event) { - is ChatActivityUiEvent.ProcessingRequest -> { - threadViewLazyListState.animateScrollToItem(uiState.messages.size) - } - is ChatActivityUiEvent.SuccessfulResponse -> { - chatState.composerValue = "" - - if (uiState.messages.isNotEmpty()) { - threadViewLazyListState.animateScrollToItem(uiState.messages.size - 2) - } - - threadViewLazyListState.layoutInfo.visibleItemsInfo.firstOrNull { - it.key == "composer" - }?.let { - chatState.requestFocus() - softwareKeyboardController?.show() // TODO perhaps it's pointless to focus since we can click on the toolbar? - } - } is ChatActivityUiEvent.Error -> { Toast.makeText(context, event.error, Toast.LENGTH_SHORT) .show() // TODO better way of showing this. snackbar? @@ -146,6 +162,7 @@ fun ChatScreen( uiState: ChatActivityUiState, chatState: ChatState, threadViewLazyListState: LazyListState, + onSend: () -> Unit, onRequestFocus: () -> Unit ) { val isTablet = windowSizeClass.widthSizeClass > WindowWidthSizeClass.Compact @@ -158,12 +175,14 @@ fun ChatScreen( } ) { innerPadding -> ChatScreenContent( - modifier = Modifier.fillMaxSize().padding(innerPadding), + modifier = Modifier + .fillMaxSize() + .padding(innerPadding), isTablet = isTablet, - messages = uiState.messages, - liveResponse = uiState.liveResponse, + uiState = uiState, chatState = chatState, threadViewLazyListState = threadViewLazyListState, + onSend = onSend, onRequestFocus = onRequestFocus ) } @@ -174,10 +193,10 @@ fun ChatScreen( fun ChatScreenContent( modifier: Modifier = Modifier, isTablet: Boolean, - messages: List, - liveResponse: String, + uiState: ChatActivityUiState, chatState: ChatState, threadViewLazyListState: LazyListState, + onSend: () -> Unit, onRequestFocus: () -> Unit ) { val layout: @Composable (@Composable () -> Unit, @Composable () -> Unit) -> Unit = @@ -206,10 +225,12 @@ fun ChatScreenContent( layout( { ThreadView( - modifier = Modifier.fillMaxSize().padding(horizontal = 24.dp), + modifier = Modifier + .fillMaxSize() + .padding(horizontal = 24.dp), lazyListState = threadViewLazyListState, - messages = messages, - liveResponse = liveResponse, + messages = uiState.messages, + uiState = uiState, chatState = chatState ) }, @@ -222,13 +243,14 @@ fun ChatScreenContent( horizontalAlignment = Alignment.CenterHorizontally ) { ChatToolBar( - chatState = chatState, + canSend = chatState.composerValue.isNotBlank() && !uiState.requestInProgress, + onSend = onSend, onEmptySpaceClick = onRequestFocus ) LanguageModelMistakeWarning( modifier = Modifier - .padding(vertical = 10.dp) // TODO this is troublesome if there's navigation bar below or any kind of padding + .padding(vertical = 10.dp) // TODO this is troublesome if there's navigation bar below or any kind of padding. But without it, it looks even worse ) } } @@ -239,18 +261,12 @@ fun ChatScreenContent( fun ThreadView( lazyListState: LazyListState, messages: List, - liveResponse: String, + uiState: ChatActivityUiState, chatState: ChatState, modifier: Modifier = Modifier ) { val localSoftwareKeyboardController = LocalSoftwareKeyboardController.current - val composerValue = if (chatState.requestInProgress) { - chatState.lastPrompt - } else { - chatState.composerValue - } - LazyColumn( modifier = modifier .nestedScroll( // Hides the keyboard when scrolling @@ -272,23 +288,21 @@ fun ThreadView( } item(key = "composer") { - ChatMessageComposer( - modifier = Modifier - .fillParentMaxHeight() // so that you can click anywhere on the screen to focus the text field - .disableBringIntoViewOnFocus() - .focusRequester(chatState.focusRequester), - value = composerValue, - onValueChange = { - chatState.composerValue = it - }, - submitted = chatState.requestInProgress - ) - } - - if (chatState.requestInProgress) { - item { - ChatMessageResponse( - content = liveResponse + if (!uiState.requestInProgress) { + ChatMessageComposer( + modifier = Modifier + .fillParentMaxHeight() // so that you can click anywhere on the screen to focus the text field + .disableBringIntoViewOnFocus() + .focusRequester(chatState.focusRequester), + value = chatState.composerValue, + onValueChange = { + chatState.composerValue = it + } + ) + } else { + // so basically if this was absent, there would be no space anymore to scroll below, and if the conversation were short enough, it would jump to the top because you can't scroll to something that's not there. + Spacer( + modifier = Modifier.fillParentMaxHeight() ) } } @@ -300,11 +314,28 @@ fun ChatMessagePrompt( content: String, modifier: Modifier = Modifier ) { + // TODO + var animate by rememberSaveable { mutableStateOf(false) } + + val textPadding by animateDpAsState( + targetValue = if (animate) 16.dp else 0.dp, + label = "composerTextPaddingAnimation" + ) + + val textOpacity by animateFloatAsState( + targetValue = if (animate) 0.7f else 1.0f, + label = "composerTextOpacityAnimation" + ) + + LaunchedEffect(Unit) { + animate = true + } + Text( text = content, modifier = modifier - .padding(horizontal = 16.dp), - color = MaterialTheme.colorScheme.onSurface.copy(alpha = 0.7f) + .padding(horizontal = textPadding), + color = MaterialTheme.colorScheme.onSurface.copy(alpha = textOpacity) ) } @@ -323,30 +354,18 @@ fun ChatMessageResponse( fun ChatMessageComposer( value: String, onValueChange: (String) -> Unit, - submitted: Boolean, modifier: Modifier = Modifier ) { - val textPadding by animateDpAsState( - targetValue = if (submitted) 16.dp else 0.dp, - label = "composerTextPaddingAnimation" - ) - - val textOpacity by animateFloatAsState( - targetValue = if (submitted) 0.7f else 1.0f, - label = "composerTextOpacityAnimation" - ) - SimpleTextFieldWithPadding( modifier = modifier, value = value, onValueChange = onValueChange, - enabled = !submitted, placeholder = { Text("Type your message...") // TODO hide when just browsing history? }, - padding = PaddingValues(vertical = 10.dp, horizontal = textPadding), + padding = PaddingValues(vertical = 10.dp), textStyle = LocalTextStyle.current.copy( - color = MaterialTheme.colorScheme.onSurface.copy(alpha = textOpacity) + color = MaterialTheme.colorScheme.onSurface ) ) } @@ -354,11 +373,12 @@ fun ChatMessageComposer( @Composable fun ChatToolBar( modifier: Modifier = Modifier, - chatState: ChatState, + canSend: Boolean, + onSend: () -> Unit, onEmptySpaceClick: () -> Unit ) { val sendButtonColor by animateColorAsState( - targetValue = if (chatState.canSend) { + targetValue = if (canSend) { IconButtonDefaults.iconButtonColors().contentColor } else { IconButtonDefaults.iconButtonColors().disabledContentColor @@ -376,11 +396,11 @@ fun ChatToolBar( horizontalArrangement = Arrangement.End ) { IconButton( - onClick = chatState::performSend, + onClick = onSend, modifier = Modifier .height(48.dp) .padding(horizontal = 8.dp), - enabled = chatState.canSend, + enabled = canSend, colors = IconButtonDefaults.iconButtonColors( contentColor = sendButtonColor, disabledContentColor = sendButtonColor diff --git a/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivityUiEvent.kt b/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivityUiEvent.kt index 0e18fc1..55b16c9 100644 --- a/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivityUiEvent.kt +++ b/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivityUiEvent.kt @@ -1,7 +1,5 @@ package eu.m724.chatapp.activity.chat sealed interface ChatActivityUiEvent { - data object ProcessingRequest : ChatActivityUiEvent - data class SuccessfulResponse(val message: String): ChatActivityUiEvent data class Error(val error: String): ChatActivityUiEvent } \ No newline at end of file diff --git a/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivityUiState.kt b/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivityUiState.kt index 2f11741..1cbb3b5 100644 --- a/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivityUiState.kt +++ b/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivityUiState.kt @@ -1,6 +1,6 @@ package eu.m724.chatapp.activity.chat -import eu.m724.chatapp.api.data.ChatMessage +import eu.m724.chatapp.api.data.response.completion.ChatMessage data class ChatActivityUiState( /** @@ -13,13 +13,5 @@ data class ChatActivityUiState( */ val requestInProgress: Boolean = false, - /** - * The response right now, updates when streaming - */ - val liveResponse: String = "", - - /** - * All the messages of this chat - */ - val messages: List = listOf() + val messages: List = emptyList() ) \ No newline at end of file diff --git a/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivityViewModel.kt b/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivityViewModel.kt index 85922c1..5c32674 100644 --- a/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivityViewModel.kt +++ b/app/src/main/java/eu/m724/chatapp/activity/chat/ChatActivityViewModel.kt @@ -4,88 +4,112 @@ import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import dagger.hilt.android.lifecycle.HiltViewModel import eu.m724.chatapp.api.AiApiService -import eu.m724.chatapp.api.data.ChatMessage -import eu.m724.chatapp.api.data.request.ChatCompletionRequest +import eu.m724.chatapp.api.data.request.completion.ChatCompletionRequest +import eu.m724.chatapp.api.data.response.completion.ChatCompletionResponseEvent +import eu.m724.chatapp.api.data.response.completion.ChatMessage +import eu.m724.chatapp.api.retrofit.sse.SseEvent import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.launchIn +import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.receiveAsFlow import kotlinx.coroutines.flow.update -import kotlinx.coroutines.launch import javax.inject.Inject @HiltViewModel class ChatActivityViewModel @Inject constructor( private val aiApiService: AiApiService ) : ViewModel() { - private val messages = mutableListOf() - private val _uiState = MutableStateFlow(ChatActivityUiState()) val uiState: StateFlow = _uiState.asStateFlow() private val _uiEvents = Channel() val uiEvents = _uiEvents.receiveAsFlow() - fun sendMessage(message: String) { - _uiState.update { - var uiState = it.copy( - requestInProgress = true, - liveResponse = "", - ) - - if (it.chatTitle == null) { - uiState = uiState.copy(chatTitle = message) - } - - uiState - } + private val messages = mutableListOf() + fun sendMessage(promptContent: String) { messages.add(ChatMessage( role = ChatMessage.Role.User, - content = message + content = promptContent )) - viewModelScope.launch { - _uiEvents.send(ChatActivityUiEvent.ProcessingRequest) + var responseContent = "" - val response = aiApiService.chatComplete(ChatCompletionRequest( - model = "free-model", - messages = messages, - temperature = 1.0f, - maxTokens = 128, - frequencyPenalty = 0.0f, - presencePenalty = 0.0f - )) + _uiState.update { + it.copy( + requestInProgress = true, + messages = messages + ChatMessage( + role = ChatMessage.Role.Assistant, + content = responseContent + ), + chatTitle = it.chatTitle ?: promptContent, + ) + } - if (!response.isSuccessful || response.body() == null) { - messages.removeLast() + aiApiService.getChatCompletion(ChatCompletionRequest( + model = "free-model", + messages = messages, + temperature = 1.0f, + maxTokens = 128, + frequencyPenalty = 0.0f, + presencePenalty = 0.0f + )).onEach { event -> + when (event) { + is SseEvent.Open -> { - _uiState.update { - it.copy( - requestInProgress = false - ) } + is SseEvent.Event -> { + event.data.choices?.firstOrNull()?.let { choice -> + if (choice.delta.content != null) { + responseContent += choice.delta.content - _uiEvents.send(ChatActivityUiEvent.Error(response.code().toString())) + _uiState.update { + it.copy( + messages = messages + ChatMessage( + role = ChatMessage.Role.Assistant, + content = responseContent + ) + ) + } + } + } + } + is SseEvent.Closed -> { + messages.add(ChatMessage( + role = ChatMessage.Role.Assistant, + content = responseContent + )) - // TODO launch toast or something - return@launch + _uiState.update { + it.copy( + requestInProgress = false, + messages = messages.toList() + ) + } + } + is SseEvent.Failure -> { + // TODO here + println(event.response?.message) + + _uiEvents.send(ChatActivityUiEvent.Error(event.error.toString())) + + // TODO investigate if closed is called here too + } } - - val completion = response.body()!! - val choice = completion.choices[0] - - messages.add(choice.message) - + }.catch { exception -> _uiState.update { it.copy( - requestInProgress = false, - messages = messages.toList() + requestInProgress = false ) } - _uiEvents.send(ChatActivityUiEvent.SuccessfulResponse(message)) - } + _uiEvents.send(ChatActivityUiEvent.Error(exception.toString())) + + // TODO investigate if closed or failure is called + }.launchIn(viewModelScope) } } \ No newline at end of file diff --git a/app/src/main/java/eu/m724/chatapp/activity/chat/ChatState.kt b/app/src/main/java/eu/m724/chatapp/activity/chat/ChatState.kt index 541c948..bf775d7 100644 --- a/app/src/main/java/eu/m724/chatapp/activity/chat/ChatState.kt +++ b/app/src/main/java/eu/m724/chatapp/activity/chat/ChatState.kt @@ -1,7 +1,6 @@ package eu.m724.chatapp.activity.chat import androidx.compose.runtime.Composable -import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember @@ -9,25 +8,9 @@ import androidx.compose.runtime.setValue import androidx.compose.ui.focus.FocusRequester class ChatState( - val focusRequester: FocusRequester, - private val onSend: (String) -> Unit, // Store the lambda - initialRequestInProgress: Boolean + val focusRequester: FocusRequester ) { var composerValue by mutableStateOf("") - var lastPrompt by mutableStateOf("") - var requestInProgress by mutableStateOf(initialRequestInProgress) - - val canSend: Boolean - get() = composerValue.isNotBlank() && !requestInProgress - - // This method will be called by the UI (e.g., the send button) - fun performSend() { - if (canSend) { - lastPrompt = composerValue - onSend(composerValue) - composerValue = "" - } - } fun requestFocus() { focusRequester.requestFocus() @@ -36,24 +19,13 @@ class ChatState( companion object { @Composable fun rememberChatState( - requestInProgress: Boolean, - onSend: (String) -> Unit, // Takes the message string as a parameter focusRequester: FocusRequester = remember { FocusRequester() } ): ChatState { - val state = remember { - ChatState( - focusRequester = focusRequester, - onSend = onSend, // Pass the lambda directly - initialRequestInProgress = requestInProgress - ) + return remember { + ChatState(focusRequester = focusRequester) } - - LaunchedEffect(requestInProgress) { - state.requestInProgress = requestInProgress - } - - return state } } } + diff --git a/app/src/main/java/eu/m724/chatapp/activity/chat/composable/SimpleTextFieldWithPadding.kt b/app/src/main/java/eu/m724/chatapp/activity/chat/composable/SimpleTextFieldWithPadding.kt index 9b7c52a..0ce166f 100644 --- a/app/src/main/java/eu/m724/chatapp/activity/chat/composable/SimpleTextFieldWithPadding.kt +++ b/app/src/main/java/eu/m724/chatapp/activity/chat/composable/SimpleTextFieldWithPadding.kt @@ -19,11 +19,11 @@ import androidx.compose.ui.text.input.VisualTransformation fun SimpleTextFieldWithPadding( value: String, onValueChange: (String) -> Unit, - enabled: Boolean, placeholder: @Composable () -> Unit, padding: PaddingValues, textStyle: TextStyle, - modifier: Modifier = Modifier + modifier: Modifier = Modifier, + enabled: Boolean = true ) { val interactionSource = remember { MutableInteractionSource() } diff --git a/app/src/main/java/eu/m724/chatapp/activity/select/SelectModelActivity.kt b/app/src/main/java/eu/m724/chatapp/activity/select/SelectModelActivity.kt index 880f906..246c03e 100644 --- a/app/src/main/java/eu/m724/chatapp/activity/select/SelectModelActivity.kt +++ b/app/src/main/java/eu/m724/chatapp/activity/select/SelectModelActivity.kt @@ -59,7 +59,7 @@ import androidx.compose.ui.graphics.Color import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp -import eu.m724.chatapp.api.data.response.LanguageModel +import eu.m724.chatapp.api.data.response.models.LanguageModel import java.math.RoundingMode import java.text.DecimalFormat diff --git a/app/src/main/java/eu/m724/chatapp/activity/select/SelectModelUiState.kt b/app/src/main/java/eu/m724/chatapp/activity/select/SelectModelUiState.kt index cd8cdf6..46468d9 100644 --- a/app/src/main/java/eu/m724/chatapp/activity/select/SelectModelUiState.kt +++ b/app/src/main/java/eu/m724/chatapp/activity/select/SelectModelUiState.kt @@ -1,6 +1,6 @@ package eu.m724.chatapp.activity.select -import eu.m724.chatapp.api.data.response.LanguageModel +import eu.m724.chatapp.api.data.response.models.LanguageModel data class SelectModelUiState( val models: List = listOf() diff --git a/app/src/main/java/eu/m724/chatapp/api/AiApiNetworkModule.kt b/app/src/main/java/eu/m724/chatapp/api/AiApiNetworkModule.kt index 708d09f..a6372f8 100644 --- a/app/src/main/java/eu/m724/chatapp/api/AiApiNetworkModule.kt +++ b/app/src/main/java/eu/m724/chatapp/api/AiApiNetworkModule.kt @@ -1,50 +1,123 @@ package eu.m724.chatapp.api import com.google.gson.FieldNamingPolicy +import com.google.gson.Gson import com.google.gson.GsonBuilder import dagger.Module import dagger.Provides import dagger.hilt.InstallIn import dagger.hilt.components.SingletonComponent import eu.m724.chatapp.BuildConfig +import eu.m724.chatapp.api.retrofit.interceptor.AiApiRequestExceptionInterceptor +import eu.m724.chatapp.api.retrofit.interceptor.AiApiRequestHeadersInterceptor +import eu.m724.chatapp.api.retrofit.sse.SseCallAdapterFactory import okhttp3.OkHttpClient import okhttp3.logging.HttpLoggingInterceptor import retrofit2.Retrofit import retrofit2.converter.gson.GsonConverterFactory +import java.util.concurrent.TimeUnit +import javax.inject.Named +import javax.inject.Qualifier import javax.inject.Singleton @Module @InstallIn(SingletonComponent::class) object AiApiNetworkModule { @Provides - @Singleton - fun provideOkHttpClient(): OkHttpClient { - val interceptor = AiApiRequestInterceptor( - userAgent = BuildConfig.USER_AGENT, - apiEndpoint = BuildConfig.API_ENDPOINT, - apiKey = BuildConfig.API_KEY + @Named("apiKey") + fun provideApiKey(): String = BuildConfig.API_KEY + + @Provides + @Named("apiEndpoint") + fun provideApiEndpoint(): String = BuildConfig.API_ENDPOINT + + @Provides + @Named("userAgent") + fun provideUserAgent(): String = BuildConfig.USER_AGENT + + @Provides + @Named("isDebug") + fun provideIsDebug(): Boolean = BuildConfig.DEBUG + + @Provides + fun provideOkHttpClientBuilder( + @Named("apiKey") apiKey: String, + @Named("apiEndpoint") apiEndpoint: String, + @Named("userAgent") userAgent: String, + @Named("isDebug") isDebug: Boolean + ): OkHttpClient.Builder { + val interceptor = AiApiRequestHeadersInterceptor( + userAgent = userAgent, + apiEndpoint = apiEndpoint, + apiKey = apiKey ) val builder = OkHttpClient.Builder() .addInterceptor(interceptor) - if (BuildConfig.DEBUG) { - builder.addInterceptor(HttpLoggingInterceptor().setLevel(HttpLoggingInterceptor.Level.BODY)) + if (isDebug) { + // level body makes the response buffered which nukes sse + builder.addInterceptor(HttpLoggingInterceptor().apply { level = HttpLoggingInterceptor.Level.HEADERS }) } - return builder.build() + return builder + } + + /** + * The standard client is used for standard (non-SSE) requests + */ + @Provides + @Singleton + @StandardClient + fun provideStandardOkHttpClient( + builder: OkHttpClient.Builder, + gson: Gson + ): OkHttpClient { + val interceptor = AiApiRequestExceptionInterceptor( + gson = gson + ) + + return builder + .addInterceptor(interceptor) + .build() + } + + /** + * The long-lived client is used for long-lived (SSE) requests + */ + @Provides + @Singleton + @LongLivedClient + fun provideLongLivedOkHttpClient( + builder: OkHttpClient.Builder + ): OkHttpClient { + return builder + .readTimeout(0, TimeUnit.SECONDS) // Apparently there are other safe guards against a zombie connection, but I don't know in practice + .connectTimeout(30, TimeUnit.SECONDS) + .build() } @Provides @Singleton - fun provideRetrofit(okHttpClient: OkHttpClient): Retrofit { - val gson = GsonBuilder() + fun provideGson(): Gson { + return GsonBuilder() .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) // snake_case .create() + } + @Provides + @Singleton + fun provideRetrofit( + @StandardClient standardOkHttpClient: OkHttpClient, + @LongLivedClient longLivedOkHttpClient: OkHttpClient, + gson: Gson, + @Named("apiEndpoint") apiEndpoint: String, + @Named("isDebug") isDebug: Boolean + ): Retrofit { return Retrofit.Builder() - .baseUrl(BuildConfig.API_ENDPOINT) - .client(okHttpClient) + .baseUrl(apiEndpoint) + .client(standardOkHttpClient) // Use the standard client by default + .addCallAdapterFactory(SseCallAdapterFactory(longLivedOkHttpClient, gson, isDebug)) // this intercepts SSE requests and makes them use the long-lived client .addConverterFactory(GsonConverterFactory.create(gson)) .build() } @@ -54,4 +127,12 @@ object AiApiNetworkModule { fun provideAiApiService(retrofit: Retrofit): AiApiService { return retrofit.create(AiApiService::class.java) } -} \ No newline at end of file +} + +@Qualifier +@Retention(AnnotationRetention.BINARY) +annotation class StandardClient + +@Qualifier +@Retention(AnnotationRetention.BINARY) +annotation class LongLivedClient \ No newline at end of file diff --git a/app/src/main/java/eu/m724/chatapp/api/AiApiService.kt b/app/src/main/java/eu/m724/chatapp/api/AiApiService.kt index b4fb195..efc69c4 100644 --- a/app/src/main/java/eu/m724/chatapp/api/AiApiService.kt +++ b/app/src/main/java/eu/m724/chatapp/api/AiApiService.kt @@ -1,17 +1,21 @@ package eu.m724.chatapp.api -import eu.m724.chatapp.api.data.request.ChatCompletionRequest -import eu.m724.chatapp.api.data.response.ChatCompletionResponse -import eu.m724.chatapp.api.data.response.LanguageModelsResponse +import eu.m724.chatapp.api.data.request.completion.ChatCompletionRequest +import eu.m724.chatapp.api.data.response.completion.ChatCompletionResponseEvent +import eu.m724.chatapp.api.data.response.models.LanguageModelsResponse +import eu.m724.chatapp.api.retrofit.sse.SseEvent +import kotlinx.coroutines.flow.Flow import retrofit2.Response import retrofit2.http.Body import retrofit2.http.GET import retrofit2.http.POST +import retrofit2.http.Streaming interface AiApiService { @GET("models?detailed=true") suspend fun getModels(): Response @POST("chat/completions") - suspend fun chatComplete(@Body body: ChatCompletionRequest): Response + @Streaming + fun getChatCompletion(@Body body: ChatCompletionRequest): Flow> } \ No newline at end of file diff --git a/app/src/main/java/eu/m724/chatapp/api/data/AiApiException.kt b/app/src/main/java/eu/m724/chatapp/api/data/AiApiException.kt new file mode 100644 index 0000000..90630a5 --- /dev/null +++ b/app/src/main/java/eu/m724/chatapp/api/data/AiApiException.kt @@ -0,0 +1,17 @@ +package eu.m724.chatapp.api.data + +data class AiApiExceptionData( + override val message: String, + val type: String, + val param: String, + val code: Int +) : Exception(message) + +data class AiApiExceptionDataWrapper( + val error: AiApiExceptionData +) + +class AiApiException( + val httpCode: Int, + val error: AiApiExceptionData? +) : Exception("SSE HTTP Error: $httpCode") \ No newline at end of file diff --git a/app/src/main/java/eu/m724/chatapp/api/data/request/ChatCompletionRequest.kt b/app/src/main/java/eu/m724/chatapp/api/data/request/completion/ChatCompletionRequest.kt similarity index 90% rename from app/src/main/java/eu/m724/chatapp/api/data/request/ChatCompletionRequest.kt rename to app/src/main/java/eu/m724/chatapp/api/data/request/completion/ChatCompletionRequest.kt index 85eb40f..cc3c5dc 100644 --- a/app/src/main/java/eu/m724/chatapp/api/data/request/ChatCompletionRequest.kt +++ b/app/src/main/java/eu/m724/chatapp/api/data/request/completion/ChatCompletionRequest.kt @@ -1,6 +1,6 @@ -package eu.m724.chatapp.api.data.request +package eu.m724.chatapp.api.data.request.completion -import eu.m724.chatapp.api.data.ChatMessage +import eu.m724.chatapp.api.data.response.completion.ChatMessage data class ChatCompletionRequest( /** @@ -40,7 +40,9 @@ data class ChatCompletionRequest( * Read more: https://www.promptitude.io/glossary/frequency-penalty * @see frequencyPenalty */ - val presencePenalty: Float + val presencePenalty: Float, + + val stream: Boolean = true ) { init { require(temperature >= 0.0) { "temperature must be at least 0.0" } diff --git a/app/src/main/java/eu/m724/chatapp/api/data/response/ChatCompletionResponse.kt b/app/src/main/java/eu/m724/chatapp/api/data/response/completion/ChatCompletionResponseEvent.kt similarity index 65% rename from app/src/main/java/eu/m724/chatapp/api/data/response/ChatCompletionResponse.kt rename to app/src/main/java/eu/m724/chatapp/api/data/response/completion/ChatCompletionResponseEvent.kt index 60e6b9b..0315fd5 100644 --- a/app/src/main/java/eu/m724/chatapp/api/data/response/ChatCompletionResponse.kt +++ b/app/src/main/java/eu/m724/chatapp/api/data/response/completion/ChatCompletionResponseEvent.kt @@ -1,24 +1,13 @@ -package eu.m724.chatapp.api.data.response +package eu.m724.chatapp.api.data.response.completion -import com.google.gson.annotations.JsonAdapter import com.google.gson.annotations.SerializedName -import eu.m724.chatapp.api.data.ChatMessage -import eu.m724.chatapp.api.serialize.EpochSecondToLocalDateTimeDeserializer -import java.time.LocalDateTime -data class ChatCompletionResponse( +data class ChatCompletionResponseEvent( /** * Request ID */ val id: String, - /** - * Request time - */ - @SerializedName("created") - @JsonAdapter(EpochSecondToLocalDateTimeDeserializer::class) - val createdAt: LocalDateTime, - /** * Completion choices. Usually has only one element. */ @@ -35,14 +24,19 @@ data class CompletionChoice( val index: Int, /** - * The generated message + * The generated message delta, you should merge it with the previous delta */ - val message: ChatMessage, + val delta: CompletionChoiceDelta, /** - * The reason why generating the response has stopped + * The reason why generating the response has stopped. null if the response hasn't finished yet. */ - val finishReason: CompletionFinishReason + val finishReason: CompletionFinishReason? +) + +data class CompletionChoiceDelta( + /** The next generated token, may be null if the response just finished */ + val content: String? ) enum class CompletionFinishReason { diff --git a/app/src/main/java/eu/m724/chatapp/api/data/ChatMessage.kt b/app/src/main/java/eu/m724/chatapp/api/data/response/completion/ChatMessage.kt similarity index 85% rename from app/src/main/java/eu/m724/chatapp/api/data/ChatMessage.kt rename to app/src/main/java/eu/m724/chatapp/api/data/response/completion/ChatMessage.kt index d4887f7..b85971c 100644 --- a/app/src/main/java/eu/m724/chatapp/api/data/ChatMessage.kt +++ b/app/src/main/java/eu/m724/chatapp/api/data/response/completion/ChatMessage.kt @@ -1,4 +1,4 @@ -package eu.m724.chatapp.api.data +package eu.m724.chatapp.api.data.response.completion import com.google.gson.annotations.SerializedName diff --git a/app/src/main/java/eu/m724/chatapp/api/data/response/LanguageModelsResponse.kt b/app/src/main/java/eu/m724/chatapp/api/data/response/models/LanguageModelsResponse.kt similarity index 95% rename from app/src/main/java/eu/m724/chatapp/api/data/response/LanguageModelsResponse.kt rename to app/src/main/java/eu/m724/chatapp/api/data/response/models/LanguageModelsResponse.kt index 1a35506..4e8d69b 100644 --- a/app/src/main/java/eu/m724/chatapp/api/data/response/LanguageModelsResponse.kt +++ b/app/src/main/java/eu/m724/chatapp/api/data/response/models/LanguageModelsResponse.kt @@ -1,4 +1,4 @@ -package eu.m724.chatapp.api.data.response +package eu.m724.chatapp.api.data.response.models import com.google.gson.annotations.SerializedName diff --git a/app/src/main/java/eu/m724/chatapp/api/serialize/EpochSecondToLocalDateTimeDeserializer.kt b/app/src/main/java/eu/m724/chatapp/api/data/serialize/EpochSecondToLocalDateTimeDeserializer.kt similarity index 94% rename from app/src/main/java/eu/m724/chatapp/api/serialize/EpochSecondToLocalDateTimeDeserializer.kt rename to app/src/main/java/eu/m724/chatapp/api/data/serialize/EpochSecondToLocalDateTimeDeserializer.kt index ceafc3b..3471cae 100644 --- a/app/src/main/java/eu/m724/chatapp/api/serialize/EpochSecondToLocalDateTimeDeserializer.kt +++ b/app/src/main/java/eu/m724/chatapp/api/data/serialize/EpochSecondToLocalDateTimeDeserializer.kt @@ -1,4 +1,4 @@ -package eu.m724.chatapp.api.serialize +package eu.m724.chatapp.api.data.serialize import com.google.gson.JsonDeserializationContext import com.google.gson.JsonDeserializer diff --git a/app/src/main/java/eu/m724/chatapp/api/retrofit/interceptor/AiApiRequestExceptionInterceptor.kt b/app/src/main/java/eu/m724/chatapp/api/retrofit/interceptor/AiApiRequestExceptionInterceptor.kt new file mode 100644 index 0000000..2d10844 --- /dev/null +++ b/app/src/main/java/eu/m724/chatapp/api/retrofit/interceptor/AiApiRequestExceptionInterceptor.kt @@ -0,0 +1,31 @@ +package eu.m724.chatapp.api.retrofit.interceptor + +import com.google.gson.Gson +import eu.m724.chatapp.api.data.AiApiException +import eu.m724.chatapp.api.data.AiApiExceptionDataWrapper +import okhttp3.Interceptor +import okhttp3.Response + +class AiApiRequestExceptionInterceptor( + private val gson: Gson +) : Interceptor { + override fun intercept(chain: Interceptor.Chain): Response { + val request = chain.request() + val response = chain.proceed(request) + + if (response.isSuccessful) { + return response + } + + response.close() + + val apiError = + try { + gson.fromJson(response.body!!.string(), AiApiExceptionDataWrapper::class.java) + } catch (_: Exception) { + null + }?.error + + throw AiApiException(response.code, apiError) + } +} \ No newline at end of file diff --git a/app/src/main/java/eu/m724/chatapp/api/AiApiRequestInterceptor.kt b/app/src/main/java/eu/m724/chatapp/api/retrofit/interceptor/AiApiRequestHeadersInterceptor.kt similarity index 85% rename from app/src/main/java/eu/m724/chatapp/api/AiApiRequestInterceptor.kt rename to app/src/main/java/eu/m724/chatapp/api/retrofit/interceptor/AiApiRequestHeadersInterceptor.kt index 8056556..0c4f453 100644 --- a/app/src/main/java/eu/m724/chatapp/api/AiApiRequestInterceptor.kt +++ b/app/src/main/java/eu/m724/chatapp/api/retrofit/interceptor/AiApiRequestHeadersInterceptor.kt @@ -1,9 +1,9 @@ -package eu.m724.chatapp.api +package eu.m724.chatapp.api.retrofit.interceptor import okhttp3.Interceptor import okhttp3.Response -class AiApiRequestInterceptor( +class AiApiRequestHeadersInterceptor( private val userAgent: String, private val apiEndpoint: String, private val apiKey: String diff --git a/app/src/main/java/eu/m724/chatapp/api/retrofit/sse/SseCallAdapter.kt b/app/src/main/java/eu/m724/chatapp/api/retrofit/sse/SseCallAdapter.kt new file mode 100644 index 0000000..0d9267f --- /dev/null +++ b/app/src/main/java/eu/m724/chatapp/api/retrofit/sse/SseCallAdapter.kt @@ -0,0 +1,86 @@ +package eu.m724.chatapp.api.retrofit.sse + +import com.google.gson.Gson +import kotlinx.coroutines.channels.awaitClose +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.callbackFlow +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.Response +import okhttp3.sse.EventSource +import okhttp3.sse.EventSourceListener +import okhttp3.sse.EventSources +import retrofit2.Call +import retrofit2.CallAdapter +import java.lang.reflect.Type + +class SseCallAdapter( + private val client: OkHttpClient, + private val gson: Gson, + private val eventType: Type, + private val debug: Boolean +) : CallAdapter>> { + override fun responseType(): Type = eventType + + override fun adapt(call: Call): Flow> { + return callbackFlow { + val listener = object : EventSourceListener() { + override fun onOpen(eventSource: EventSource, response: Response) { + trySend(SseEvent.Open(eventSource, response)) + } + + override fun onEvent( + eventSource: EventSource, + id: String?, + type: String?, + data: String + ) { + if (debug) { + println("raw sse data: " +data) + } + + if (data.trim() == "[DONE]") { + // The server is about to close the connection + return + } + + try { + val eventData = gson.fromJson(data, eventType) as T? + if (eventData != null) { + trySend(SseEvent.Event(id, type, eventData)) + } + } catch (e: Exception) { + val failure = SseEvent.Failure(e, null) + trySend(failure) + close(e) + } + } + + override fun onClosed(eventSource: EventSource) { + trySend(SseEvent.Closed) + close() // Close the flow + } + + override fun onFailure( + eventSource: EventSource, + t: Throwable?, + response: Response? + ) { + // TODO aiapiexception here + val error = t ?: RuntimeException("Unknown SSE error") + trySend(SseEvent.Failure(error, response)) + close(error) + } + } + + // Create a new request from the Retrofit call + val request: Request = call.request() + val eventSource = EventSources.createFactory(client).newEventSource(request, listener) + + // This block is called when the Flow is cancelled + awaitClose { + eventSource.cancel() + } + } + } +} \ No newline at end of file diff --git a/app/src/main/java/eu/m724/chatapp/api/retrofit/sse/SseCallAdapterFactory.kt b/app/src/main/java/eu/m724/chatapp/api/retrofit/sse/SseCallAdapterFactory.kt new file mode 100644 index 0000000..b956554 --- /dev/null +++ b/app/src/main/java/eu/m724/chatapp/api/retrofit/sse/SseCallAdapterFactory.kt @@ -0,0 +1,37 @@ +package eu.m724.chatapp.api.retrofit.sse + +import com.google.gson.Gson +import kotlinx.coroutines.flow.Flow +import okhttp3.OkHttpClient +import retrofit2.CallAdapter +import retrofit2.Retrofit +import java.lang.reflect.ParameterizedType +import java.lang.reflect.Type + +class SseCallAdapterFactory( + private val client: OkHttpClient, + private val gson: Gson, + private val debug: Boolean +) : CallAdapter.Factory() { + override fun get( + returnType: Type, + annotations: Array, + retrofit: Retrofit + ): CallAdapter<*, *>? { + // Ensure the return type is a Flow + if (getRawType(returnType) != Flow::class.java) { + println("wrong type: " + getRawType(returnType)) + return null + } + + // Ensure the Flow's generic type is SseEvent + val flowType = getParameterUpperBound(0, returnType as ParameterizedType) + if (getRawType(flowType) != SseEvent::class.java) { + return null + } + + // Get the generic type of SseEvent + val eventType = getParameterUpperBound(0, flowType as ParameterizedType) + return SseCallAdapter(client, gson, eventType, debug) + } +} \ No newline at end of file diff --git a/app/src/main/java/eu/m724/chatapp/api/retrofit/sse/SseEvent.kt b/app/src/main/java/eu/m724/chatapp/api/retrofit/sse/SseEvent.kt new file mode 100644 index 0000000..58421f6 --- /dev/null +++ b/app/src/main/java/eu/m724/chatapp/api/retrofit/sse/SseEvent.kt @@ -0,0 +1,22 @@ +package eu.m724.chatapp.api.retrofit.sse + +import okhttp3.Response +import okhttp3.sse.EventSource + +/** + * A sealed class representing the different states of an SSE connection. + * @param T The type of the data expected in the event payload. + */ +sealed class SseEvent { + // The connection was successfully opened + data class Open(val eventSource: EventSource, val response: Response) : SseEvent() + + // A new event (message) was received from the server + data class Event(val id: String?, val name: String?, val data: T) : SseEvent() + + // The connection was closed, either by the server or client + object Closed : SseEvent() + + // An unrecoverable error occurred + data class Failure(val error: Throwable, val response: Response?) : SseEvent() +} \ No newline at end of file diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index c73221a..12301fa 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -16,6 +16,7 @@ retrofit = "3.0.0" secrets = "2.0.1" loggingInterceptor = "4.12.0" material3WindowSizeClass = "1.3.2" +okhttpSse = "4.12.0" [libraries] androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } @@ -40,6 +41,7 @@ retrofit = { group = "com.squareup.retrofit2", name = "retrofit", version.ref = retrofit-converter-gson = { group = "com.squareup.retrofit2", name = "converter-gson", version.ref = "retrofit"} logging-interceptor = { group = "com.squareup.okhttp3", name = "logging-interceptor", version.ref = "loggingInterceptor" } androidx-material3-window-size-class1 = { group = "androidx.compose.material3", name = "material3-window-size-class", version.ref = "material3WindowSizeClass" } +okhttp-sse = { group = "com.squareup.okhttp3", name = "okhttp-sse", version.ref = "okhttpSse" } [plugins] android-application = { id = "com.android.application", version.ref = "agp" }