Make it better
This commit is contained in:
parent
feec030726
commit
db2cf5b7cb
96 changed files with 2465 additions and 1068 deletions
1
aiapi/.gitignore
vendored
Normal file
1
aiapi/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
/build
|
32
aiapi/build.gradle.kts
Normal file
32
aiapi/build.gradle.kts
Normal file
|
@ -0,0 +1,32 @@
|
|||
plugins {
|
||||
id("java-library")
|
||||
alias(libs.plugins.kotlin.jvm)
|
||||
alias(libs.plugins.ksp)
|
||||
}
|
||||
|
||||
java {
|
||||
sourceCompatibility = JavaVersion.VERSION_11
|
||||
targetCompatibility = JavaVersion.VERSION_11
|
||||
}
|
||||
|
||||
kotlin {
|
||||
compilerOptions {
|
||||
jvmTarget = org.jetbrains.kotlin.gradle.dsl.JvmTarget.JVM_11
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation(libs.moshi)
|
||||
implementation(libs.moshi.kotlin)
|
||||
implementation(libs.retrofit)
|
||||
implementation(libs.retrofit.converter.moshi)
|
||||
implementation(libs.logging.interceptor)
|
||||
implementation(libs.kotlinx.coroutines.core)
|
||||
implementation(libs.okhttp.sse)
|
||||
testImplementation(libs.junit)
|
||||
testImplementation(libs.kotlinx.coroutines.test)
|
||||
testImplementation(libs.mockwebserver)
|
||||
testImplementation(libs.truth)
|
||||
testImplementation(libs.turbine)
|
||||
ksp(libs.moshi.kotlin.codegen)
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
package eu.m724.newchat.aiapi
|
||||
|
||||
interface AiApiConfiguration {
|
||||
val apiKey: String
|
||||
// TODO consider HttpUrl for this one, however that would require okhttp on the client
|
||||
// TODO maybe an option to set multiple domains
|
||||
val endpoint: String
|
||||
val userAgent: String
|
||||
val isDebug: Boolean
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
package eu.m724.newchat.aiapi
|
||||
|
||||
import com.squareup.moshi.Moshi
|
||||
import eu.m724.newchat.aiapi.repository.AiApiRepository
|
||||
import eu.m724.newchat.aiapi.repository.AiApiRepositoryImpl
|
||||
import eu.m724.newchat.aiapi.retrofit.AiApiRequestExceptionInterceptor
|
||||
import eu.m724.newchat.aiapi.retrofit.RequestHeadersInterceptor
|
||||
import eu.m724.newchat.aiapi.retrofit.sse.SseCallAdapterFactory
|
||||
import okhttp3.OkHttpClient
|
||||
import okhttp3.logging.HttpLoggingInterceptor
|
||||
import retrofit2.Retrofit
|
||||
import retrofit2.converter.moshi.MoshiConverterFactory
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
object AiApiDataLayerFactory {
|
||||
fun createApiRepository(
|
||||
configuration: AiApiConfiguration
|
||||
): AiApiRepository {
|
||||
val apiService = createApiService(configuration)
|
||||
return AiApiRepositoryImpl(apiService)
|
||||
}
|
||||
|
||||
private fun createApiService(
|
||||
configuration: AiApiConfiguration
|
||||
): AiApiService {
|
||||
return createRetrofit(configuration).create(AiApiService::class.java)
|
||||
}
|
||||
|
||||
private fun createRetrofit(
|
||||
configuration: AiApiConfiguration
|
||||
): Retrofit {
|
||||
val moshi = Moshi.Builder()
|
||||
.build()
|
||||
|
||||
val okHttpClientBuilder = createOkHttpClientBuilder(configuration)
|
||||
val standardOkHttpClient = createStandardOkHttpClient(okHttpClientBuilder, moshi)
|
||||
val longLivedOkHttpClient = createLongLivedOkHttpClient(okHttpClientBuilder)
|
||||
|
||||
return Retrofit.Builder()
|
||||
.baseUrl(configuration.endpoint)
|
||||
.client(standardOkHttpClient) // Use the standard client by default
|
||||
.addCallAdapterFactory(
|
||||
SseCallAdapterFactory(
|
||||
longLivedOkHttpClient,
|
||||
moshi,
|
||||
configuration.isDebug
|
||||
)
|
||||
) // This intercepts SSE requests and makes them use the long-lived client
|
||||
.addConverterFactory(MoshiConverterFactory.create(moshi))
|
||||
.build()
|
||||
}
|
||||
|
||||
private fun createOkHttpClientBuilder(
|
||||
configuration: AiApiConfiguration
|
||||
): OkHttpClient.Builder {
|
||||
val interceptor = RequestHeadersInterceptor(
|
||||
endpoint = configuration.endpoint,
|
||||
apiKey = configuration.apiKey,
|
||||
userAgent = configuration.userAgent
|
||||
)
|
||||
|
||||
val builder = OkHttpClient.Builder()
|
||||
.addInterceptor(interceptor)
|
||||
|
||||
if (configuration.isDebug) {
|
||||
// Level.BODY buffers the entire response, which destroys SSE
|
||||
builder.addInterceptor(
|
||||
HttpLoggingInterceptor().apply {
|
||||
level = HttpLoggingInterceptor.Level.HEADERS
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
return builder
|
||||
}
|
||||
|
||||
private fun createStandardOkHttpClient(
|
||||
builder: OkHttpClient.Builder,
|
||||
moshi: Moshi
|
||||
): OkHttpClient {
|
||||
return builder
|
||||
.addInterceptor(AiApiRequestExceptionInterceptor(moshi))
|
||||
.build()
|
||||
}
|
||||
|
||||
private fun createLongLivedOkHttpClient(
|
||||
builder: OkHttpClient.Builder
|
||||
): OkHttpClient {
|
||||
// TODO figure out why can't we just use AiApiRequestExceptionInterceptor
|
||||
return builder
|
||||
.readTimeout(0, TimeUnit.SECONDS)
|
||||
.build()
|
||||
}
|
||||
}
|
23
aiapi/src/main/java/eu/m724/newchat/aiapi/AiApiService.kt
Normal file
23
aiapi/src/main/java/eu/m724/newchat/aiapi/AiApiService.kt
Normal file
|
@ -0,0 +1,23 @@
|
|||
package eu.m724.newchat.aiapi
|
||||
|
||||
import eu.m724.newchat.aiapi.models.dto.lm.LanguageModelResponseDto
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.ChatCompletionRequestDto
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.ChatCompletionResponseEventDto
|
||||
import eu.m724.newchat.aiapi.retrofit.sse.SseEvent
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import retrofit2.http.Body
|
||||
import retrofit2.http.GET
|
||||
import retrofit2.http.Headers
|
||||
import retrofit2.http.POST
|
||||
import retrofit2.http.Streaming
|
||||
|
||||
interface AiApiService {
|
||||
@GET("models?detailed=true")
|
||||
@Headers("Accept: application/json")
|
||||
suspend fun getLanguageModels(): LanguageModelResponseDto
|
||||
|
||||
@POST("chat/completions")
|
||||
@Headers("Accept: text/event-stream")
|
||||
@Streaming
|
||||
fun streamChatCompletion(@Body body: ChatCompletionRequestDto): Flow<SseEvent<ChatCompletionResponseEventDto>>
|
||||
}
|
|
@ -1,28 +1,34 @@
|
|||
package eu.m724.chatapp.api.data.request.completion
|
||||
package eu.m724.newchat.aiapi.models.dto.chat
|
||||
|
||||
import eu.m724.chatapp.api.data.response.completion.ChatMessage
|
||||
import com.squareup.moshi.Json
|
||||
import com.squareup.moshi.JsonClass
|
||||
|
||||
data class ChatCompletionRequest(
|
||||
@JsonClass(generateAdapter = true)
|
||||
data class ChatCompletionRequestDto(
|
||||
/**
|
||||
* The model ID
|
||||
*/
|
||||
@Json(name = "model")
|
||||
val model: String,
|
||||
|
||||
/**
|
||||
* The messages in the current chat. Usually a user message is the last one if making a request.
|
||||
*/
|
||||
val messages: List<ChatMessage>,
|
||||
@Json(name = "messages")
|
||||
val messages: List<ChatMessageDto>,
|
||||
|
||||
/**
|
||||
* The maximum amount of tokens to generate
|
||||
*/
|
||||
@Json(name = "max_tokens")
|
||||
val maxTokens: Int = Int.MAX_VALUE,
|
||||
|
||||
/**
|
||||
* Controls the "creativity" of the model.
|
||||
* Read more: https://www.iguazio.com/glossary/llm-temperature/
|
||||
*/
|
||||
val temperature: Float,
|
||||
|
||||
/**
|
||||
* The maximum amount of tokens to generate
|
||||
*/
|
||||
val maxTokens: Int,
|
||||
@Json(name = "temperature")
|
||||
val temperature: Float = 0.0f,
|
||||
|
||||
/**
|
||||
* Controls the repetition of words in the generated text.
|
||||
|
@ -31,7 +37,8 @@ data class ChatCompletionRequest(
|
|||
* Read more: https://www.promptitude.io/glossary/frequency-penalty
|
||||
* @see presencePenalty
|
||||
*/
|
||||
val frequencyPenalty: Float,
|
||||
@Json(name = "frequency_penalty")
|
||||
val frequencyPenalty: Float = 0.0f,
|
||||
|
||||
/**
|
||||
* Controls the repetition of words in the generated text.
|
||||
|
@ -40,9 +47,23 @@ data class ChatCompletionRequest(
|
|||
* Read more: https://www.promptitude.io/glossary/frequency-penalty
|
||||
* @see frequencyPenalty
|
||||
*/
|
||||
val presencePenalty: Float,
|
||||
@Json(name = "presence_penalty")
|
||||
val presencePenalty: Float = 0.0f,
|
||||
|
||||
/**
|
||||
* Do not change please
|
||||
* TODO make not changeable
|
||||
*/
|
||||
@Json(name = "stream")
|
||||
val stream: Boolean = true,
|
||||
|
||||
/**
|
||||
* Do not change please
|
||||
* TODO make not changeable
|
||||
*/
|
||||
@Json(name = "stream_options")
|
||||
val streamOptions: Map<String, Boolean> = mapOf("include_usage" to true),
|
||||
|
||||
val stream: Boolean = true
|
||||
) {
|
||||
init {
|
||||
require(temperature >= 0.0) { "temperature must be at least 0.0" }
|
||||
|
@ -55,5 +76,8 @@ data class ChatCompletionRequest(
|
|||
|
||||
require(presencePenalty >= -2.0) { "presencePenalty must be at least -2.0" }
|
||||
require(presencePenalty <= 2.0) { "presencePenalty must be at most 2.0" }
|
||||
|
||||
require(stream) { "stream must be true, I told you not to change it" }
|
||||
require(streamOptions["include_usage"] == true) { "streamOptions must include include_usage, I told you not to change it" }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
package eu.m724.newchat.aiapi.models.dto.chat
|
||||
|
||||
import com.squareup.moshi.Json
|
||||
import com.squareup.moshi.JsonClass
|
||||
|
||||
@JsonClass(generateAdapter = true)
|
||||
data class ChatCompletionResponseEventDto(
|
||||
@Json(name = "id")
|
||||
val id: String,
|
||||
|
||||
@Json(name = "object")
|
||||
val objectType: String,
|
||||
|
||||
@Json(name = "choices")
|
||||
val choices: List<CompletionChoiceDto>,
|
||||
|
||||
/**
|
||||
* Only at end
|
||||
*/
|
||||
@Json(name = "usage")
|
||||
val usage: CompletionUsageDto?
|
||||
)
|
|
@ -0,0 +1,18 @@
|
|||
package eu.m724.newchat.aiapi.models.dto.chat
|
||||
|
||||
import com.squareup.moshi.Json
|
||||
import com.squareup.moshi.JsonClass
|
||||
|
||||
@JsonClass(generateAdapter = true)
|
||||
data class ChatMessageDto(
|
||||
// TODO make enum
|
||||
@Json(name = "role")
|
||||
val role: String,
|
||||
|
||||
@Json(name = "content")
|
||||
val content: String
|
||||
) {
|
||||
init {
|
||||
require(role == "user" || role == "assistant" || role == "system") { "role must be user, assistant or system" }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
package eu.m724.newchat.aiapi.models.dto.chat
|
||||
|
||||
import com.squareup.moshi.Json
|
||||
import com.squareup.moshi.JsonClass
|
||||
|
||||
@JsonClass(generateAdapter = true)
|
||||
data class CompletionChoiceDto(
|
||||
@Json(name = "index")
|
||||
val index: Int,
|
||||
|
||||
/**
|
||||
* The generated message delta, you should merge it with the previous delta
|
||||
*/
|
||||
@Json(name = "delta")
|
||||
val delta: CompletionChoiceDeltaDto,
|
||||
|
||||
/**
|
||||
* The reason why generating the response has stopped. null if the response hasn't finished yet.
|
||||
*/
|
||||
@Json(name = "finish_reason")
|
||||
val finishReason: FinishReason?
|
||||
) {
|
||||
enum class FinishReason {
|
||||
/**
|
||||
* The response has stopped, because the model said so
|
||||
*/
|
||||
@Json(name = "stop")
|
||||
Stop,
|
||||
|
||||
/**
|
||||
* The response has stopped, because it got too long
|
||||
*/
|
||||
@Json(name = "length")
|
||||
Length,
|
||||
|
||||
/**
|
||||
* The response has stopped, because the content got flagged
|
||||
*/
|
||||
@Json(name = "content_filter")
|
||||
ContentFilter
|
||||
}
|
||||
}
|
||||
|
||||
@JsonClass(generateAdapter = true)
|
||||
data class CompletionChoiceDeltaDto(
|
||||
@Json(name = "content")
|
||||
/** The next generated token, may be null if the response just finished */
|
||||
val content: String?
|
||||
)
|
||||
|
||||
@JsonClass(generateAdapter = true)
|
||||
data class CompletionUsageDto(
|
||||
/**
|
||||
* The amount of tokens of the prompt
|
||||
*/
|
||||
@Json(name = "prompt_tokens")
|
||||
val promptTokens: Int,
|
||||
|
||||
/**
|
||||
* The amount of tokens of the generated completion
|
||||
*/
|
||||
@Json(name = "completion_tokens")
|
||||
val completionTokens: Int,
|
||||
|
||||
/**
|
||||
* The total amount of tokens processed
|
||||
*/
|
||||
@Json(name = "total_tokens")
|
||||
val totalTokens: Int,
|
||||
|
||||
@Json(name = "cost")
|
||||
val cost: Double,
|
||||
|
||||
@Json(name = "currency")
|
||||
val currency: Currency
|
||||
) {
|
||||
enum class Currency {
|
||||
USD,
|
||||
XNO
|
||||
}
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package eu.m724.newchat.aiapi.models.dto.lm
|
||||
|
||||
import com.squareup.moshi.Json
|
||||
import com.squareup.moshi.JsonClass
|
||||
|
||||
/**
|
||||
* A language model as defined by the API
|
||||
*/
|
||||
@JsonClass(generateAdapter = true)
|
||||
data class LanguageModelDto(
|
||||
/**
|
||||
* The ID of the language model
|
||||
*/
|
||||
@Json(name = "id")
|
||||
val id: String,
|
||||
|
||||
/**
|
||||
* The pretty name of the language model
|
||||
*/
|
||||
@Json(name = "name")
|
||||
val name: String,
|
||||
|
||||
/**
|
||||
* The description of the language model
|
||||
*/
|
||||
@Json(name = "description")
|
||||
val description: String,
|
||||
|
||||
/**
|
||||
* The context length of the language model
|
||||
*/
|
||||
@Json(name = "context_length")
|
||||
val contextLength: Int,
|
||||
|
||||
/**
|
||||
* The pricing of the language model
|
||||
*/
|
||||
@Json(name = "pricing")
|
||||
val pricing: LanguageModelPricingDto,
|
||||
|
||||
/**
|
||||
* The URL of the icon of the language model
|
||||
*/
|
||||
@Json(name = "icon_url")
|
||||
val iconUrl: String?
|
||||
|
||||
) {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (javaClass != other?.javaClass) return false
|
||||
|
||||
other as LanguageModelDto
|
||||
|
||||
if (contextLength != other.contextLength) return false
|
||||
if (id != other.id) return false
|
||||
if (name != other.name) return false
|
||||
if (description != other.description) return false
|
||||
if (pricing != other.pricing) return false
|
||||
if (iconUrl != other.iconUrl) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = contextLength
|
||||
result = 31 * result + id.hashCode()
|
||||
result = 31 * result + name.hashCode()
|
||||
result = 31 * result + description.hashCode()
|
||||
result = 31 * result + pricing.hashCode()
|
||||
result = 31 * result + iconUrl.hashCode()
|
||||
return result
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
package eu.m724.newchat.aiapi.models.dto.lm
|
||||
|
||||
import com.squareup.moshi.Json
|
||||
import com.squareup.moshi.JsonClass
|
||||
|
||||
/**
|
||||
* The pricing of a language model as defined by the API
|
||||
*/
|
||||
@JsonClass(generateAdapter = true)
|
||||
data class LanguageModelPricingDto(
|
||||
/**
|
||||
* The cost per million input tokens of the language model
|
||||
*/
|
||||
@Json(name = "prompt")
|
||||
val inputCostPerMillionTokens: Double,
|
||||
|
||||
/**
|
||||
* The cost per million output tokens of the language model
|
||||
*/
|
||||
@Json(name = "completion")
|
||||
val completionCostPerMillionTokens: Double,
|
||||
) {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (javaClass != other?.javaClass) return false
|
||||
|
||||
other as LanguageModelPricingDto
|
||||
|
||||
if (inputCostPerMillionTokens != other.inputCostPerMillionTokens) return false
|
||||
if (completionCostPerMillionTokens != other.completionCostPerMillionTokens) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = inputCostPerMillionTokens.hashCode()
|
||||
result = 31 * result + completionCostPerMillionTokens.hashCode()
|
||||
return result
|
||||
}
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
package eu.m724.newchat.aiapi.models.dto.lm
|
||||
|
||||
import com.squareup.moshi.Json
|
||||
import com.squareup.moshi.JsonClass
|
||||
|
||||
// TODO remove this
|
||||
@JsonClass(generateAdapter = true)
|
||||
data class LanguageModelResponseDto(
|
||||
@Json(name = "data")
|
||||
val data: List<LanguageModelDto>
|
||||
)
|
|
@ -0,0 +1,24 @@
|
|||
package eu.m724.newchat.aiapi.models.repo
|
||||
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.CompletionChoiceDto
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.CompletionUsageDto
|
||||
|
||||
sealed class ChatCompletionResponseChunk {
|
||||
data class PartialContent(
|
||||
val completionPart: String
|
||||
) : ChatCompletionResponseChunk()
|
||||
|
||||
data class Finish(
|
||||
val finishReason: CompletionChoiceDto.FinishReason,
|
||||
val cost: Cost
|
||||
) : ChatCompletionResponseChunk() {
|
||||
data class Cost(
|
||||
val promptTokens: Int,
|
||||
val completionTokens: Int,
|
||||
|
||||
val total: Double,
|
||||
val currency: CompletionUsageDto.Currency
|
||||
)
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
package eu.m724.newchat.aiapi.repository
|
||||
|
||||
import eu.m724.newchat.aiapi.models.dto.lm.LanguageModelResponseDto
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.ChatCompletionRequestDto
|
||||
import eu.m724.newchat.aiapi.models.repo.ChatCompletionResponseChunk
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
|
||||
interface AiApiRepository {
|
||||
suspend fun getLanguageModels(): Result<LanguageModelResponseDto> // TODO cache or something
|
||||
fun streamChatCompletion(request: ChatCompletionRequestDto): Flow<ChatCompletionResponseChunk>
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package eu.m724.newchat.aiapi.repository
|
||||
|
||||
import eu.m724.newchat.aiapi.AiApiService
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.ChatCompletionRequestDto
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.ChatCompletionResponseEventDto
|
||||
import eu.m724.newchat.aiapi.models.dto.lm.LanguageModelResponseDto
|
||||
import eu.m724.newchat.aiapi.models.repo.ChatCompletionResponseChunk
|
||||
import eu.m724.newchat.aiapi.retrofit.sse.SseEvent
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.mapNotNull
|
||||
|
||||
internal class AiApiRepositoryImpl(
|
||||
private val apiService: AiApiService
|
||||
) : AiApiRepository {
|
||||
override suspend fun getLanguageModels(): Result<LanguageModelResponseDto> {
|
||||
return try {
|
||||
val models = apiService.getLanguageModels()
|
||||
|
||||
Result.success(models)
|
||||
} catch (e: Exception) {
|
||||
Result.failure(e)
|
||||
}
|
||||
}
|
||||
|
||||
override fun streamChatCompletion(request: ChatCompletionRequestDto): Flow<ChatCompletionResponseChunk> {
|
||||
var finished = false
|
||||
|
||||
return apiService.streamChatCompletion(request).mapNotNull { event ->
|
||||
when (event) {
|
||||
is SseEvent.Open -> null
|
||||
is SseEvent.Event<ChatCompletionResponseEventDto> -> {
|
||||
check (!finished) { "Received an event after the stream was marked as finished." }
|
||||
|
||||
require(event.data.choices.size <= 1) { "Handling multiple choices is not supported." }
|
||||
|
||||
val choice = event.data.choices.firstOrNull() ?: return@mapNotNull null
|
||||
|
||||
when {
|
||||
choice.finishReason != null -> {
|
||||
val usage = event.data.usage
|
||||
|
||||
if (usage == null) {
|
||||
// bug in api where finish reason is repeated TODO maybe report it
|
||||
null
|
||||
} else {
|
||||
// Mark the stream as finished and create the final event.
|
||||
finished = true
|
||||
ChatCompletionResponseChunk.Finish(
|
||||
finishReason = choice.finishReason,
|
||||
cost = ChatCompletionResponseChunk.Finish.Cost(
|
||||
promptTokens = usage.promptTokens,
|
||||
completionTokens = usage.completionTokens,
|
||||
total = usage.cost,
|
||||
currency = usage.currency
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
choice.delta.content != null -> {
|
||||
ChatCompletionResponseChunk.PartialContent(
|
||||
completionPart = choice.delta.content
|
||||
)
|
||||
}
|
||||
|
||||
else -> {
|
||||
println("Ignoring unhandleable event: ${event.data}")
|
||||
return@mapNotNull null
|
||||
//throw NotImplementedError("Cannot handle this event: ${event.data}")
|
||||
}
|
||||
}
|
||||
}
|
||||
is SseEvent.Close -> null
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package eu.m724.newchat.aiapi.retrofit
|
||||
|
||||
import com.squareup.moshi.Moshi
|
||||
import eu.m724.newchat.aiapi.retrofit.exception.ApiException
|
||||
import eu.m724.newchat.aiapi.retrofit.exception.ApiExceptionWrappedDto
|
||||
import okhttp3.Interceptor
|
||||
import okhttp3.Response
|
||||
|
||||
class AiApiRequestExceptionInterceptor(
|
||||
moshi: Moshi
|
||||
) : Interceptor {
|
||||
private val adapter = moshi.adapter(ApiExceptionWrappedDto::class.java) // TODO weird
|
||||
|
||||
override fun intercept(chain: Interceptor.Chain): Response {
|
||||
val request = chain.request()
|
||||
val response = chain.proceed(request)
|
||||
|
||||
if (!response.isSuccessful) {
|
||||
// TIP you can't do that you can't read string twice it'll break
|
||||
// println("Error and response body: ${response.body.string()}")
|
||||
|
||||
val details = try {
|
||||
adapter.fromJson(response.body.string())!!.error
|
||||
} catch (e: Exception) {
|
||||
println("Wow an exception parsing the error: $e")
|
||||
null
|
||||
}
|
||||
|
||||
if (details != null) {
|
||||
throw ApiException(
|
||||
message = details.message,
|
||||
errorType = details.errorType,
|
||||
errorSubType = details.errorSubType,
|
||||
clientError = response.code < 500
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
}
|
|
@ -1,18 +1,21 @@
|
|||
package eu.m724.chatapp.api.retrofit.interceptor
|
||||
package eu.m724.newchat.aiapi.retrofit
|
||||
|
||||
import okhttp3.Interceptor
|
||||
import okhttp3.Response
|
||||
|
||||
class AiApiRequestHeadersInterceptor(
|
||||
private val userAgent: String,
|
||||
private val apiEndpoint: String,
|
||||
private val apiKey: String
|
||||
) : Interceptor {
|
||||
/**
|
||||
* Adds API-specific headers to requests, like API key and User-Agent.
|
||||
*/
|
||||
class RequestHeadersInterceptor(
|
||||
private val endpoint: String,
|
||||
private val apiKey: String,
|
||||
private val userAgent: String
|
||||
) : Interceptor {
|
||||
override fun intercept(chain: Interceptor.Chain): Response {
|
||||
val builder = chain.request().newBuilder()
|
||||
.header("User-Agent", userAgent)
|
||||
|
||||
if (chain.request().url.toString().startsWith(apiEndpoint)) {
|
||||
if (chain.request().url.toString().startsWith(endpoint)) {
|
||||
builder.header("Authorization", "Bearer $apiKey")
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
package eu.m724.newchat.aiapi.retrofit.exception
|
||||
|
||||
import java.io.IOException
|
||||
|
||||
open class ApiException(
|
||||
override val message: String,
|
||||
val errorType: String,
|
||||
val errorSubType: String?,
|
||||
|
||||
/**
|
||||
* :|
|
||||
*/
|
||||
val clientError: Boolean
|
||||
): IOException(message)
|
|
@ -0,0 +1,19 @@
|
|||
package eu.m724.newchat.aiapi.retrofit.exception
|
||||
|
||||
import com.squareup.moshi.Json
|
||||
import com.squareup.moshi.JsonClass
|
||||
|
||||
@JsonClass(generateAdapter = true)
|
||||
data class ApiExceptionDetailsDto(
|
||||
@Json(name = "message")
|
||||
val message: String,
|
||||
|
||||
@Json(name = "type")
|
||||
val errorType: String,
|
||||
|
||||
/**
|
||||
* Actually "code", renamed for clarity
|
||||
*/
|
||||
@Json(name = "code")
|
||||
val errorSubType: String?
|
||||
)
|
|
@ -0,0 +1,10 @@
|
|||
package eu.m724.newchat.aiapi.retrofit.exception
|
||||
|
||||
import com.squareup.moshi.Json
|
||||
import com.squareup.moshi.JsonClass
|
||||
|
||||
@JsonClass(generateAdapter = true)
|
||||
internal data class ApiExceptionWrappedDto(
|
||||
@Json(name = "error")
|
||||
val error: ApiExceptionDetailsDto
|
||||
)
|
|
@ -1,8 +1,8 @@
|
|||
package eu.m724.chatapp.api.retrofit.sse
|
||||
package eu.m724.newchat.aiapi.retrofit.sse
|
||||
|
||||
import com.google.gson.Gson
|
||||
import eu.m724.chatapp.api.data.AiApiException
|
||||
import eu.m724.chatapp.api.data.AiApiExceptionDataWrapper
|
||||
import com.squareup.moshi.Moshi
|
||||
import eu.m724.newchat.aiapi.retrofit.exception.ApiException
|
||||
import eu.m724.newchat.aiapi.retrofit.exception.ApiExceptionWrappedDto
|
||||
import kotlinx.coroutines.channels.awaitClose
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.callbackFlow
|
||||
|
@ -18,10 +18,13 @@ import java.lang.reflect.Type
|
|||
|
||||
class SseCallAdapter<T : Any>(
|
||||
private val client: OkHttpClient,
|
||||
private val gson: Gson,
|
||||
moshi: Moshi,
|
||||
private val eventType: Type,
|
||||
private val debug: Boolean
|
||||
) : CallAdapter<T, Flow<SseEvent<T>>> {
|
||||
private val eventDataAdapter = moshi.adapter<T>(eventType)
|
||||
private val exceptionAdapter = moshi.adapter(ApiExceptionWrappedDto::class.java) // TODO weird
|
||||
|
||||
override fun responseType(): Type = eventType
|
||||
|
||||
override fun adapt(call: Call<T>): Flow<SseEvent<T>> {
|
||||
|
@ -38,29 +41,32 @@ class SseCallAdapter<T : Any>(
|
|||
data: String
|
||||
) {
|
||||
if (debug) {
|
||||
println("raw sse data: " +data)
|
||||
println("Received SSE data: $data")
|
||||
}
|
||||
|
||||
if (data.trim() == "[DONE]") {
|
||||
// The server is about to close the connection
|
||||
// TODO should *we* close here?
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
val eventData = gson.fromJson(data, eventType) as T?
|
||||
val eventData = eventDataAdapter.fromJson(data)
|
||||
if (eventData != null) {
|
||||
trySend(SseEvent.Event(id, type, eventData))
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
val failure = SseEvent.Failure(e, null)
|
||||
trySend(failure)
|
||||
close(e)
|
||||
close(e) // we sent Failure here too with null body
|
||||
}
|
||||
}
|
||||
|
||||
override fun onClosed(eventSource: EventSource) {
|
||||
trySend(SseEvent.Closed)
|
||||
close() // Close the flow
|
||||
if (debug) {
|
||||
println("SSE stream onClosed")
|
||||
}
|
||||
|
||||
trySend(SseEvent.Close)
|
||||
close()
|
||||
}
|
||||
|
||||
override fun onFailure(
|
||||
|
@ -68,25 +74,32 @@ class SseCallAdapter<T : Any>(
|
|||
t: Throwable?,
|
||||
response: Response?
|
||||
) {
|
||||
// TODO find a nicer way to handle errors
|
||||
val exc = if (response != null) {
|
||||
println("sse response: " + response.code)
|
||||
println("sse response body: " + response.body?.string())
|
||||
println("SSE failure full response code: " + response.code)
|
||||
//println("SSE failure full response body: " + response.body.string())
|
||||
|
||||
val apiError =
|
||||
try {
|
||||
gson.fromJson(response.body!!.string(), AiApiExceptionDataWrapper::class.java)
|
||||
} catch (_: Exception) {
|
||||
null
|
||||
}?.error
|
||||
|
||||
AiApiException(response.code, apiError)
|
||||
try {
|
||||
exceptionAdapter.fromJson(response.body.string())!!.error
|
||||
} catch (e: Exception) {
|
||||
println("Wow an exception parsing the error: $e")
|
||||
null
|
||||
}?.let { details ->
|
||||
ApiException(
|
||||
message = details.message,
|
||||
errorType = details.errorType,
|
||||
errorSubType = details.errorSubType,
|
||||
clientError = response.code < 500 // TODO might be inaccurate if it's 200 at first but something goes wrong during the stream
|
||||
)
|
||||
} ?: t
|
||||
} else {
|
||||
println("SSE error: " + (t?.message ?: "Unknown"))
|
||||
|
||||
t
|
||||
}
|
||||
|
||||
val error = exc ?: RuntimeException("Unknown SSE error")
|
||||
trySend(SseEvent.Failure(error, response))
|
||||
close(error)
|
||||
close(error) // TODO before it was a Failure with optional response, maybe we could provide the response here too
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -96,6 +109,10 @@ class SseCallAdapter<T : Any>(
|
|||
|
||||
// This block is called when the Flow is cancelled
|
||||
awaitClose {
|
||||
if (debug) {
|
||||
println("SSE flow cancelled")
|
||||
}
|
||||
|
||||
eventSource.cancel()
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
package eu.m724.chatapp.api.retrofit.sse
|
||||
package eu.m724.newchat.aiapi.retrofit.sse
|
||||
|
||||
import com.google.gson.Gson
|
||||
import com.squareup.moshi.Moshi
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import okhttp3.OkHttpClient
|
||||
import retrofit2.CallAdapter
|
||||
|
@ -10,7 +10,7 @@ import java.lang.reflect.Type
|
|||
|
||||
class SseCallAdapterFactory(
|
||||
private val client: OkHttpClient,
|
||||
private val gson: Gson,
|
||||
private val moshi: Moshi,
|
||||
private val debug: Boolean
|
||||
) : CallAdapter.Factory() {
|
||||
override fun get(
|
||||
|
@ -18,20 +18,19 @@ class SseCallAdapterFactory(
|
|||
annotations: Array<out Annotation>,
|
||||
retrofit: Retrofit
|
||||
): CallAdapter<*, *>? {
|
||||
// Ensure the return type is a Flow
|
||||
// Ensure the return type is a Flow (ensure this is a SSE request)
|
||||
if (getRawType(returnType) != Flow::class.java) {
|
||||
println("wrong type: " + getRawType(returnType))
|
||||
return null
|
||||
}
|
||||
|
||||
// Ensure the Flow's generic type is SseEvent
|
||||
// Ensure the Flow's generic type is SseEvent (idk what for)
|
||||
val flowType = getParameterUpperBound(0, returnType as ParameterizedType)
|
||||
if (getRawType(flowType) != SseEvent::class.java) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Get the generic type of SseEvent<T>
|
||||
// Get the generic type of SseEvent<T> (dunno why either)
|
||||
val eventType = getParameterUpperBound(0, flowType as ParameterizedType)
|
||||
return SseCallAdapter<Any>(client, gson, eventType, debug)
|
||||
return SseCallAdapter<Any>(client, moshi, eventType, debug)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package eu.m724.chatapp.api.retrofit.sse
|
||||
package eu.m724.newchat.aiapi.retrofit.sse
|
||||
|
||||
import okhttp3.Response
|
||||
import okhttp3.sse.EventSource
|
||||
|
@ -15,8 +15,5 @@ sealed class SseEvent<out T> {
|
|||
data class Event<out T>(val id: String?, val name: String?, val data: T) : SseEvent<T>()
|
||||
|
||||
// The connection was closed, either by the server or client
|
||||
object Closed : SseEvent<Nothing>()
|
||||
|
||||
// An unrecoverable error occurred
|
||||
data class Failure(val error: Throwable, val response: Response?) : SseEvent<Nothing>()
|
||||
object Close : SseEvent<Nothing>()
|
||||
}
|
|
@ -0,0 +1,239 @@
|
|||
package eu.m724.newchat.aiapi.endpoints
|
||||
|
||||
import app.cash.turbine.test
|
||||
import com.google.common.truth.Truth.assertThat
|
||||
import com.squareup.moshi.JsonEncodingException
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.ChatCompletionRequestDto
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.ChatCompletionResponseEventDto
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.ChatMessageDto
|
||||
import eu.m724.newchat.aiapi.retrofit.exception.ApiException
|
||||
import eu.m724.newchat.aiapi.retrofit.sse.SseEvent
|
||||
import eu.m724.newchat.aiapi.rules.MockApiRule
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import okhttp3.mockwebserver.MockResponse
|
||||
import okhttp3.mockwebserver.SocketPolicy
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
class ChatCompletionsTest {
|
||||
@Rule
|
||||
@JvmField
|
||||
val mockApi = MockApiRule(
|
||||
apiKey = "totally-valid-key",
|
||||
userAgent = "ExpectedUserAgent/1.0"
|
||||
)
|
||||
|
||||
val requestBody = ChatCompletionRequestDto(
|
||||
model = "deepseek-v3.1",
|
||||
messages = listOf(
|
||||
ChatMessageDto(
|
||||
role = "system",
|
||||
content = "You are a helpful assistant."
|
||||
),
|
||||
ChatMessageDto(
|
||||
role = "user",
|
||||
content = "Hello"
|
||||
)
|
||||
),
|
||||
maxTokens = 4000,
|
||||
temperature = 0.5f,
|
||||
frequencyPenalty = 0.0f,
|
||||
presencePenalty = 0.0f
|
||||
)
|
||||
|
||||
private fun getResource(path: String): String {
|
||||
return this.javaClass.classLoader!!.getResource(path)!!.readText()
|
||||
}
|
||||
|
||||
private fun chunkedResponse(body: String): MockResponse {
|
||||
return MockResponse()
|
||||
.setHeader("Content-Type", "text/event-stream")
|
||||
.setResponseCode(200)
|
||||
.setSocketPolicy(SocketPolicy.KEEP_OPEN)
|
||||
.setChunkedBody(body, 64)
|
||||
.throttleBody(64, 10, TimeUnit.MILLISECONDS)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `streamChatCompletion success - returns a flow of chat completion events`() = runTest {
|
||||
// TODO don't do this
|
||||
val responseBody = getResource("responses/chat/chat_completion_events_success.txt")
|
||||
|
||||
val messageCount = responseBody.split("\n\n").size - 1 // -1 because [DONE]
|
||||
|
||||
val response = chunkedResponse(responseBody)
|
||||
|
||||
mockApi.server.enqueue(response)
|
||||
|
||||
val expectedMessageParts = listOf(
|
||||
"Hey",
|
||||
"!",
|
||||
" What",
|
||||
"'s",
|
||||
"up",
|
||||
"?"
|
||||
)
|
||||
|
||||
val receivedEvents = mutableListOf<ChatCompletionResponseEventDto>()
|
||||
|
||||
mockApi.repository.streamChatCompletion(requestBody).test {
|
||||
assertThat(awaitItem()).isInstanceOf(SseEvent.Open::class.java)
|
||||
|
||||
repeat(messageCount) { index ->
|
||||
val event = awaitItem()
|
||||
assertThat(event).isInstanceOf(SseEvent.Event::class.java)
|
||||
|
||||
val completionEvent = (event as SseEvent.Event<ChatCompletionResponseEventDto>).data
|
||||
assertThat(completionEvent.id).isEqualTo("chatcmpl-1234567890abcdefgh")
|
||||
assertThat(completionEvent.objectType).isEqualTo("chat.completion.chunk")
|
||||
receivedEvents.add(completionEvent)
|
||||
|
||||
val choice = completionEvent.choices[0]
|
||||
if (choice.finishReason == null) {
|
||||
assertThat(choice.delta.content).isEqualTo(expectedMessageParts[index])
|
||||
} else {
|
||||
assertThat(choice.delta.content).isNull()
|
||||
assertThat(completionEvent.usage!!.totalTokens).isEqualTo(7)
|
||||
}
|
||||
}
|
||||
|
||||
assertThat(awaitItem()).isInstanceOf(SseEvent.Close::class.java)
|
||||
|
||||
awaitComplete()
|
||||
}
|
||||
|
||||
assertThat(receivedEvents.size).isEqualTo(messageCount)
|
||||
|
||||
val recorded = mockApi.server.takeRequest()
|
||||
assertThat(recorded.method).isEqualTo("POST")
|
||||
assertThat(recorded.path).isEqualTo("/chat/completions")
|
||||
assertThat(recorded.headers["Authorization"]).isEqualTo("Bearer totally-valid-key")
|
||||
assertThat(recorded.headers["User-Agent"]).isEqualTo("ExpectedUserAgent/1.0")
|
||||
|
||||
// TODO overall kinda weak test because we're not checking anything
|
||||
// TODO consider checking request body here
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `streamChatCompletion failure - invalid request`() = runTest {
|
||||
val response = MockResponse()
|
||||
.setHeader("Content-Type", "application/problem+json")
|
||||
.setResponseCode(400)
|
||||
.setBody(getResource("responses/chat/chat_completion_failure_invalid_request.json"))
|
||||
mockApi.server.enqueue(response)
|
||||
|
||||
mockApi.repository.streamChatCompletion(requestBody).test {
|
||||
val error = awaitError()
|
||||
assertThat(error).isInstanceOf(ApiException::class.java)
|
||||
|
||||
error as ApiException
|
||||
assertThat(error.message).isEqualTo("Missing or invalid required parameter: messages")
|
||||
assertThat(error.errorType).isEqualTo("invalid_request_error")
|
||||
assertThat(error.errorSubType).isEqualTo(null)
|
||||
assertThat(error.clientError).isEqualTo(true)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `streamChatCompletion failure - invalid response`() = runTest {
|
||||
val response = MockResponse()
|
||||
.setHeader("Content-Type", "text/event-stream")
|
||||
.setResponseCode(200)
|
||||
.setBody(getResource("responses/chat/chat_completion_failure_invalid_response.txt"))
|
||||
mockApi.server.enqueue(response)
|
||||
|
||||
mockApi.repository.streamChatCompletion(requestBody).test {
|
||||
assertThat(awaitItem()).isInstanceOf(SseEvent.Open::class.java)
|
||||
|
||||
for (i in 1..7) {
|
||||
assertThat(awaitItem()).isInstanceOf(SseEvent.Event::class.java)
|
||||
}
|
||||
|
||||
val error = awaitError()
|
||||
assertThat(error).isInstanceOf(JsonEncodingException::class.java)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `streamChatCompletion failure - server error`() = runTest {
|
||||
val response = MockResponse()
|
||||
.setHeader("Content-Type", "application/problem+json")
|
||||
.setResponseCode(503)
|
||||
.setBody(getResource("responses/error_services_unavailable.json"))
|
||||
mockApi.server.enqueue(response)
|
||||
|
||||
mockApi.repository.streamChatCompletion(requestBody).test {
|
||||
val error = awaitError()
|
||||
assertThat(error).isInstanceOf(ApiException::class.java)
|
||||
|
||||
error as ApiException
|
||||
assertThat(error.message).isEqualTo("All available services are currently unavailable. Please try again later.")
|
||||
assertThat(error.errorType).isEqualTo("service_unavailable")
|
||||
assertThat(error.errorSubType).isEqualTo("all_fallbacks_failed")
|
||||
assertThat(error.clientError).isEqualTo(false)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `streamChatCompletion failure - incomplete response`() = runTest {
|
||||
val responseBody = getResource("responses/chat/chat_completion_events_failure_incomplete.txt")
|
||||
|
||||
val messageCount = responseBody.split("\n\n").size - 1
|
||||
|
||||
val response = MockResponse()
|
||||
.setHeader("Content-Type", "text/event-stream")
|
||||
.setResponseCode(200)
|
||||
.setSocketPolicy(SocketPolicy.KEEP_OPEN)
|
||||
.setChunkedBody(responseBody, 64)
|
||||
.throttleBody(64, 10, TimeUnit.MILLISECONDS)
|
||||
mockApi.server.enqueue(response)
|
||||
|
||||
val receivedEvents = mutableListOf<ChatCompletionResponseEventDto>()
|
||||
|
||||
mockApi.repository.streamChatCompletion(requestBody).test {
|
||||
println("Consumer: Awaiting Open event...")
|
||||
assertThat(awaitItem()).isInstanceOf(SseEvent.Open::class.java)
|
||||
|
||||
repeat(messageCount) { index ->
|
||||
println("Consumer: Awaiting event #${index + 1}")
|
||||
|
||||
val event = awaitItem()
|
||||
assertThat(event).isInstanceOf(SseEvent.Event::class.java)
|
||||
|
||||
val completionEvent = (event as SseEvent.Event<ChatCompletionResponseEventDto>).data
|
||||
assertThat(completionEvent.id).isEqualTo("chatcmpl-1234567890abcdefgh")
|
||||
assertThat(completionEvent.objectType).isEqualTo("chat.completion.chunk")
|
||||
receivedEvents.add(completionEvent)
|
||||
|
||||
// TODO verify that
|
||||
val choice = completionEvent.choices[0]
|
||||
println("Received partial content: ${choice.delta.content}")
|
||||
if (choice.finishReason != null) {
|
||||
println("Finish reason: ${choice.finishReason}")
|
||||
}
|
||||
}
|
||||
|
||||
println("Consumer: Awaiting error...")
|
||||
assertThat(awaitError()).isInstanceOf(IllegalArgumentException::class.java)
|
||||
}
|
||||
|
||||
assertThat(receivedEvents.size).isEqualTo(messageCount)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `streamChatComplete failure - JSON response`() = runTest {
|
||||
val response = MockResponse()
|
||||
.setHeader("Content-Type", "application/json")
|
||||
.setResponseCode(200)
|
||||
.setSocketPolicy(SocketPolicy.KEEP_OPEN)
|
||||
.setBody("{}")
|
||||
mockApi.server.enqueue(response)
|
||||
|
||||
mockApi.repository.streamChatCompletion(requestBody).test {
|
||||
val error = awaitError()
|
||||
assertThat(error).isInstanceOf(IllegalStateException::class.java)
|
||||
assertThat(error.message).isEqualTo("Invalid content-type: application/json")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,163 @@
|
|||
package eu.m724.newchat.aiapi.endpoints
|
||||
|
||||
import com.google.common.base.Preconditions.checkNotNull
|
||||
import com.google.common.truth.Truth.assertThat
|
||||
import com.squareup.moshi.JsonDataException
|
||||
import com.squareup.moshi.JsonEncodingException
|
||||
import eu.m724.newchat.aiapi.models.dto.lm.LanguageModelDto
|
||||
import eu.m724.newchat.aiapi.models.dto.lm.LanguageModelPricingDto
|
||||
import eu.m724.newchat.aiapi.retrofit.exception.ApiException
|
||||
import eu.m724.newchat.aiapi.rules.MockApiRule
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import okhttp3.mockwebserver.MockResponse
|
||||
import okhttp3.mockwebserver.SocketPolicy
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import java.io.IOException
|
||||
|
||||
class LanguageModelsTest {
|
||||
@Rule
|
||||
@JvmField
|
||||
val mockApi = MockApiRule(
|
||||
apiKey = "totally-valid-key",
|
||||
userAgent = "ExpectedUserAgent/1.0"
|
||||
)
|
||||
|
||||
private fun getResource(path: String): String {
|
||||
return checkNotNull(this.javaClass.classLoader!!.getResource(path), "No such test asset $path").readText()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getLanguageModels success - returns the list of available language models`() = runTest {
|
||||
val expectedModels = listOf(
|
||||
LanguageModelDto(
|
||||
id = "model-alpha",
|
||||
name = "The First Model",
|
||||
description = "The first model ever created",
|
||||
contextLength = 3000,
|
||||
pricing = LanguageModelPricingDto(
|
||||
inputCostPerMillionTokens = 11.0,
|
||||
completionCostPerMillionTokens = 3.0
|
||||
),
|
||||
iconUrl = "/icons/icon-192x192.png"
|
||||
),
|
||||
LanguageModelDto(
|
||||
id = "beta-model",
|
||||
name = "The Second Model",
|
||||
description = "Most intelligent model",
|
||||
contextLength = 2147483647,
|
||||
pricing = LanguageModelPricingDto(
|
||||
inputCostPerMillionTokens = 9876543.21,
|
||||
completionCostPerMillionTokens = 9999999.9
|
||||
),
|
||||
iconUrl = "/icons/icon-192x192.png"
|
||||
)
|
||||
)
|
||||
|
||||
val mockResponse = MockResponse()
|
||||
.setResponseCode(200)
|
||||
.setBody(getResource("responses/lm/get_models_success.json"))
|
||||
mockApi.server.enqueue(mockResponse)
|
||||
|
||||
val result = mockApi.repository.getLanguageModels()
|
||||
|
||||
assertThat(result.isSuccess).isTrue()
|
||||
|
||||
val models = result.getOrThrow().data
|
||||
assertThat(models).containsExactlyElementsIn(expectedModels)
|
||||
|
||||
val recorded = mockApi.server.takeRequest()
|
||||
assertThat(recorded.method).isEqualTo("GET")
|
||||
assertThat(recorded.path).isEqualTo("/models?detailed=true")
|
||||
assertThat(recorded.headers["Authorization"]).isEqualTo("Bearer totally-valid-key")
|
||||
assertThat(recorded.headers["User-Agent"]).isEqualTo("ExpectedUserAgent/1.0")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getLanguageModels success - empty list`() = runTest {
|
||||
val mockResponse = MockResponse()
|
||||
.setResponseCode(200)
|
||||
.setBody(getResource("responses/lm/get_models_success_empty.json"))
|
||||
mockApi.server.enqueue(mockResponse)
|
||||
|
||||
val result = mockApi.repository.getLanguageModels()
|
||||
|
||||
assertThat(result.isSuccess).isTrue()
|
||||
|
||||
val models = result.getOrThrow().data
|
||||
assertThat(models).hasSize(0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getLanguageModels failure - invalid response`() = runTest {
|
||||
val mockResponse = MockResponse()
|
||||
.setResponseCode(200)
|
||||
.setBody(getResource("responses/lm/get_models_failure_invalid_response.json"))
|
||||
mockApi.server.enqueue(mockResponse)
|
||||
|
||||
val result = mockApi.repository.getLanguageModels()
|
||||
|
||||
assertThat(result.isFailure).isTrue()
|
||||
assertThat(result.getOrNull()).isEqualTo(null)
|
||||
assertThat(result.exceptionOrNull()).isInstanceOf(JsonDataException::class.java)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getLanguageModels failure - non-JSON response`() = runTest {
|
||||
val mockResponse = MockResponse()
|
||||
.setResponseCode(200)
|
||||
.setBody("""hello world""")
|
||||
mockApi.server.enqueue(mockResponse)
|
||||
|
||||
val result = mockApi.repository.getLanguageModels()
|
||||
|
||||
assertThat(result.isFailure).isTrue()
|
||||
assertThat(result.getOrNull()).isEqualTo(null)
|
||||
assertThat(result.exceptionOrNull()).isInstanceOf(JsonEncodingException::class.java)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getLanguageModels failure - server error`() = runTest {
|
||||
val mockResponse = MockResponse()
|
||||
.setResponseCode(503)
|
||||
.setBody(getResource("responses/error_services_unavailable.json"))
|
||||
mockApi.server.enqueue(mockResponse)
|
||||
|
||||
val result = mockApi.repository.getLanguageModels()
|
||||
assertThat(result.isFailure).isTrue()
|
||||
assertThat(result.getOrNull()).isEqualTo(null)
|
||||
|
||||
val exception = result.exceptionOrNull()
|
||||
assertThat(exception).isInstanceOf(ApiException::class.java)
|
||||
|
||||
val apiException = exception as ApiException
|
||||
assertThat(apiException.message).isEqualTo("All available services are currently unavailable. Please try again later.")
|
||||
assertThat(apiException.errorType).isEqualTo("service_unavailable")
|
||||
assertThat(apiException.errorSubType).isEqualTo("all_fallbacks_failed")
|
||||
assertThat(apiException.clientError).isEqualTo(false)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getLanguageModels failure - network disconnect`() = runTest {
|
||||
val mockResponse = MockResponse()
|
||||
.setSocketPolicy(SocketPolicy.DISCONNECT_AT_START)
|
||||
mockApi.server.enqueue(mockResponse)
|
||||
|
||||
val result = mockApi.repository.getLanguageModels()
|
||||
|
||||
assertThat(result.isFailure).isTrue()
|
||||
assertThat(result.exceptionOrNull()).isInstanceOf(IOException::class.java)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getLanguageModels failure - 204 no response`() = runTest {
|
||||
val mockResponse = MockResponse()
|
||||
.setResponseCode(204)
|
||||
mockApi.server.enqueue(mockResponse)
|
||||
|
||||
val result = mockApi.repository.getLanguageModels()
|
||||
|
||||
assertThat(result.isFailure).isTrue()
|
||||
assertThat(result.exceptionOrNull()).isInstanceOf(KotlinNullPointerException::class.java)
|
||||
}
|
||||
}
|
25
aiapi/src/test/java/eu/m724/newchat/aiapi/models/DtoTest.kt
Normal file
25
aiapi/src/test/java/eu/m724/newchat/aiapi/models/DtoTest.kt
Normal file
|
@ -0,0 +1,25 @@
|
|||
package eu.m724.newchat.aiapi.models
|
||||
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.ChatMessageDto
|
||||
import org.junit.Test
|
||||
|
||||
class DtoTest {
|
||||
@Test
|
||||
fun `ChatMessageDto - test role validation`() {
|
||||
val validRoles = listOf("user", "assistant", "system")
|
||||
val invalidRoles = listOf("invalid1", "invalid2", "invalid3")
|
||||
|
||||
// Test valid roles
|
||||
validRoles.forEach { role ->
|
||||
ChatMessageDto(role = role, content = "Hello")
|
||||
}
|
||||
|
||||
// Test invalid roles
|
||||
invalidRoles.forEach { role ->
|
||||
try {
|
||||
ChatMessageDto(role = role, content = "Hello")
|
||||
} catch (_: IllegalArgumentException) { }
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
package eu.m724.newchat.aiapi.rules
|
||||
|
||||
import eu.m724.newchat.aiapi.AiApiConfiguration
|
||||
import eu.m724.newchat.aiapi.AiApiDataLayerFactory
|
||||
import eu.m724.newchat.aiapi.repository.AiApiRepository
|
||||
import okhttp3.mockwebserver.MockWebServer
|
||||
import org.junit.rules.ExternalResource
|
||||
|
||||
class MockApiRule(
|
||||
private val apiKey: String,
|
||||
private val userAgent: String
|
||||
) : ExternalResource() {
|
||||
lateinit var server: MockWebServer
|
||||
private set
|
||||
lateinit var repository: AiApiRepository
|
||||
private set
|
||||
|
||||
override fun before() {
|
||||
server = MockWebServer()
|
||||
|
||||
val configuration = object : AiApiConfiguration {
|
||||
override val apiKey: String = this@MockApiRule.apiKey
|
||||
override val endpoint: String = server.url("/").toString()
|
||||
override val userAgent: String = this@MockApiRule.userAgent
|
||||
override val isDebug: Boolean = true
|
||||
}
|
||||
|
||||
repository = AiApiDataLayerFactory.createApiRepository(configuration)
|
||||
}
|
||||
|
||||
override fun after() {
|
||||
server.shutdown()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"Hey"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" 😊"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"'m"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" doing"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" great"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"—"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"just"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" here"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" and"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":123
|
|
@ -0,0 +1,15 @@
|
|||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"Hey"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" What"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"'s"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" up"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":6,"total_tokens":7}}
|
||||
|
||||
data: [DONE]
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"error": {
|
||||
"message": "Missing or invalid required parameter: messages",
|
||||
"type": "invalid_request_error",
|
||||
"param": "messages",
|
||||
"code": null
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"Hey"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" 😊"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"'m"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" doing"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" great"},"finish_reason":null}]}
|
||||
|
||||
data: hello world
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":"just"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" here"},"finish_reason":null}]}
|
||||
|
||||
data: {"id""chatcmpl-1234567890abcdefgh","object":"chat.completion.chunk","created":1234567890,"model":"deepseek-v3.1","choices":[{"index":0,"delta":{"content":" and"},"finish_reason":null}]}
|
||||
|
||||
data: [DONE]
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"error": {
|
||||
"message": "All available services are currently unavailable. Please try again later.",
|
||||
"status": 503,
|
||||
"type": "service_unavailable",
|
||||
"param" :null,
|
||||
"code": "all_fallbacks_failed"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"data": "yes wrong type"
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "model-alpha",
|
||||
"object": "model",
|
||||
"created": 1755681330,
|
||||
"owned_by": "organization-owner",
|
||||
"name": "The First Model",
|
||||
"description": "The first model ever created",
|
||||
"context_length": 3000,
|
||||
"pricing": {
|
||||
"prompt": 11.0,
|
||||
"completion": 3.0,
|
||||
"currency": "USD",
|
||||
"unit": "per_million_tokens"
|
||||
},
|
||||
"cost_estimate": 0.01,
|
||||
"icon_url": "/icons/icon-192x192.png"
|
||||
},
|
||||
{
|
||||
"id": "beta-model",
|
||||
"object": "model",
|
||||
"created": 1755681330,
|
||||
"owned_by": "organization-owner",
|
||||
"name": "The Second Model",
|
||||
"description": "Most intelligent model",
|
||||
"context_length": 2147483647,
|
||||
"pricing": {
|
||||
"prompt": 9876543.21,
|
||||
"completion": 9999999.9,
|
||||
"currency": "USD",
|
||||
"unit": "per_million_tokens"
|
||||
},
|
||||
"cost_estimate": 0.01,
|
||||
"icon_url": "/icons/icon-192x192.png",
|
||||
"extra field": "we don't care about"
|
||||
}
|
||||
],
|
||||
"extra field": "we don't care about"
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"object": "list",
|
||||
"data": [],
|
||||
"extra_field": "we don't care about"
|
||||
}
|
|
@ -5,12 +5,11 @@ plugins {
|
|||
alias(libs.plugins.hilt.android)
|
||||
alias(libs.plugins.ksp)
|
||||
alias(libs.plugins.secrets)
|
||||
alias(libs.plugins.parcelize)
|
||||
}
|
||||
|
||||
android {
|
||||
namespace = "eu.m724.chatapp"
|
||||
compileSdk = 35
|
||||
compileSdk = 36
|
||||
|
||||
defaultConfig {
|
||||
applicationId = "eu.m724.chatapp"
|
||||
|
@ -33,12 +32,12 @@ android {
|
|||
}
|
||||
|
||||
compileOptions {
|
||||
sourceCompatibility = JavaVersion.VERSION_11
|
||||
targetCompatibility = JavaVersion.VERSION_11
|
||||
sourceCompatibility = JavaVersion.VERSION_17
|
||||
targetCompatibility = JavaVersion.VERSION_17
|
||||
}
|
||||
|
||||
kotlinOptions {
|
||||
jvmTarget = "11"
|
||||
jvmTarget = "17"
|
||||
}
|
||||
|
||||
buildFeatures {
|
||||
|
@ -52,6 +51,8 @@ android {
|
|||
}
|
||||
|
||||
dependencies {
|
||||
implementation(project(":aiapi"))
|
||||
implementation(project(":storage"))
|
||||
implementation(libs.androidx.core.ktx)
|
||||
implementation(libs.androidx.appcompat)
|
||||
implementation(libs.material)
|
||||
|
@ -63,16 +64,11 @@ dependencies {
|
|||
implementation(libs.androidx.ui.tooling.preview)
|
||||
implementation(libs.androidx.material3)
|
||||
implementation(libs.hilt.android)
|
||||
implementation(libs.retrofit)
|
||||
implementation(libs.retrofit.converter.gson)
|
||||
implementation(libs.androidx.material3.window.size.class1)
|
||||
implementation(libs.okhttp.sse)
|
||||
implementation(libs.androidx.datastore)
|
||||
implementation(libs.hilt.navigation.compose)
|
||||
implementation(libs.androidx.room.runtime)
|
||||
implementation(libs.androidx.room.compiler)
|
||||
implementation(libs.androidx.room.paging)
|
||||
implementation(libs.androidx.room.ktx)
|
||||
implementation(libs.paging.runtime.ktx)
|
||||
implementation(libs.paging.compose)
|
||||
testImplementation(libs.junit)
|
||||
androidTestImplementation(libs.androidx.junit)
|
||||
androidTestImplementation(libs.androidx.espresso.core)
|
||||
|
@ -80,6 +76,5 @@ dependencies {
|
|||
androidTestImplementation(libs.androidx.ui.test.junit4)
|
||||
debugImplementation(libs.androidx.ui.tooling)
|
||||
debugImplementation(libs.androidx.ui.test.manifest)
|
||||
debugImplementation(libs.logging.interceptor)
|
||||
ksp(libs.hilt.compiler)
|
||||
}
|
|
@ -18,7 +18,6 @@
|
|||
<activity
|
||||
android:name=".activity.main.MainActivity"
|
||||
android:exported="true"
|
||||
android:label="@string/title_activity_main"
|
||||
android:theme="@style/Theme.ChatApp">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package eu.m724.chatapp.activity.chat
|
||||
|
||||
import android.os.Bundle
|
||||
import android.util.Log
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.compose.setContent
|
||||
import androidx.activity.enableEdgeToEdge
|
||||
|
@ -19,7 +20,6 @@ import androidx.compose.foundation.layout.padding
|
|||
import androidx.compose.foundation.layout.width
|
||||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.LazyListState
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.foundation.lazy.rememberLazyListState
|
||||
import androidx.compose.material3.CenterAlignedTopAppBar
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
|
@ -50,17 +50,17 @@ import androidx.lifecycle.compose.collectAsStateWithLifecycle
|
|||
import dagger.hilt.android.AndroidEntryPoint
|
||||
import eu.m724.chatapp.R
|
||||
import eu.m724.chatapp.activity.chat.composable.ChatToolBar
|
||||
import eu.m724.chatapp.activity.ui.composable.AnimatedChangingText
|
||||
import eu.m724.chatapp.activity.chat.composable.LanguageModelMistakeWarning
|
||||
import eu.m724.chatapp.activity.chat.composable.thread.ChatMessageComposer
|
||||
import eu.m724.chatapp.activity.chat.composable.thread.ChatResponseErrorNotice
|
||||
import eu.m724.chatapp.activity.chat.quick_settings.ChatQuickSettingsEvent
|
||||
import eu.m724.chatapp.activity.chat.state.ChatComposerState
|
||||
import eu.m724.chatapp.activity.chat.state.rememberChatComposerState
|
||||
import eu.m724.chatapp.activity.ui.composable.AnimatedChangingText
|
||||
import eu.m724.chatapp.activity.ui.composable.disableBringIntoViewOnFocus
|
||||
import eu.m724.chatapp.activity.ui.composable.hideKeyboardOnScrollUp
|
||||
import eu.m724.chatapp.activity.ui.theme.ChatAppTheme
|
||||
import eu.m724.chatapp.api.data.response.completion.ChatMessage
|
||||
import eu.m724.chatapp.model.ChatMessage
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
@AndroidEntryPoint
|
||||
|
@ -94,12 +94,15 @@ class ChatActivity : ComponentActivity() {
|
|||
}
|
||||
}
|
||||
|
||||
var savedFocus by remember { mutableStateOf(false) }
|
||||
val chat by viewModel.chat.collectAsStateWithLifecycle()
|
||||
val messages by viewModel.messages.collectAsStateWithLifecycle()
|
||||
|
||||
ChatScreen(
|
||||
windowSizeClass = windowSizeClass,
|
||||
uiState = uiState,
|
||||
chatComposerState = chatState,
|
||||
chatTitle = chat.title ?: stringResource(R.string.title_new_conversation),
|
||||
messages = messages,
|
||||
threadViewLazyListState = threadViewLazyListState,
|
||||
snackbarHostState = snackbarHostState,
|
||||
onSend = onSend,
|
||||
|
@ -108,8 +111,8 @@ class ChatActivity : ComponentActivity() {
|
|||
|
||||
coroutineScope.launch {
|
||||
if (threadViewLazyListState.layoutInfo.visibleItemsInfo.find { it.key == "composer" } == null) {
|
||||
if (uiState.chat.messages.isNotEmpty()) {
|
||||
threadViewLazyListState.animateScrollToItem(uiState.chat.messages.size)
|
||||
if (messages.isNotEmpty()) {
|
||||
threadViewLazyListState.animateScrollToItem(messages.size)
|
||||
// TODO this makes the composer full screen but if the condition above is false it doesn't, that may be kind of unintuitive
|
||||
}
|
||||
}
|
||||
|
@ -131,15 +134,17 @@ class ChatActivity : ComponentActivity() {
|
|||
)
|
||||
|
||||
LaunchedEffect(uiState.requestInProgress) {
|
||||
Log.d("ChatActivity", "Request in progress: ${uiState.requestInProgress}, messages: ${messages.size}")
|
||||
|
||||
if (uiState.requestInProgress) {
|
||||
chatState.composerValue = ""
|
||||
|
||||
// scroll to the last user message
|
||||
threadViewLazyListState.animateScrollToItem(uiState.chat.messages.size - 2)
|
||||
threadViewLazyListState.animateScrollToItem(messages.size - 1)
|
||||
} else {
|
||||
if (uiState.chat.messages.size > 1) {
|
||||
if (messages.size > 1) {
|
||||
// scroll to the last user message too
|
||||
threadViewLazyListState.animateScrollToItem(uiState.chat.messages.size - 2)
|
||||
threadViewLazyListState.animateScrollToItem(messages.size - 2)
|
||||
}
|
||||
|
||||
if (uiState.lastResponseError == null) {
|
||||
|
@ -168,6 +173,19 @@ class ChatActivity : ComponentActivity() {
|
|||
withDismissAction = true
|
||||
)
|
||||
}
|
||||
is ChatActivityUiEvent.Loaded -> {
|
||||
val lastMessage = messages.lastOrNull()
|
||||
if (lastMessage != null) {
|
||||
// ¯\_(ツ)_/¯
|
||||
if (lastMessage.role == ChatMessage.Role.Assistant) {
|
||||
threadViewLazyListState.animateScrollToItem(messages.size - 2)
|
||||
} else if (messages.size > 2) {
|
||||
threadViewLazyListState.animateScrollToItem(messages.size - 3)
|
||||
} else {
|
||||
threadViewLazyListState.animateScrollToItem(messages.size - 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -181,6 +199,8 @@ fun ChatScreen(
|
|||
uiState: ChatActivityUiState,
|
||||
chatComposerState: ChatComposerState,
|
||||
threadViewLazyListState: LazyListState,
|
||||
chatTitle: String,
|
||||
messages: List<ChatMessage>,
|
||||
snackbarHostState: SnackbarHostState,
|
||||
onSend: () -> Unit,
|
||||
onRequestFocus: () -> Unit,
|
||||
|
@ -193,7 +213,7 @@ fun ChatScreen(
|
|||
modifier = Modifier.fillMaxSize(),
|
||||
topBar = {
|
||||
ChatTopAppBar(
|
||||
title = uiState.chat.title ?: stringResource(R.string.title_new_conversation)
|
||||
title = chatTitle
|
||||
)
|
||||
},
|
||||
snackbarHost = {
|
||||
|
@ -201,7 +221,7 @@ fun ChatScreen(
|
|||
hostState = snackbarHostState,
|
||||
modifier = Modifier
|
||||
.imePadding()
|
||||
.padding(bottom = 80.dp) // Excuse the magic value. This is the approximate height of the toolbar + AI warning.
|
||||
.padding(bottom = 80.dp) // This is the approximate height of the toolbar + AI warning. TODO not do that
|
||||
)
|
||||
}
|
||||
) { innerPadding ->
|
||||
|
@ -212,6 +232,7 @@ fun ChatScreen(
|
|||
isTablet = isTablet,
|
||||
uiState = uiState,
|
||||
chatComposerState = chatComposerState,
|
||||
messages = messages,
|
||||
threadViewLazyListState = threadViewLazyListState,
|
||||
onSend = onSend,
|
||||
onRequestFocus = onRequestFocus,
|
||||
|
@ -227,6 +248,7 @@ fun ChatScreenContent(
|
|||
isTablet: Boolean,
|
||||
uiState: ChatActivityUiState,
|
||||
chatComposerState: ChatComposerState,
|
||||
messages: List<ChatMessage>,
|
||||
threadViewLazyListState: LazyListState,
|
||||
onSend: () -> Unit,
|
||||
onRequestFocus: () -> Unit,
|
||||
|
@ -262,9 +284,9 @@ fun ChatScreenContent(
|
|||
.fillMaxSize()
|
||||
.padding(horizontal = 24.dp),
|
||||
lazyListState = threadViewLazyListState,
|
||||
messages = uiState.chat.messages,
|
||||
uiState = uiState,
|
||||
chatComposerState = chatComposerState
|
||||
chatComposerState = chatComposerState,
|
||||
messages = messages
|
||||
)
|
||||
},
|
||||
{
|
||||
|
@ -295,9 +317,9 @@ fun ChatScreenContent(
|
|||
@Composable
|
||||
fun ThreadView(
|
||||
lazyListState: LazyListState,
|
||||
messages: List<ChatMessage>,
|
||||
uiState: ChatActivityUiState,
|
||||
chatComposerState: ChatComposerState,
|
||||
messages: List<ChatMessage>,
|
||||
modifier: Modifier = Modifier
|
||||
) {
|
||||
val localSoftwareKeyboardController = LocalSoftwareKeyboardController.current
|
||||
|
@ -307,7 +329,11 @@ fun ThreadView(
|
|||
.hideKeyboardOnScrollUp(localSoftwareKeyboardController!!),
|
||||
state = lazyListState
|
||||
) {
|
||||
items(messages) { message ->
|
||||
items(
|
||||
count = messages.size
|
||||
) { index ->
|
||||
val message = messages[index]
|
||||
|
||||
if (message.role == ChatMessage.Role.User) {
|
||||
ChatMessagePrompt(
|
||||
modifier = Modifier.padding(vertical = 10.dp),
|
||||
|
@ -315,12 +341,20 @@ fun ThreadView(
|
|||
)
|
||||
} else if (message.role == ChatMessage.Role.Assistant) {
|
||||
ChatMessageResponse(
|
||||
modifier = Modifier.animateContentSize(),
|
||||
content = message.content
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (uiState.requestInProgress) {
|
||||
item(key = "liveResponse") {
|
||||
ChatMessageResponse(
|
||||
modifier = Modifier.animateContentSize(),
|
||||
content = uiState.liveResponse
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (uiState.lastResponseError != null) {
|
||||
item(key = "error") {
|
||||
ChatResponseErrorNotice(
|
||||
|
@ -329,23 +363,25 @@ fun ThreadView(
|
|||
}
|
||||
}
|
||||
|
||||
item(key = "composer") {
|
||||
if (!uiState.requestInProgress) {
|
||||
ChatMessageComposer(
|
||||
modifier = Modifier
|
||||
.fillParentMaxHeight() // so that you can click anywhere on the screen to focus the text field
|
||||
.disableBringIntoViewOnFocus()
|
||||
.focusRequester(chatComposerState.focusRequester),
|
||||
value = chatComposerState.composerValue,
|
||||
onValueChange = {
|
||||
chatComposerState.composerValue = it
|
||||
}
|
||||
)
|
||||
} else {
|
||||
// so basically if this was absent, there would be no space anymore to scroll below, and if the conversation were short enough, it would jump to the top because you can't scroll to something that's not there.
|
||||
Spacer(
|
||||
modifier = Modifier.fillParentMaxHeight()
|
||||
)
|
||||
if (uiState.enableComposer) {
|
||||
item(key = "composer") {
|
||||
if (!uiState.requestInProgress) {
|
||||
ChatMessageComposer(
|
||||
modifier = Modifier
|
||||
.fillParentMaxHeight() // so that you can click anywhere on the screen to focus the text field
|
||||
.disableBringIntoViewOnFocus()
|
||||
.focusRequester(chatComposerState.focusRequester),
|
||||
value = chatComposerState.composerValue,
|
||||
onValueChange = {
|
||||
chatComposerState.composerValue = it
|
||||
}
|
||||
)
|
||||
} else {
|
||||
// so basically if this was absent, there would be no space anymore to scroll below, and if the conversation were short enough, it would jump to the top because you can't scroll to something that's not there.
|
||||
Spacer(
|
||||
modifier = Modifier.fillParentMaxHeight()
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -356,7 +392,7 @@ fun ChatMessagePrompt(
|
|||
content: String,
|
||||
modifier: Modifier = Modifier
|
||||
) {
|
||||
// TODO
|
||||
// TODO not do this
|
||||
var animate by rememberSaveable { mutableStateOf(false) }
|
||||
|
||||
val textPadding by animateDpAsState(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package eu.m724.chatapp.activity.chat
|
||||
|
||||
sealed interface ChatActivityUiEvent {
|
||||
data object Loaded : ChatActivityUiEvent
|
||||
data class Error(val error: String): ChatActivityUiEvent
|
||||
}
|
|
@ -1,15 +1,19 @@
|
|||
package eu.m724.chatapp.activity.chat
|
||||
|
||||
import eu.m724.chatapp.store.data.Chat
|
||||
import eu.m724.chatapp.store.data.ChatResponseError
|
||||
import eu.m724.chatapp.model.ChatCompletionError
|
||||
|
||||
data class ChatActivityUiState(
|
||||
val chat: Chat,
|
||||
|
||||
/**
|
||||
* Whether a request is in progress (a response is streaming)
|
||||
*/
|
||||
val requestInProgress: Boolean = false,
|
||||
|
||||
val lastResponseError: ChatResponseError? = null
|
||||
/**
|
||||
* The live response, streaming
|
||||
*/
|
||||
val liveResponse: String = "",
|
||||
|
||||
val lastResponseError: ChatCompletionError? = null,
|
||||
|
||||
val enableComposer: Boolean = true
|
||||
)
|
|
@ -1,168 +1,273 @@
|
|||
package eu.m724.chatapp.activity.chat
|
||||
|
||||
import android.util.Log
|
||||
import androidx.lifecycle.SavedStateHandle
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import eu.m724.chatapp.api.AiApiService
|
||||
import eu.m724.chatapp.api.data.request.completion.ChatCompletionRequest
|
||||
import eu.m724.chatapp.api.data.response.completion.ChatCompletionResponseEvent
|
||||
import eu.m724.chatapp.api.data.response.completion.ChatMessage
|
||||
import eu.m724.chatapp.api.data.response.completion.CompletionFinishReason
|
||||
import eu.m724.chatapp.api.data.response.models.LanguageModel
|
||||
import eu.m724.chatapp.api.retrofit.sse.SseEvent
|
||||
import eu.m724.chatapp.store.data.Chat
|
||||
import eu.m724.chatapp.store.data.ChatResponseError
|
||||
import eu.m724.chat.storage.repository.ChatStorageRepository
|
||||
import eu.m724.chatapp.BuildConfig
|
||||
import eu.m724.chatapp.model.Chat
|
||||
import eu.m724.chatapp.model.ChatCompletionError
|
||||
import eu.m724.chatapp.model.ChatMessage
|
||||
import eu.m724.chatapp.model.toChat
|
||||
import eu.m724.chatapp.model.toChatMessage
|
||||
import eu.m724.chatapp.model.toDto
|
||||
import eu.m724.chatapp.model.toEntity
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.ChatCompletionRequestDto
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.CompletionChoiceDto
|
||||
import eu.m724.newchat.aiapi.models.dto.lm.LanguageModelDto
|
||||
import eu.m724.newchat.aiapi.models.repo.ChatCompletionResponseChunk
|
||||
import eu.m724.newchat.aiapi.repository.AiApiRepository
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.channels.Channel
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.catch
|
||||
import kotlinx.coroutines.flow.firstOrNull
|
||||
import kotlinx.coroutines.flow.launchIn
|
||||
import kotlinx.coroutines.flow.onCompletion
|
||||
import kotlinx.coroutines.flow.onEach
|
||||
import kotlinx.coroutines.flow.receiveAsFlow
|
||||
import kotlinx.coroutines.flow.update
|
||||
import kotlinx.coroutines.launch
|
||||
import java.time.Instant
|
||||
import javax.inject.Inject
|
||||
import kotlin.random.Random
|
||||
|
||||
@HiltViewModel
|
||||
class ChatActivityViewModel @Inject constructor(
|
||||
private val aiApiService: AiApiService,
|
||||
private val aiApiRepository: AiApiRepository,
|
||||
private val chatStorageRepository: ChatStorageRepository,
|
||||
savedStateHandle: SavedStateHandle
|
||||
) : ViewModel() {
|
||||
private val _uiState = MutableStateFlow(ChatActivityUiState(
|
||||
chat = savedStateHandle.get<Chat>("chat") ?: throw IllegalStateException("Chat not provided")
|
||||
))
|
||||
private val chatId: Long = savedStateHandle["chatId"] ?: Random.nextLong()
|
||||
private val dummyChat = Chat(
|
||||
id = chatId,
|
||||
lastUpdated = Instant.MIN,
|
||||
title = null,
|
||||
subtitle = null,
|
||||
selectedModel = BuildConfig.DEFAULT_MODEL
|
||||
)
|
||||
|
||||
private val _uiState = MutableStateFlow(ChatActivityUiState())
|
||||
|
||||
val uiState: StateFlow<ChatActivityUiState> = _uiState.asStateFlow()
|
||||
|
||||
private val _uiEvents = Channel<ChatActivityUiEvent>()
|
||||
val uiEvents = _uiEvents.receiveAsFlow()
|
||||
|
||||
private val messages = mutableListOf<ChatMessage>()
|
||||
private val _chat = MutableStateFlow(dummyChat)
|
||||
val chat = _chat.asStateFlow()
|
||||
|
||||
/*val chat = chatStorageRepository.getChat(chatId)
|
||||
.map { it?.toChat() ?: Chat(chatId, null, "Mistral-Nemo-12B-Instruct-2407") }
|
||||
.stateIn(
|
||||
scope = viewModelScope,
|
||||
started = SharingStarted.WhileSubscribed(5000),
|
||||
initialValue = Chat(chatId, null, "Mistral-Nemo-12B-Instruct-2407")
|
||||
)*/
|
||||
|
||||
private val _messages = MutableStateFlow(emptyList<ChatMessage>())
|
||||
val messages = _messages.asStateFlow()
|
||||
|
||||
init {
|
||||
Log.d("ChatActivityViewModel", "Loading chat with ID $chatId")
|
||||
|
||||
viewModelScope.launch {
|
||||
Log.d("ChatActivityViewModel", "Loading messages")
|
||||
_messages.value = chatStorageRepository.listMessages(chatId).firstOrNull()
|
||||
?.map { it.toChatMessage() } ?: emptyList()
|
||||
|
||||
chatStorageRepository.getChat(chatId).firstOrNull()?.let {
|
||||
_chat.value = it.toChat()
|
||||
}
|
||||
|
||||
_uiEvents.send(ChatActivityUiEvent.Loaded)
|
||||
|
||||
Log.d("ChatActivityViewModel", "Loaded ${_messages.value.size} messages")
|
||||
|
||||
_chat.collect {
|
||||
if (it == dummyChat) return@collect
|
||||
Log.d("ChatActivityViewModel", "Persisting chat")
|
||||
chatStorageRepository.updateChat(it.copy(lastUpdated = Instant.now()).toEntity())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Repeat the last prompt. There is no reliable way to continue a response.
|
||||
* Repeat the last prompt.
|
||||
* There is no reliable way to continue a response.
|
||||
*/
|
||||
fun repeatLastRequest() {
|
||||
var lastUserMessage = messages.removeLast()
|
||||
|
||||
if (lastUserMessage.role == ChatMessage.Role.Assistant) {
|
||||
// If we just removed an Assistant message, we must also remove the respective User message
|
||||
messages.removeLast()
|
||||
if (messages.value.isEmpty()) {
|
||||
throw IllegalStateException("No message to repeat")
|
||||
}
|
||||
|
||||
val lastUserMessage = messages.value.last()
|
||||
|
||||
viewModelScope.launch {
|
||||
if (lastUserMessage.role == ChatMessage.Role.Assistant) {
|
||||
removeMessage(lastUserMessage)
|
||||
}
|
||||
|
||||
performRequest(this)
|
||||
}
|
||||
|
||||
sendMessage(lastUserMessage.content)
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a new message.
|
||||
*/
|
||||
fun sendMessage(promptContent: String) {
|
||||
var responseContent = ""
|
||||
var error: ChatResponseError? = null
|
||||
|
||||
if (messages.lastOrNull()?.role == ChatMessage.Role.User) {
|
||||
messages.removeLast()
|
||||
} // If there was an error and no response was generated, this shouldn't be a follow-up
|
||||
|
||||
messages.add(
|
||||
ChatMessage(
|
||||
role = ChatMessage.Role.User,
|
||||
content = promptContent
|
||||
)
|
||||
val message = ChatMessage(
|
||||
// TODO id
|
||||
role = ChatMessage.Role.User,
|
||||
content = promptContent
|
||||
)
|
||||
|
||||
viewModelScope.launch {
|
||||
_chat.update {
|
||||
it.copy(
|
||||
title = it.title
|
||||
?: (promptContent.take(30) + if (promptContent.length > 30) "\u2026" else "")
|
||||
)
|
||||
}
|
||||
|
||||
// If there was an error and no response was generated, this shouldn't be a follow-up
|
||||
val lastMessage = messages.value.lastOrNull()
|
||||
if (lastMessage?.role == ChatMessage.Role.User) {
|
||||
removeMessage(lastMessage)
|
||||
}
|
||||
|
||||
addMessage(message)
|
||||
|
||||
performRequest(this)
|
||||
}
|
||||
}
|
||||
|
||||
fun selectModel(model: LanguageModelDto) {
|
||||
_chat.update {
|
||||
it.copy(
|
||||
selectedModel = model.id
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun performRequest(scope: CoroutineScope): Job {
|
||||
_uiState.update {
|
||||
it.copy(
|
||||
requestInProgress = true,
|
||||
lastResponseError = null,
|
||||
chat = it.chat.copy(
|
||||
title = it.chat.title ?: promptContent,
|
||||
messages = messages + ChatMessage(
|
||||
role = ChatMessage.Role.Assistant,
|
||||
content = responseContent
|
||||
)
|
||||
)
|
||||
liveResponse = "",
|
||||
enableComposer = true
|
||||
)
|
||||
}
|
||||
|
||||
aiApiService.getChatCompletion(
|
||||
ChatCompletionRequest(
|
||||
model = _uiState.value.chat.model.id,
|
||||
messages = messages,
|
||||
var responseContent = ""
|
||||
var error: ChatCompletionError? = null
|
||||
|
||||
val messageDtoList = messages.value.map { it.toDto() }
|
||||
|
||||
Log.d("ChatActivityViewModel", "Sending request with ${messageDtoList.size} messages:")
|
||||
messageDtoList.forEach { message ->
|
||||
Log.d("ChatActivityViewModel", "Message: ${message.role}, ${message.content}")
|
||||
}
|
||||
|
||||
return aiApiRepository.streamChatCompletion(
|
||||
ChatCompletionRequestDto(
|
||||
model = chat.value.selectedModel,
|
||||
messages = messageDtoList,
|
||||
temperature = 1.0f,
|
||||
maxTokens = 128,
|
||||
frequencyPenalty = 0.0f,
|
||||
presencePenalty = 0.0f
|
||||
)
|
||||
).onEach { event ->
|
||||
when (event) {
|
||||
is SseEvent.Open -> {
|
||||
// There is nothing to do here
|
||||
}
|
||||
is SseEvent.Event<ChatCompletionResponseEvent> -> {
|
||||
event.data.choices?.firstOrNull()?.let { choice ->
|
||||
if (choice.delta.content != null) {
|
||||
responseContent += choice.delta.content
|
||||
).onEach { chunk: ChatCompletionResponseChunk ->
|
||||
when (chunk) {
|
||||
is ChatCompletionResponseChunk.PartialContent -> {
|
||||
responseContent += chunk.completionPart
|
||||
|
||||
_uiState.update {
|
||||
it.copy(
|
||||
chat = it.chat.copy(
|
||||
messages = messages + ChatMessage(
|
||||
role = ChatMessage.Role.Assistant,
|
||||
content = responseContent
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (choice.finishReason == CompletionFinishReason.Length) {
|
||||
error = ChatResponseError.LengthLimit
|
||||
}
|
||||
_uiState.update {
|
||||
it.copy(
|
||||
liveResponse = responseContent
|
||||
)
|
||||
}
|
||||
}
|
||||
is SseEvent.Closed -> {
|
||||
// Closed is not used in case of an error
|
||||
|
||||
messages.add(
|
||||
ChatMessage(
|
||||
role = ChatMessage.Role.Assistant,
|
||||
content = responseContent
|
||||
)
|
||||
)
|
||||
}
|
||||
is SseEvent.Failure -> {
|
||||
// The below should do. More investigation is needed but I believe this should do
|
||||
is ChatCompletionResponseChunk.Finish -> {
|
||||
if (chunk.finishReason == CompletionChoiceDto.FinishReason.Length) {
|
||||
error = ChatCompletionError.LengthLimit
|
||||
}
|
||||
}
|
||||
}
|
||||
}.catch { exception ->
|
||||
// a message is not added here
|
||||
|
||||
error = ChatResponseError.Error(exception.message)
|
||||
error = ChatCompletionError.Other(exception.message)
|
||||
|
||||
_uiEvents.send(ChatActivityUiEvent.Error(exception.toString()))
|
||||
}.onCompletion {
|
||||
val message = ChatMessage(
|
||||
role = ChatMessage.Role.Assistant,
|
||||
content = responseContent
|
||||
)
|
||||
|
||||
val fatalError = responseContent.isEmpty() && error != null
|
||||
|
||||
if (!fatalError) {
|
||||
addMessage(message)
|
||||
}
|
||||
|
||||
_uiState.update {
|
||||
it.copy(
|
||||
requestInProgress = false,
|
||||
lastResponseError = error,
|
||||
chat = it.chat.copy(
|
||||
messages = messages.toList()
|
||||
)
|
||||
enableComposer = !fatalError
|
||||
)
|
||||
}
|
||||
}.launchIn(viewModelScope)
|
||||
|
||||
val truncated = responseContent.truncateAtWord(50)
|
||||
|
||||
_chat.update {
|
||||
it.copy(
|
||||
subtitle = truncated + if (responseContent.length > truncated.length) "\u2026" else "",
|
||||
lastUpdated = Instant.now()
|
||||
)
|
||||
}
|
||||
}.launchIn(scope)
|
||||
}
|
||||
|
||||
fun selectModel(model: LanguageModel) {
|
||||
_uiState.update {
|
||||
it.copy(
|
||||
chat = it.chat.copy(
|
||||
model = model
|
||||
)
|
||||
)
|
||||
private suspend fun addMessage(message: ChatMessage): Long {
|
||||
val messageId = chatStorageRepository.addMessage(message.toEntity(chatId))
|
||||
|
||||
_messages.update { it + message.copy(id = messageId) }
|
||||
|
||||
Log.d("ChatActivityViewModel", "Added message with ID ${messageId}, new size: ${_messages.value.size}")
|
||||
|
||||
return messageId
|
||||
}
|
||||
|
||||
private suspend fun removeMessage(message: ChatMessage) {
|
||||
_messages.update { messages -> messages.filterNot { it.id == message.id } }
|
||||
chatStorageRepository.deleteMessageById(message.id!!)
|
||||
|
||||
Log.d("ChatActivityViewModel", "Removed message with ID ${message.id}, new size: ${_messages.value.size}")
|
||||
}
|
||||
|
||||
// TODO expose this maybe
|
||||
private fun String.truncateAtWord(maxLength: Int): String {
|
||||
if (maxLength <= 0 || this.length <= maxLength) {
|
||||
return this
|
||||
}
|
||||
|
||||
val endIndex = this.indexOf(' ', startIndex = maxLength)
|
||||
|
||||
if (endIndex == -1) {
|
||||
return this
|
||||
}
|
||||
|
||||
return this.substring(0, endIndex)
|
||||
}
|
||||
}
|
|
@ -63,7 +63,7 @@ fun ChatToolBar(
|
|||
Column {
|
||||
AnimatedVisibility(settingsOpened) {
|
||||
ChatQuickSettings(
|
||||
modifier = Modifier.padding(16.dp), // To match the rounded corners
|
||||
modifier = Modifier.padding(horizontal = 16.dp), // To match the rounded corners
|
||||
onModelSelected = {
|
||||
settingsOpened = false
|
||||
onSettingsEvent(ChatQuickSettingsEvent.ModelSelected(it))
|
||||
|
|
|
@ -14,11 +14,11 @@ import androidx.compose.ui.Modifier
|
|||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.unit.dp
|
||||
import eu.m724.chatapp.R
|
||||
import eu.m724.chatapp.store.data.ChatResponseError
|
||||
import eu.m724.chatapp.model.ChatCompletionError
|
||||
|
||||
@Composable
|
||||
fun ChatResponseErrorNotice(
|
||||
error: ChatResponseError,
|
||||
error: ChatCompletionError,
|
||||
modifier: Modifier = Modifier
|
||||
) {
|
||||
Row(
|
||||
|
@ -35,8 +35,8 @@ fun ChatResponseErrorNotice(
|
|||
)
|
||||
|
||||
val errorMessage = when (error) {
|
||||
is ChatResponseError.LengthLimit -> stringResource(R.string.response_error_length_limit)
|
||||
is ChatResponseError.Error -> stringResource(R.string.response_error_generic)
|
||||
is ChatCompletionError.LengthLimit -> stringResource(R.string.response_error_length_limit)
|
||||
is ChatCompletionError.Other -> stringResource(R.string.response_error_generic)
|
||||
}
|
||||
|
||||
Text(
|
||||
|
|
|
@ -1,153 +1,76 @@
|
|||
package eu.m724.chatapp.activity.chat.quick_settings
|
||||
|
||||
import androidx.compose.animation.core.animateDpAsState
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.foundation.lazy.rememberLazyListState
|
||||
import androidx.compose.material3.CircularProgressIndicator
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.derivedStateOf
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.geometry.Offset
|
||||
import androidx.compose.ui.input.nestedscroll.NestedScrollConnection
|
||||
import androidx.compose.ui.input.nestedscroll.NestedScrollSource
|
||||
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.min
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import androidx.lifecycle.compose.collectAsStateWithLifecycle
|
||||
import eu.m724.chatapp.R
|
||||
import eu.m724.chatapp.activity.chat.quick_settings.composable.DismissableLazyColumn
|
||||
import eu.m724.chatapp.activity.chat.quick_settings.composable.ModelCard
|
||||
import eu.m724.chatapp.api.data.response.models.LanguageModel
|
||||
import eu.m724.newchat.aiapi.models.dto.lm.LanguageModelDto
|
||||
|
||||
@Composable
|
||||
fun ChatQuickSettings(
|
||||
modifier: Modifier = Modifier,
|
||||
onModelSelected: (LanguageModel) -> Unit,
|
||||
onModelSelected: (LanguageModelDto) -> Unit,
|
||||
onDismiss: () -> Unit,
|
||||
viewModel: ChatQuickSettingsViewModel = hiltViewModel(),
|
||||
) {
|
||||
val uiState by viewModel.uiState.collectAsStateWithLifecycle()
|
||||
|
||||
Column(
|
||||
DismissableLazyColumn(
|
||||
modifier = modifier,
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
onDismiss = onDismiss
|
||||
) {
|
||||
if (uiState.modelsLoaded) {
|
||||
ModelList(
|
||||
models = uiState.models,
|
||||
onModelSelected = onModelSelected,
|
||||
onDismiss = onDismiss
|
||||
)
|
||||
} else {
|
||||
CircularProgressIndicator()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ModelList(
|
||||
models: List<LanguageModel>,
|
||||
onModelSelected: (LanguageModel) -> Unit,
|
||||
onDismiss: () -> Unit
|
||||
) {
|
||||
val minHeight = 250.dp
|
||||
val maxHeight = 600.dp // TODO
|
||||
|
||||
var targetHeight by remember { mutableStateOf(minHeight) }
|
||||
val height by animateDpAsState(targetHeight)
|
||||
|
||||
var listState = rememberLazyListState()
|
||||
|
||||
val isScrolledToTop by remember {
|
||||
derivedStateOf {
|
||||
listState.firstVisibleItemIndex == 0 && listState.firstVisibleItemScrollOffset == 0
|
||||
}
|
||||
}
|
||||
|
||||
LaunchedEffect(listState.isScrollInProgress) {
|
||||
if (!listState.isScrollInProgress) {
|
||||
if (targetHeight < 50.dp) {
|
||||
onDismiss()
|
||||
} else if (targetHeight < minHeight + 10.dp) {
|
||||
targetHeight = minHeight
|
||||
} else {
|
||||
targetHeight = maxHeight
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is the connection that will intercept scroll events.
|
||||
val nestedScrollConnection = remember {
|
||||
object : NestedScrollConnection {
|
||||
// onPreScroll is called before the child (LazyColumn) gets to scroll.
|
||||
override fun onPreScroll(available: Offset, source: NestedScrollSource): Offset {
|
||||
if (available.y < 0) { // scroll down (content down, finger up)
|
||||
if (targetHeight < maxHeight) {
|
||||
targetHeight = min(targetHeight - available.y.dp / 2, maxHeight)
|
||||
return available
|
||||
}
|
||||
}
|
||||
|
||||
// scroll up (content up, finger down)
|
||||
if (available.y > 0 && isScrolledToTop) {
|
||||
targetHeight = targetHeight - available.y.dp / 2
|
||||
return available
|
||||
}
|
||||
|
||||
return Offset.Zero
|
||||
}
|
||||
|
||||
override fun onPostScroll(
|
||||
consumed: Offset,
|
||||
available: Offset,
|
||||
source: NestedScrollSource
|
||||
): Offset {
|
||||
return super.onPostScroll(consumed, available, source)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LazyColumn(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(height)
|
||||
.nestedScroll(nestedScrollConnection),
|
||||
state = listState
|
||||
) { // TODO make this rounded
|
||||
item {
|
||||
item(
|
||||
key = "header"
|
||||
) {
|
||||
Text(
|
||||
modifier = Modifier.padding(top = 16.dp), // TODO remove this hack
|
||||
text = stringResource(R.string.quick_settings_select_model),
|
||||
style = MaterialTheme.typography.titleLarge
|
||||
) // TODO center this maybe? but this looks cool too
|
||||
}
|
||||
|
||||
items(
|
||||
items = models,
|
||||
key = { it.id }
|
||||
) { model ->
|
||||
ModelCard(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(horizontal = 16.dp, vertical = 8.dp),
|
||||
model = model,
|
||||
onSelected = {
|
||||
onModelSelected(model)
|
||||
if (uiState.models.isNotEmpty()) {
|
||||
items(
|
||||
items = uiState.models,
|
||||
key = { it.id }
|
||||
) { model ->
|
||||
ModelCard(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(horizontal = 16.dp, vertical = 8.dp),
|
||||
model = model,
|
||||
onSelected = {
|
||||
onModelSelected(model)
|
||||
}
|
||||
)
|
||||
}
|
||||
} else {
|
||||
item(
|
||||
key = "loading"
|
||||
) {
|
||||
Box(
|
||||
modifier = Modifier.fillMaxSize().height(200.dp),
|
||||
contentAlignment = Alignment.Center
|
||||
) {
|
||||
CircularProgressIndicator()
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,9 +1,9 @@
|
|||
package eu.m724.chatapp.activity.chat.quick_settings
|
||||
|
||||
import eu.m724.chatapp.api.data.response.models.LanguageModel
|
||||
import eu.m724.newchat.aiapi.models.dto.lm.LanguageModelDto
|
||||
|
||||
sealed interface ChatQuickSettingsEvent {
|
||||
data class Visibility(val visibility: ChatQuickSettingsVisibility): ChatQuickSettingsEvent
|
||||
data class ModelSelected(val model: LanguageModel): ChatQuickSettingsEvent
|
||||
data class ModelSelected(val model: LanguageModelDto): ChatQuickSettingsEvent
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package eu.m724.chatapp.activity.chat.quick_settings
|
||||
|
||||
import eu.m724.chatapp.api.data.response.models.LanguageModel
|
||||
import eu.m724.newchat.aiapi.models.dto.lm.LanguageModelDto
|
||||
|
||||
data class ChatQuickSettingsUiState(
|
||||
val modelsLoaded: Boolean = false,
|
||||
|
@ -8,5 +8,5 @@ data class ChatQuickSettingsUiState(
|
|||
/**
|
||||
* A list of all available language models.
|
||||
*/
|
||||
val models: List<LanguageModel> = listOf()
|
||||
val models: List<LanguageModelDto> = listOf()
|
||||
)
|
|
@ -3,7 +3,7 @@ package eu.m724.chatapp.activity.chat.quick_settings
|
|||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import eu.m724.chatapp.api.AiApiService
|
||||
import eu.m724.newchat.aiapi.repository.AiApiRepository
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
|
@ -13,7 +13,7 @@ import javax.inject.Inject
|
|||
|
||||
@HiltViewModel
|
||||
class ChatQuickSettingsViewModel @Inject constructor(
|
||||
val aiApiService: AiApiService
|
||||
val aiApiRepository: AiApiRepository
|
||||
) : ViewModel() {
|
||||
private val _uiState = MutableStateFlow(ChatQuickSettingsUiState())
|
||||
val uiState: StateFlow<ChatQuickSettingsUiState> = _uiState.asStateFlow()
|
||||
|
@ -24,19 +24,12 @@ class ChatQuickSettingsViewModel @Inject constructor(
|
|||
|
||||
private fun loadModels() {
|
||||
viewModelScope.launch {
|
||||
val modelsResponse = try {
|
||||
aiApiService.getModels()
|
||||
} catch (e: Exception) {
|
||||
// TODO
|
||||
return@launch
|
||||
}
|
||||
|
||||
val models = modelsResponse.body()!!.data
|
||||
val models = aiApiRepository.getLanguageModels()
|
||||
|
||||
_uiState.update {
|
||||
it.copy(
|
||||
modelsLoaded = true,
|
||||
models = models
|
||||
models = models.getOrThrow().data
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
package eu.m724.chatapp.activity.chat.quick_settings.composable
|
||||
|
||||
import androidx.compose.animation.core.animateDpAsState
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.LazyListScope
|
||||
import androidx.compose.foundation.lazy.rememberLazyListState
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.derivedStateOf
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.geometry.Offset
|
||||
import androidx.compose.ui.input.nestedscroll.NestedScrollConnection
|
||||
import androidx.compose.ui.input.nestedscroll.NestedScrollSource
|
||||
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.min
|
||||
|
||||
@Composable
|
||||
fun DismissableLazyColumn(
|
||||
modifier: Modifier = Modifier,
|
||||
onDismiss: () -> Unit,
|
||||
content: LazyListScope.() -> Unit,
|
||||
) {
|
||||
val minHeight = 250.dp
|
||||
val maxHeight = 600.dp // TODO
|
||||
|
||||
var targetHeight by remember { mutableStateOf(minHeight) }
|
||||
val height by animateDpAsState(targetHeight)
|
||||
|
||||
var listState = rememberLazyListState()
|
||||
|
||||
val isScrolledToTop by remember {
|
||||
derivedStateOf {
|
||||
listState.firstVisibleItemIndex == 0 && listState.firstVisibleItemScrollOffset == 0
|
||||
}
|
||||
}
|
||||
|
||||
LaunchedEffect(listState.isScrollInProgress) {
|
||||
if (!listState.isScrollInProgress) {
|
||||
if (targetHeight < 50.dp) {
|
||||
onDismiss()
|
||||
} else if (targetHeight < minHeight + 10.dp) {
|
||||
targetHeight = minHeight
|
||||
} else {
|
||||
targetHeight = maxHeight
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is the connection that will intercept scroll events.
|
||||
val nestedScrollConnection = remember {
|
||||
object : NestedScrollConnection {
|
||||
// onPreScroll is called before the child (LazyColumn) gets to scroll.
|
||||
override fun onPreScroll(available: Offset, source: NestedScrollSource): Offset {
|
||||
if (available.y < 0) { // scroll down (content down, finger up)
|
||||
if (targetHeight < maxHeight) {
|
||||
targetHeight = min(targetHeight - available.y.dp / 2, maxHeight)
|
||||
return available
|
||||
}
|
||||
}
|
||||
|
||||
// scroll up (content up, finger down)
|
||||
if (available.y > 0 && isScrolledToTop) {
|
||||
targetHeight = targetHeight - available.y.dp / 2
|
||||
return available
|
||||
}
|
||||
|
||||
return Offset.Zero
|
||||
}
|
||||
|
||||
override fun onPostScroll(
|
||||
consumed: Offset,
|
||||
available: Offset,
|
||||
source: NestedScrollSource
|
||||
): Offset {
|
||||
return super.onPostScroll(consumed, available, source)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LazyColumn(
|
||||
modifier = modifier
|
||||
.fillMaxWidth()
|
||||
.height(height)
|
||||
.nestedScroll(nestedScrollConnection),
|
||||
state = listState,
|
||||
content = content
|
||||
)
|
||||
}
|
|
@ -6,6 +6,7 @@ import androidx.compose.animation.core.animateFloatAsState
|
|||
import androidx.compose.animation.core.tween
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.FlowRow
|
||||
import androidx.compose.foundation.layout.PaddingValues
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
|
@ -39,13 +40,13 @@ import androidx.compose.ui.text.font.FontWeight
|
|||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.unit.dp
|
||||
import eu.m724.chatapp.R
|
||||
import eu.m724.chatapp.api.data.response.models.LanguageModel
|
||||
import eu.m724.newchat.aiapi.models.dto.lm.LanguageModelDto
|
||||
import java.math.RoundingMode
|
||||
import java.text.DecimalFormat
|
||||
|
||||
@Composable
|
||||
fun ModelCard(
|
||||
model: LanguageModel,
|
||||
model: LanguageModelDto,
|
||||
modifier: Modifier = Modifier,
|
||||
onSelected: () -> Unit = {},
|
||||
) {
|
||||
|
@ -102,6 +103,7 @@ fun ModelCard(
|
|||
horizontalArrangement = Arrangement.SpaceBetween
|
||||
) {
|
||||
Text(
|
||||
modifier = Modifier.weight(1f),
|
||||
text = model.name,
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
)
|
||||
|
@ -114,7 +116,7 @@ fun ModelCard(
|
|||
if (model.description != null) {
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
Text(
|
||||
text = model.description,
|
||||
text = model.description!!,
|
||||
modifier = Modifier.animateContentSize(),
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||
|
@ -127,14 +129,15 @@ fun ModelCard(
|
|||
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
|
||||
Row(
|
||||
FlowRow(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
verticalArrangement = Arrangement.Center,
|
||||
horizontalArrangement = Arrangement.End
|
||||
) {
|
||||
PriceItem(
|
||||
icon = Icons.Default.KeyboardArrowUp,
|
||||
label = stringResource(R.string.model_card_price_input),
|
||||
price = model.pricing.pricePerMillionInputTokens,
|
||||
price = model.pricing.inputCostPerMillionTokens,
|
||||
contentDescription = stringResource(R.string.model_card_price_million_input_icon_description)
|
||||
)
|
||||
|
||||
|
@ -143,7 +146,7 @@ fun ModelCard(
|
|||
PriceItem(
|
||||
icon = Icons.Default.KeyboardArrowDown,
|
||||
label = stringResource(R.string.model_card_price_output),
|
||||
price = model.pricing.pricePerMillionOutputTokens,
|
||||
price = model.pricing.completionCostPerMillionTokens,
|
||||
contentDescription = stringResource(R.string.model_card_price_million_output_icon_description)
|
||||
)
|
||||
|
||||
|
|
|
@ -2,98 +2,199 @@ package eu.m724.chatapp.activity.main
|
|||
|
||||
import android.content.Intent
|
||||
import android.os.Bundle
|
||||
import android.util.Log
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.compose.setContent
|
||||
import androidx.activity.enableEdgeToEdge
|
||||
import androidx.compose.animation.animateContentSize
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.activity.viewModels
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.LinearProgressIndicator
|
||||
import androidx.compose.foundation.layout.width
|
||||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.Create
|
||||
import androidx.compose.material3.FloatingActionButton
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.sp
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import androidx.lifecycle.lifecycleScope
|
||||
import dagger.hilt.android.AndroidEntryPoint
|
||||
import eu.m724.chatapp.R
|
||||
import eu.m724.chatapp.activity.chat.ChatActivity
|
||||
import eu.m724.chatapp.activity.ui.theme.ChatAppTheme
|
||||
import eu.m724.chatapp.model.Chat
|
||||
import kotlinx.coroutines.launch
|
||||
import java.time.ZoneId
|
||||
import java.time.format.DateTimeFormatter
|
||||
import java.time.format.FormatStyle
|
||||
|
||||
@AndroidEntryPoint
|
||||
class MainActivity : ComponentActivity() {
|
||||
val viewModel: MainActivityViewModel by viewModels()
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
enableEdgeToEdge()
|
||||
setContent {
|
||||
ChatAppTheme {
|
||||
Scaffold(
|
||||
modifier = Modifier.fillMaxSize()
|
||||
) { innerPadding ->
|
||||
Content(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(innerPadding)
|
||||
)
|
||||
|
||||
lifecycleScope.launch {
|
||||
viewModel.uiEvents.collect { event ->
|
||||
when (event) {
|
||||
is MainActivityUiEvent.StartChat -> {
|
||||
startChat(event.chatId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enableEdgeToEdge()
|
||||
setContent {
|
||||
Content(
|
||||
onChatSelected = { id ->
|
||||
startChat(id)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fun startChat(chatId: Long?) {
|
||||
Log.d("MainActivity", "Starting chat with id $chatId")
|
||||
|
||||
val intent = Intent(this@MainActivity, ChatActivity::class.java).apply {
|
||||
putExtra("chatId", chatId)
|
||||
}
|
||||
|
||||
startActivity(intent)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun Content(
|
||||
modifier: Modifier = Modifier,
|
||||
onChatSelected: (Long?) -> Unit,
|
||||
viewModel: MainActivityViewModel = hiltViewModel()
|
||||
) {
|
||||
val uiState by viewModel.uiState.collectAsState()
|
||||
val context = LocalContext.current
|
||||
val chats by viewModel.chats.collectAsState()
|
||||
|
||||
LaunchedEffect(Unit) {
|
||||
viewModel.uiEvents.collect { event ->
|
||||
when (event) {
|
||||
is MainActivityUiEvent.StartChat -> {
|
||||
val intent = Intent(context, ChatActivity::class.java).apply {
|
||||
putExtra("chat", event.chat)
|
||||
}
|
||||
|
||||
context.startActivity(intent)
|
||||
ChatAppTheme {
|
||||
Scaffold(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
floatingActionButton = {
|
||||
FloatingActionButton(
|
||||
onClick = {
|
||||
onChatSelected(null)
|
||||
},
|
||||
// TODO maybe disable ripple
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Create,
|
||||
contentDescription = stringResource(R.string.start_new_conversation)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Column(
|
||||
modifier = modifier,
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.SpaceEvenly
|
||||
) {
|
||||
Text(
|
||||
text = stringResource(R.string.welcome)
|
||||
)
|
||||
|
||||
Button(
|
||||
modifier = Modifier.animateContentSize(),
|
||||
onClick = {
|
||||
viewModel.startConversation()
|
||||
},
|
||||
enabled = !uiState.loading
|
||||
) {
|
||||
if (uiState.loading) {
|
||||
LinearProgressIndicator()
|
||||
} else {
|
||||
) { innerPadding ->
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(innerPadding),
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
Text(
|
||||
text = stringResource(R.string.start_new_conversation)
|
||||
modifier = Modifier.padding(8.dp),
|
||||
text = stringResource(R.string.welcome)
|
||||
)
|
||||
|
||||
ChatList(
|
||||
modifier = Modifier
|
||||
.fillMaxSize(),
|
||||
onChatSelected = onChatSelected,
|
||||
chats = chats
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Composable
|
||||
fun ChatList(
|
||||
modifier: Modifier = Modifier,
|
||||
onChatSelected: (Long) -> Unit,
|
||||
chats: List<Chat>
|
||||
) {
|
||||
LazyColumn(
|
||||
modifier = modifier
|
||||
) {
|
||||
items(
|
||||
items = chats,
|
||||
key = { chat -> chat.id }
|
||||
) { chat ->
|
||||
ChatListEntry(
|
||||
chat = chat,
|
||||
onClick = {
|
||||
onChatSelected(chat.id)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ChatListEntry(
|
||||
modifier: Modifier = Modifier,
|
||||
chat: Chat,
|
||||
onClick: () -> Unit
|
||||
) {
|
||||
Box(
|
||||
modifier = modifier
|
||||
.fillMaxWidth()
|
||||
.padding(8.dp)
|
||||
.clickable(onClick = onClick),
|
||||
contentAlignment = Alignment.Center
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.width(350.dp)
|
||||
) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Text(
|
||||
modifier = Modifier.padding(horizontal = 8.dp),
|
||||
text = chat.title ?: "Untitled",
|
||||
fontWeight = FontWeight.Bold,
|
||||
fontSize = 18.sp
|
||||
)
|
||||
|
||||
Text(
|
||||
text = chat.lastUpdated.atZone(ZoneId.systemDefault()).format(
|
||||
DateTimeFormatter.ofLocalizedTime(FormatStyle.SHORT)),
|
||||
color = MaterialTheme.colorScheme.secondary,
|
||||
fontSize = 12.sp
|
||||
)
|
||||
}
|
||||
|
||||
Text(
|
||||
text = chat.subtitle ?: "Chat is empty",
|
||||
maxLines = 1,
|
||||
overflow = TextOverflow.Ellipsis,
|
||||
fontSize = 12.sp
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,9 +1,7 @@
|
|||
package eu.m724.chatapp.activity.main
|
||||
|
||||
import eu.m724.chatapp.store.data.Chat
|
||||
|
||||
sealed interface MainActivityUiEvent {
|
||||
data class StartChat(
|
||||
val chat: Chat,
|
||||
val chatId: Long,
|
||||
): MainActivityUiEvent
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
package eu.m724.chatapp.activity.main
|
||||
|
||||
data class MainActivityUiState(
|
||||
val loading: Boolean = false
|
||||
)
|
|
@ -3,51 +3,30 @@ package eu.m724.chatapp.activity.main
|
|||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import eu.m724.chatapp.api.AiApiService
|
||||
import eu.m724.chatapp.store.data.Chat
|
||||
import eu.m724.chatapp.store.room.ChatDao
|
||||
import eu.m724.chat.storage.repository.ChatStorageRepository
|
||||
import eu.m724.chatapp.model.toChat
|
||||
import kotlinx.coroutines.channels.Channel
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.map
|
||||
import kotlinx.coroutines.flow.receiveAsFlow
|
||||
import kotlinx.coroutines.flow.update
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.flow.stateIn
|
||||
import javax.inject.Inject
|
||||
|
||||
@HiltViewModel
|
||||
class MainActivityViewModel @Inject constructor(
|
||||
val aiApiService: AiApiService,
|
||||
val chatDao: ChatDao
|
||||
chatStorageRepository: ChatStorageRepository
|
||||
) : ViewModel() {
|
||||
private val _uiState = MutableStateFlow(MainActivityUiState())
|
||||
val uiState: StateFlow<MainActivityUiState> = _uiState.asStateFlow()
|
||||
|
||||
private val _uiEvents = Channel<MainActivityUiEvent>()
|
||||
val uiEvents = _uiEvents.receiveAsFlow()
|
||||
|
||||
fun startConversation() {
|
||||
_uiState.update {
|
||||
it.copy(
|
||||
loading = true
|
||||
)
|
||||
val chats = chatStorageRepository.listChats()
|
||||
.map { chats ->
|
||||
chats.map { chatEntity ->
|
||||
chatEntity.toChat()
|
||||
}
|
||||
}
|
||||
|
||||
viewModelScope.launch {
|
||||
val modelsResponse = aiApiService.getModels()
|
||||
val models = modelsResponse.body()!!.data
|
||||
val model = models.find { it.id == "meta-llama/llama-3.2-3b-instruct" }
|
||||
println(models)
|
||||
|
||||
val chat = Chat(
|
||||
title = null,
|
||||
model = model!!,
|
||||
messages = emptyList()
|
||||
)
|
||||
|
||||
chatDao.insertChat()
|
||||
|
||||
_uiEvents.send(MainActivityUiEvent.StartChat(chat))
|
||||
}
|
||||
}
|
||||
.stateIn(
|
||||
scope = viewModelScope,
|
||||
started = kotlinx.coroutines.flow.SharingStarted.WhileSubscribed(5000),
|
||||
initialValue = emptyList()
|
||||
)
|
||||
}
|
|
@ -1,138 +0,0 @@
|
|||
package eu.m724.chatapp.api
|
||||
|
||||
import com.google.gson.FieldNamingPolicy
|
||||
import com.google.gson.Gson
|
||||
import com.google.gson.GsonBuilder
|
||||
import dagger.Module
|
||||
import dagger.Provides
|
||||
import dagger.hilt.InstallIn
|
||||
import dagger.hilt.components.SingletonComponent
|
||||
import eu.m724.chatapp.BuildConfig
|
||||
import eu.m724.chatapp.api.retrofit.interceptor.AiApiRequestExceptionInterceptor
|
||||
import eu.m724.chatapp.api.retrofit.interceptor.AiApiRequestHeadersInterceptor
|
||||
import eu.m724.chatapp.api.retrofit.sse.SseCallAdapterFactory
|
||||
import okhttp3.OkHttpClient
|
||||
import okhttp3.logging.HttpLoggingInterceptor
|
||||
import retrofit2.Retrofit
|
||||
import retrofit2.converter.gson.GsonConverterFactory
|
||||
import java.util.concurrent.TimeUnit
|
||||
import javax.inject.Named
|
||||
import javax.inject.Qualifier
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Module
|
||||
@InstallIn(SingletonComponent::class)
|
||||
object AiApiNetworkModule {
|
||||
@Provides
|
||||
@Named("apiKey")
|
||||
fun provideApiKey(): String = BuildConfig.API_KEY
|
||||
|
||||
@Provides
|
||||
@Named("apiEndpoint")
|
||||
fun provideApiEndpoint(): String = BuildConfig.API_ENDPOINT
|
||||
|
||||
@Provides
|
||||
@Named("userAgent")
|
||||
fun provideUserAgent(): String = BuildConfig.USER_AGENT
|
||||
|
||||
@Provides
|
||||
@Named("isDebug")
|
||||
fun provideIsDebug(): Boolean = BuildConfig.DEBUG
|
||||
|
||||
@Provides
|
||||
fun provideOkHttpClientBuilder(
|
||||
@Named("apiKey") apiKey: String,
|
||||
@Named("apiEndpoint") apiEndpoint: String,
|
||||
@Named("userAgent") userAgent: String,
|
||||
@Named("isDebug") isDebug: Boolean
|
||||
): OkHttpClient.Builder {
|
||||
val interceptor = AiApiRequestHeadersInterceptor(
|
||||
userAgent = userAgent,
|
||||
apiEndpoint = apiEndpoint,
|
||||
apiKey = apiKey
|
||||
)
|
||||
|
||||
val builder = OkHttpClient.Builder()
|
||||
.addInterceptor(interceptor)
|
||||
|
||||
if (isDebug) {
|
||||
// level body makes the response buffered which nukes sse
|
||||
builder.addInterceptor(HttpLoggingInterceptor().apply { level = HttpLoggingInterceptor.Level.HEADERS })
|
||||
}
|
||||
|
||||
return builder
|
||||
}
|
||||
|
||||
/**
|
||||
* The standard client is used for standard (non-SSE) requests
|
||||
*/
|
||||
@Provides
|
||||
@Singleton
|
||||
@StandardClient
|
||||
fun provideStandardOkHttpClient(
|
||||
builder: OkHttpClient.Builder,
|
||||
gson: Gson
|
||||
): OkHttpClient {
|
||||
val interceptor = AiApiRequestExceptionInterceptor(
|
||||
gson = gson
|
||||
)
|
||||
|
||||
return builder
|
||||
.addInterceptor(interceptor)
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* The long-lived client is used for long-lived (SSE) requests
|
||||
*/
|
||||
@Provides
|
||||
@Singleton
|
||||
@LongLivedClient
|
||||
fun provideLongLivedOkHttpClient(
|
||||
builder: OkHttpClient.Builder
|
||||
): OkHttpClient {
|
||||
return builder
|
||||
.readTimeout(0, TimeUnit.SECONDS) // Apparently there are other safe guards against a zombie connection, but I don't know in practice
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
.build()
|
||||
}
|
||||
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideGson(): Gson {
|
||||
return GsonBuilder()
|
||||
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) // snake_case
|
||||
.create()
|
||||
}
|
||||
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideRetrofit(
|
||||
@StandardClient standardOkHttpClient: OkHttpClient,
|
||||
@LongLivedClient longLivedOkHttpClient: OkHttpClient,
|
||||
gson: Gson,
|
||||
@Named("apiEndpoint") apiEndpoint: String,
|
||||
@Named("isDebug") isDebug: Boolean
|
||||
): Retrofit {
|
||||
return Retrofit.Builder()
|
||||
.baseUrl(apiEndpoint)
|
||||
.client(standardOkHttpClient) // Use the standard client by default
|
||||
.addCallAdapterFactory(SseCallAdapterFactory(longLivedOkHttpClient, gson, isDebug)) // this intercepts SSE requests and makes them use the long-lived client
|
||||
.addConverterFactory(GsonConverterFactory.create(gson))
|
||||
.build()
|
||||
}
|
||||
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideAiApiService(retrofit: Retrofit): AiApiService {
|
||||
return retrofit.create(AiApiService::class.java)
|
||||
}
|
||||
}
|
||||
|
||||
@Qualifier
|
||||
@Retention(AnnotationRetention.BINARY)
|
||||
annotation class StandardClient
|
||||
|
||||
@Qualifier
|
||||
@Retention(AnnotationRetention.BINARY)
|
||||
annotation class LongLivedClient
|
|
@ -1,21 +0,0 @@
|
|||
package eu.m724.chatapp.api
|
||||
|
||||
import eu.m724.chatapp.api.data.request.completion.ChatCompletionRequest
|
||||
import eu.m724.chatapp.api.data.response.completion.ChatCompletionResponseEvent
|
||||
import eu.m724.chatapp.api.data.response.models.LanguageModelsResponse
|
||||
import eu.m724.chatapp.api.retrofit.sse.SseEvent
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import retrofit2.Response
|
||||
import retrofit2.http.Body
|
||||
import retrofit2.http.GET
|
||||
import retrofit2.http.POST
|
||||
import retrofit2.http.Streaming
|
||||
|
||||
interface AiApiService {
|
||||
@GET("models?detailed=true")
|
||||
suspend fun getModels(): Response<LanguageModelsResponse>
|
||||
|
||||
@POST("chat/completions")
|
||||
@Streaming
|
||||
fun getChatCompletion(@Body body: ChatCompletionRequest): Flow<SseEvent<ChatCompletionResponseEvent>>
|
||||
}
|
|
@ -1,17 +0,0 @@
|
|||
package eu.m724.chatapp.api.data
|
||||
|
||||
data class AiApiExceptionData(
|
||||
override val message: String,
|
||||
val type: String,
|
||||
val param: String,
|
||||
val code: Int
|
||||
) : Exception(message)
|
||||
|
||||
data class AiApiExceptionDataWrapper(
|
||||
val error: AiApiExceptionData
|
||||
)
|
||||
|
||||
class AiApiException(
|
||||
val httpCode: Int,
|
||||
val error: AiApiExceptionData?
|
||||
) : Exception("API problem: ${error?.message} (code $httpCode)")
|
|
@ -1,77 +0,0 @@
|
|||
package eu.m724.chatapp.api.data.response.completion
|
||||
|
||||
import com.google.gson.annotations.SerializedName
|
||||
|
||||
data class ChatCompletionResponseEvent(
|
||||
/**
|
||||
* Request ID
|
||||
*/
|
||||
val id: String,
|
||||
|
||||
/**
|
||||
* Completion choices. Usually has only one element.
|
||||
*/
|
||||
val choices: List<CompletionChoice>,
|
||||
|
||||
/**
|
||||
* The cost (in tokens) of this completion
|
||||
*/
|
||||
@SerializedName("usage")
|
||||
val tokenUsage: CompletionTokenUsage
|
||||
)
|
||||
|
||||
data class CompletionChoice(
|
||||
val index: Int,
|
||||
|
||||
/**
|
||||
* The generated message delta, you should merge it with the previous delta
|
||||
*/
|
||||
val delta: CompletionChoiceDelta,
|
||||
|
||||
/**
|
||||
* The reason why generating the response has stopped. null if the response hasn't finished yet.
|
||||
*/
|
||||
val finishReason: CompletionFinishReason?
|
||||
)
|
||||
|
||||
data class CompletionChoiceDelta(
|
||||
/** The next generated token, may be null if the response just finished */
|
||||
val content: String?
|
||||
)
|
||||
|
||||
enum class CompletionFinishReason {
|
||||
/**
|
||||
* The response has stopped, because the model said so
|
||||
*/
|
||||
@SerializedName("stop")
|
||||
Stop,
|
||||
|
||||
/**
|
||||
* The response has stopped, because it got too long
|
||||
*/
|
||||
@SerializedName("length")
|
||||
Length,
|
||||
|
||||
/**
|
||||
* The response has stopped, because the content got flagged
|
||||
*/
|
||||
@SerializedName("content_filter")
|
||||
ContentFilter
|
||||
}
|
||||
|
||||
data class CompletionTokenUsage(
|
||||
/**
|
||||
* The amount of tokens of the prompt
|
||||
*/
|
||||
val promptTokens: Int,
|
||||
|
||||
/**
|
||||
* The amount of tokens of the generated completion
|
||||
*/
|
||||
val completionTokens: Int,
|
||||
|
||||
/**
|
||||
* The total amount of tokens processed
|
||||
*/
|
||||
val totalTokens: Int
|
||||
)
|
|
@ -1,22 +0,0 @@
|
|||
package eu.m724.chatapp.api.data.response.completion
|
||||
|
||||
import android.os.Parcelable
|
||||
import com.google.gson.annotations.SerializedName
|
||||
import kotlinx.parcelize.Parcelize
|
||||
|
||||
@Parcelize
|
||||
data class ChatMessage(
|
||||
val role: Role,
|
||||
val content: String
|
||||
) : Parcelable {
|
||||
enum class Role {
|
||||
@SerializedName("system")
|
||||
System,
|
||||
|
||||
@SerializedName("user")
|
||||
User,
|
||||
|
||||
@SerializedName("assistant")
|
||||
Assistant
|
||||
}
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
package eu.m724.chatapp.api.data.response.models
|
||||
|
||||
import android.os.Parcelable
|
||||
import com.google.gson.annotations.SerializedName
|
||||
import kotlinx.parcelize.Parcelize
|
||||
|
||||
/**
|
||||
* Represents a language model.
|
||||
*/
|
||||
@Parcelize
|
||||
data class LanguageModel(
|
||||
/**
|
||||
* The ID of this model.
|
||||
*/
|
||||
val id: String,
|
||||
|
||||
/**
|
||||
* The readable name of this model.
|
||||
*/
|
||||
val name: String,
|
||||
|
||||
/**
|
||||
* The description of this model. TODO make it null if it equals model name
|
||||
*/
|
||||
val description: String?,
|
||||
|
||||
/**
|
||||
* The maximum amount of tokens this model can handle in one sitting.
|
||||
*/
|
||||
@SerializedName("context_length")
|
||||
val contextLength: Int,
|
||||
|
||||
/**
|
||||
* The pricing of this model
|
||||
*/
|
||||
val pricing: LanguageModelPricing
|
||||
) : Parcelable
|
||||
|
||||
/**
|
||||
* Represents the pricing of a language model.
|
||||
*/
|
||||
@Parcelize
|
||||
data class LanguageModelPricing(
|
||||
/**
|
||||
* The price per million input tokens
|
||||
*/
|
||||
@SerializedName("prompt")
|
||||
val pricePerMillionInputTokens: Double,
|
||||
|
||||
/**
|
||||
* The price per million output tokens
|
||||
*/
|
||||
@SerializedName("completion")
|
||||
val pricePerMillionOutputTokens: Double,
|
||||
|
||||
/**
|
||||
* Currency, as a code, like USD
|
||||
*/
|
||||
val currency: String
|
||||
) : Parcelable
|
|
@ -1,5 +0,0 @@
|
|||
package eu.m724.chatapp.api.data.response.models
|
||||
|
||||
data class LanguageModelsResponse(
|
||||
val data: List<LanguageModel>
|
||||
)
|
|
@ -1,23 +0,0 @@
|
|||
package eu.m724.chatapp.api.data.serialize
|
||||
|
||||
import com.google.gson.JsonDeserializationContext
|
||||
import com.google.gson.JsonDeserializer
|
||||
import com.google.gson.JsonElement
|
||||
import com.google.gson.JsonParseException
|
||||
import java.lang.reflect.Type
|
||||
import java.time.LocalDateTime
|
||||
import java.time.ZoneOffset
|
||||
|
||||
class EpochSecondToLocalDateTimeDeserializer : JsonDeserializer<LocalDateTime> {
|
||||
override fun deserialize(
|
||||
json: JsonElement?,
|
||||
typeOfT: Type?,
|
||||
context: JsonDeserializationContext?
|
||||
): LocalDateTime? {
|
||||
json?.asLong?.let { timestamp ->
|
||||
return LocalDateTime.ofEpochSecond(timestamp, 0, ZoneOffset.UTC)
|
||||
}
|
||||
|
||||
throw JsonParseException("Error deserializing LocalDateTime from $json")
|
||||
}
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
package eu.m724.chatapp.api.retrofit.interceptor
|
||||
|
||||
import com.google.gson.Gson
|
||||
import eu.m724.chatapp.api.data.AiApiException
|
||||
import eu.m724.chatapp.api.data.AiApiExceptionDataWrapper
|
||||
import okhttp3.Interceptor
|
||||
import okhttp3.Response
|
||||
|
||||
class AiApiRequestExceptionInterceptor(
|
||||
private val gson: Gson
|
||||
) : Interceptor {
|
||||
override fun intercept(chain: Interceptor.Chain): Response {
|
||||
val request = chain.request()
|
||||
val response = chain.proceed(request)
|
||||
|
||||
if (response.isSuccessful) {
|
||||
return response
|
||||
}
|
||||
|
||||
response.close()
|
||||
|
||||
val apiError =
|
||||
try {
|
||||
gson.fromJson(response.body!!.string(), AiApiExceptionDataWrapper::class.java)
|
||||
} catch (_: Exception) {
|
||||
null
|
||||
}?.error
|
||||
|
||||
throw AiApiException(response.code, apiError)
|
||||
}
|
||||
}
|
45
app/src/main/java/eu/m724/chatapp/model/Chat.kt
Normal file
45
app/src/main/java/eu/m724/chatapp/model/Chat.kt
Normal file
|
@ -0,0 +1,45 @@
|
|||
package eu.m724.chatapp.model
|
||||
|
||||
import eu.m724.chat.storage.entity.ChatEntity
|
||||
import java.time.Instant
|
||||
|
||||
data class Chat(
|
||||
val id: Long,
|
||||
|
||||
val lastUpdated: Instant,
|
||||
|
||||
/**
|
||||
* The chat title, null if not set
|
||||
*/
|
||||
val title: String?,
|
||||
|
||||
/**
|
||||
* The chat subtitle, usually the last message, null if not set
|
||||
*/
|
||||
val subtitle: String?,
|
||||
|
||||
/**
|
||||
* The ID of the selected model
|
||||
*/
|
||||
val selectedModel: String
|
||||
)
|
||||
|
||||
fun ChatEntity.toChat(): Chat {
|
||||
return Chat(
|
||||
id = this.id,
|
||||
lastUpdated = Instant.ofEpochMilli(this.lastUpdated),
|
||||
title = this.title,
|
||||
subtitle = this.subtitle,
|
||||
selectedModel = this.model
|
||||
)
|
||||
}
|
||||
|
||||
fun Chat.toEntity(): ChatEntity {
|
||||
return ChatEntity(
|
||||
id = this.id,
|
||||
lastUpdated = this.lastUpdated.toEpochMilli(),
|
||||
title = this.title,
|
||||
subtitle = this.subtitle,
|
||||
model = this.selectedModel
|
||||
)
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
package eu.m724.chatapp.model
|
||||
|
||||
sealed interface ChatCompletionError {
|
||||
data object LengthLimit: ChatCompletionError
|
||||
data class Other(val message: String?): ChatCompletionError
|
||||
}
|
47
app/src/main/java/eu/m724/chatapp/model/ChatMessage.kt
Normal file
47
app/src/main/java/eu/m724/chatapp/model/ChatMessage.kt
Normal file
|
@ -0,0 +1,47 @@
|
|||
package eu.m724.chatapp.model
|
||||
|
||||
import eu.m724.chat.storage.entity.MessageEntity
|
||||
import eu.m724.newchat.aiapi.models.dto.chat.ChatMessageDto
|
||||
|
||||
data class ChatMessage(
|
||||
val id: Long? = null,
|
||||
val role: Role,
|
||||
val content: String
|
||||
) {
|
||||
enum class Role {
|
||||
System,
|
||||
User,
|
||||
Assistant
|
||||
}
|
||||
}
|
||||
|
||||
fun ChatMessage.toDto(): ChatMessageDto {
|
||||
return ChatMessageDto(
|
||||
role = this.role.name.lowercase(),
|
||||
content = this.content
|
||||
)
|
||||
}
|
||||
|
||||
fun ChatMessage.toEntity(chatId: Long): MessageEntity {
|
||||
return MessageEntity(
|
||||
id = this.id ?: 0,
|
||||
chatId = chatId,
|
||||
assistant = this.role == ChatMessage.Role.Assistant,
|
||||
content = this.content
|
||||
)
|
||||
}
|
||||
|
||||
fun ChatMessageDto.toChatMessage(): ChatMessage {
|
||||
return ChatMessage(
|
||||
role = ChatMessage.Role.valueOf(this.role),
|
||||
content = this.content
|
||||
)
|
||||
}
|
||||
|
||||
fun MessageEntity.toChatMessage(): ChatMessage {
|
||||
return ChatMessage(
|
||||
id = this.id,
|
||||
role = if (this.assistant) ChatMessage.Role.Assistant else ChatMessage.Role.User,
|
||||
content = this.content
|
||||
)
|
||||
}
|
28
app/src/main/java/eu/m724/chatapp/module/AiApiModule.kt
Normal file
28
app/src/main/java/eu/m724/chatapp/module/AiApiModule.kt
Normal file
|
@ -0,0 +1,28 @@
|
|||
package eu.m724.chatapp.module
|
||||
|
||||
import dagger.Module
|
||||
import dagger.Provides
|
||||
import dagger.hilt.InstallIn
|
||||
import dagger.hilt.components.SingletonComponent
|
||||
import eu.m724.chatapp.BuildConfig
|
||||
import eu.m724.newchat.aiapi.AiApiConfiguration
|
||||
import eu.m724.newchat.aiapi.AiApiDataLayerFactory
|
||||
import eu.m724.newchat.aiapi.repository.AiApiRepository
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Module
|
||||
@InstallIn(SingletonComponent::class)
|
||||
object AiApiModule {
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideAiApiRepository(): AiApiRepository {
|
||||
val configuration = object : AiApiConfiguration {
|
||||
override val apiKey: String = BuildConfig.API_KEY
|
||||
override val endpoint: String = BuildConfig.API_ENDPOINT
|
||||
override val userAgent: String = BuildConfig.USER_AGENT
|
||||
override val isDebug: Boolean = BuildConfig.DEBUG
|
||||
}
|
||||
|
||||
return AiApiDataLayerFactory.createApiRepository(configuration)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package eu.m724.chatapp.module
|
||||
|
||||
import android.content.Context
|
||||
import dagger.Module
|
||||
import dagger.Provides
|
||||
import dagger.hilt.InstallIn
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import dagger.hilt.components.SingletonComponent
|
||||
import eu.m724.chat.storage.ChatStorageDataLayerFactory
|
||||
import eu.m724.chat.storage.repository.ChatStorageRepository
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Module
|
||||
@InstallIn(SingletonComponent::class)
|
||||
object ChatStorageModule {
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideChatStorageRepository(
|
||||
@ApplicationContext applicationContext: Context
|
||||
): ChatStorageRepository {
|
||||
return ChatStorageDataLayerFactory.createChatStorageRepository(applicationContext)
|
||||
}
|
||||
}
|
|
@ -1,29 +0,0 @@
|
|||
package eu.m724.chatapp.store.data
|
||||
|
||||
import android.os.Parcelable
|
||||
import eu.m724.chatapp.api.data.response.completion.ChatMessage
|
||||
import eu.m724.chatapp.api.data.response.models.LanguageModel
|
||||
import kotlinx.parcelize.Parcelize
|
||||
|
||||
@Parcelize
|
||||
data class Chat(
|
||||
/**
|
||||
* The unique identifier of this chat.
|
||||
*/
|
||||
val id: Int,
|
||||
|
||||
/**
|
||||
* The title of this chat.
|
||||
*/
|
||||
val title: String?,
|
||||
|
||||
/**
|
||||
* The model used in this chat.
|
||||
*/
|
||||
val model: LanguageModel,
|
||||
|
||||
/**
|
||||
* The messages in this chat.
|
||||
*/
|
||||
val messages: List<ChatMessage>
|
||||
) : Parcelable
|
|
@ -1,6 +0,0 @@
|
|||
package eu.m724.chatapp.store.data
|
||||
|
||||
sealed interface ChatResponseError { // TODO does this belong here?
|
||||
data object LengthLimit: ChatResponseError
|
||||
data class Error(val message: String?): ChatResponseError
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
package eu.m724.chatapp.store.proto
|
||||
|
||||
import eu.m724.chatapp.proto.Chat
|
||||
|
||||
class DataStoreModule {
|
||||
val a: Chat
|
||||
}
|
|
@ -1,20 +0,0 @@
|
|||
package eu.m724.chatapp.store.proto
|
||||
|
||||
import androidx.datastore.core.Serializer
|
||||
import eu.m724.chatapp.proto.ProtoChat
|
||||
|
||||
object ProtoChatSerializer : Serializer<ProtoChat> {
|
||||
override val defaultValue: ProtoChat = Settings.getDefaultInstance()
|
||||
|
||||
override suspend fun readFrom(input: InputStream): Settings {
|
||||
try {
|
||||
return Settings.parseFrom(input)
|
||||
} catch (exception: InvalidProtocolBufferException) {
|
||||
throw CorruptionException("Cannot read proto.", exception)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun writeTo(
|
||||
t: Settings,
|
||||
output: OutputStream) = t.writeTo(output)
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
package eu.m724.chatapp.store.room
|
||||
|
||||
import androidx.room.Dao
|
||||
import androidx.room.Insert
|
||||
import androidx.room.Query
|
||||
import androidx.room.Update
|
||||
import eu.m724.chatapp.store.room.entity.ChatEntity
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
|
||||
@Dao
|
||||
interface ChatDao {
|
||||
@Query("SELECT * FROM chats")
|
||||
fun getAllChats(): List<ChatEntity>
|
||||
|
||||
@Query("SELECT * FROM chats WHERE id = :id")
|
||||
fun getChatById(id: Int): ChatEntity?
|
||||
|
||||
@Query("""
|
||||
SELECT * FROM chats
|
||||
JOIN chats_fts ON chats.id = chats_fts.rowid
|
||||
WHERE chats_fts MATCH :query
|
||||
""")
|
||||
fun searchChats(query: String): Flow<List<ChatEntity>>
|
||||
|
||||
@Insert
|
||||
fun insertChat(chat: ChatEntity)
|
||||
|
||||
@Update
|
||||
fun updateChat(chat: ChatEntity)
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
package eu.m724.chatapp.store.room
|
||||
|
||||
import androidx.room.Dao
|
||||
import androidx.room.Insert
|
||||
import androidx.room.Query
|
||||
import androidx.room.Update
|
||||
import eu.m724.chatapp.store.room.entity.MessageEntity
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
|
||||
@Dao
|
||||
interface MessageDao {
|
||||
@Insert
|
||||
suspend fun insertMessage(message: MessageEntity)
|
||||
|
||||
@Update
|
||||
suspend fun updateMessage(message: MessageEntity)
|
||||
|
||||
@Query("SELECT * FROM messages WHERE chatId = :chatId ORDER BY index ASC")
|
||||
fun getMessagesForChat(chatId: Int): Flow<List<MessageEntity>>
|
||||
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
package eu.m724.chatapp.store.room.database
|
||||
|
||||
import androidx.room.Database
|
||||
import androidx.room.RoomDatabase
|
||||
import eu.m724.chatapp.store.room.ChatDao
|
||||
import eu.m724.chatapp.store.room.MessageDao
|
||||
import eu.m724.chatapp.store.room.entity.ChatEntity
|
||||
import eu.m724.chatapp.store.room.entity.ChatEntityFts
|
||||
import eu.m724.chatapp.store.room.entity.MessageEntity
|
||||
|
||||
@Database(entities = [
|
||||
ChatEntity::class,
|
||||
ChatEntityFts::class,
|
||||
MessageEntity::class
|
||||
], version = 1)
|
||||
abstract class AppDatabase : RoomDatabase() {
|
||||
abstract fun chatDao(): ChatDao
|
||||
abstract fun messageDao(): MessageDao
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
package eu.m724.chatapp.store.room.database
|
||||
|
||||
import android.content.Context
|
||||
import androidx.room.Room
|
||||
import dagger.Module
|
||||
import dagger.Provides
|
||||
import dagger.hilt.InstallIn
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import dagger.hilt.components.SingletonComponent
|
||||
import eu.m724.chatapp.store.room.ChatDao
|
||||
import eu.m724.chatapp.store.room.MessageDao
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Module
|
||||
@InstallIn(SingletonComponent::class)
|
||||
object DatabaseModule {
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideAppDatabase(@ApplicationContext context: Context): AppDatabase {
|
||||
return Room.databaseBuilder(
|
||||
context,
|
||||
AppDatabase::class.java,
|
||||
"chatapp-database"
|
||||
).build()
|
||||
}
|
||||
|
||||
@Provides
|
||||
fun provideChatDao(appDatabase: AppDatabase): ChatDao {
|
||||
return appDatabase.chatDao()
|
||||
}
|
||||
|
||||
@Provides
|
||||
fun provideMessageDao(appDatabase: AppDatabase): MessageDao {
|
||||
return appDatabase.messageDao()
|
||||
}
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
package eu.m724.chatapp.store.room.entity
|
||||
|
||||
import androidx.room.ColumnInfo
|
||||
import androidx.room.Entity
|
||||
import androidx.room.Fts4
|
||||
|
||||
@Entity(tableName = "chats_fts")
|
||||
@Fts4
|
||||
data class ChatEntityFts(
|
||||
@ColumnInfo(name = "title")
|
||||
val title: String
|
||||
|
||||
)
|
|
@ -3,10 +3,19 @@
|
|||
<string name="app_name">Babilo</string>
|
||||
<string name="title_new_conversation">Komencu novan konversacion</string>
|
||||
<string name="button_send_icon_description">Sendu mesaĝon</string>
|
||||
<string name="button_send_restart_icon_description">Rekomencu la respondon</string>
|
||||
<string name="button_send_restart_icon_description">Redemandu</string>
|
||||
<string name="composer_placeholder_type">Tajpu vian mesaĝon…</string>
|
||||
<string name="response_error_icon_description">Eraro okazis dum respondado</string>
|
||||
<string name="response_error_length_limit">Tro longa</string>
|
||||
<string name="response_error_generic">Fatala eraro</string>
|
||||
<string name="ai_mistake_warning">AI povas erari, duoble kontrolu.</string>
|
||||
<string name="model_card_price_million_tokens">/ 1M signoj</string>
|
||||
<string name="model_card_toggle_icon_description">Ŝaltu detalojn</string>
|
||||
<string name="welcome">Bonvenon al ChatAppo!</string>
|
||||
<string name="model_card_price_input">Demando:</string>
|
||||
<string name="model_card_price_output">Kompletigo:</string>
|
||||
<string name="button_settings_icon_description">Agordoj</string>
|
||||
<string name="quick_settings_select_model">Elektu modelon</string>
|
||||
<string name="model_card_select">Elektu</string>
|
||||
<string name="model_card_icon_description">Ikono por %1$s</string>
|
||||
</resources>
|
|
@ -5,18 +5,17 @@
|
|||
<string name="button_send_restart_icon_description">Restart response</string>
|
||||
<string name="composer_placeholder_type">Type your message…</string>
|
||||
<string name="response_error_icon_description">Error responding</string>
|
||||
<string name="response_error_length_limit">Too long</string>
|
||||
<string name="response_error_length_limit">Length limit exceeded</string>
|
||||
<string name="response_error_generic">Fatal error</string>
|
||||
<string name="ai_mistake_warning">AI can make mistakes, double-check.</string>
|
||||
<string name="button_settings_icon_description">Settings</string>
|
||||
<string name="title_activity_main">MainActivity</string>
|
||||
<string name="quick_settings_select_model">Select model</string>
|
||||
<string name="model_card_select">Select</string>
|
||||
<string name="model_card_icon_description">Icon for %1$s</string>
|
||||
<string name="model_card_price_million_input_icon_description">Price per million input tokens</string>
|
||||
<string name="model_card_price_million_output_icon_description">Price per million output tokens</string>
|
||||
<string name="model_card_price_input">Input:</string>
|
||||
<string name="model_card_price_output">Output:</string>
|
||||
<string name="model_card_price_million_input_icon_description">Price per million prompt tokens</string>
|
||||
<string name="model_card_price_million_output_icon_description">Price per million completion tokens</string>
|
||||
<string name="model_card_price_input">Prompt:</string>
|
||||
<string name="model_card_price_output">Completion:</string>
|
||||
<string name="model_card_price_million_tokens">/ 1M tokens</string>
|
||||
<string name="model_card_toggle_icon_description">Toggle details</string>
|
||||
<string name="welcome">Welcome to ChatApp!</string>
|
||||
|
|
|
@ -6,5 +6,6 @@ plugins {
|
|||
alias(libs.plugins.hilt.android) apply false
|
||||
alias(libs.plugins.ksp) apply false
|
||||
alias(libs.plugins.secrets) apply false
|
||||
alias(libs.plugins.parcelize) apply false
|
||||
alias(libs.plugins.android.library) apply false
|
||||
alias(libs.plugins.kotlin.jvm) apply false
|
||||
}
|
|
@ -1,32 +1,29 @@
|
|||
[versions]
|
||||
agp = "8.10.1"
|
||||
kotlin = "2.1.21"
|
||||
coreKtx = "1.16.0"
|
||||
agp = "8.11.1"
|
||||
kotlin = "2.2.10"
|
||||
coreKtx = "1.17.0"
|
||||
junit = "4.13.2"
|
||||
junitVersion = "1.2.1"
|
||||
espressoCore = "3.6.1"
|
||||
junitVersion = "1.3.0"
|
||||
espressoCore = "3.7.0"
|
||||
appcompat = "1.7.1"
|
||||
material = "1.12.0"
|
||||
lifecycleRuntimeKtx = "2.9.1"
|
||||
material = "1.13.0"
|
||||
lifecycleRuntimeKtx = "2.9.3"
|
||||
activityCompose = "1.10.1"
|
||||
composeBom = "2025.06.01"
|
||||
hiltAndroid = "2.56.2"
|
||||
hiltCompiler = "2.56.2"
|
||||
ksp = "2.1.21-2.0.2"
|
||||
composeBom = "2025.08.01"
|
||||
hilt = "2.57.1"
|
||||
ksp = "2.2.10-2.0.2"
|
||||
retrofit = "3.0.0"
|
||||
secrets = "2.0.1"
|
||||
loggingInterceptor = "4.12.0"
|
||||
okhttp = "5.1.0"
|
||||
material3WindowSizeClass = "1.3.2"
|
||||
okhttpSse = "4.12.0"
|
||||
parcelize = "2.1.21"
|
||||
datastore = "1.1.7"
|
||||
hiltNavigationCompose = "1.2.0"
|
||||
roomRuntime = "2.7.2"
|
||||
roomCompiler = "2.7.2"
|
||||
roomPaging = "2.7.2"
|
||||
roomKtx = "2.7.2"
|
||||
pagingRuntime = "3.3.6"
|
||||
pagingCompose = "3.3.6"
|
||||
room = "2.7.2"
|
||||
kotlinxCoroutines = "1.10.2"
|
||||
paging = "3.3.6"
|
||||
moshi = "1.15.2"
|
||||
truth = "1.4.4"
|
||||
turbine = "1.2.1"
|
||||
|
||||
[libraries]
|
||||
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
|
||||
|
@ -45,27 +42,37 @@ androidx-ui-tooling-preview = { group = "androidx.compose.ui", name = "ui-toolin
|
|||
androidx-ui-test-manifest = { group = "androidx.compose.ui", name = "ui-test-manifest" }
|
||||
androidx-ui-test-junit4 = { group = "androidx.compose.ui", name = "ui-test-junit4" }
|
||||
androidx-material3 = { group = "androidx.compose.material3", name = "material3" }
|
||||
hilt-android = { group = "com.google.dagger", name = "hilt-android", version.ref = "hiltAndroid" }
|
||||
hilt-compiler = { group = "com.google.dagger", name = "hilt-compiler", version.ref = "hiltCompiler" }
|
||||
hilt-android = { group = "com.google.dagger", name = "hilt-android", version.ref = "hilt" }
|
||||
hilt-compiler = { group = "com.google.dagger", name = "hilt-compiler", version.ref = "hilt" }
|
||||
retrofit = { group = "com.squareup.retrofit2", name = "retrofit", version.ref = "retrofit" }
|
||||
retrofit-converter-gson = { group = "com.squareup.retrofit2", name = "converter-gson", version.ref = "retrofit"}
|
||||
logging-interceptor = { group = "com.squareup.okhttp3", name = "logging-interceptor", version.ref = "loggingInterceptor" }
|
||||
retrofit-converter-moshi = { group = "com.squareup.retrofit2", name = "converter-moshi", version.ref = "retrofit" }
|
||||
logging-interceptor = { group = "com.squareup.okhttp3", name = "logging-interceptor", version.ref = "okhttp" }
|
||||
okhttp-sse = { group = "com.squareup.okhttp3", name = "okhttp-sse", version.ref = "okhttp" }
|
||||
mockwebserver = { group = "com.squareup.okhttp3", name = "mockwebserver", version.ref = "okhttp" }
|
||||
androidx-material3-window-size-class1 = { group = "androidx.compose.material3", name = "material3-window-size-class", version.ref = "material3WindowSizeClass" }
|
||||
okhttp-sse = { group = "com.squareup.okhttp3", name = "okhttp-sse", version.ref = "okhttpSse" }
|
||||
androidx-datastore = { group = "androidx.datastore", name = "datastore", version.ref = "datastore" }
|
||||
hilt-navigation-compose = { group = "androidx.hilt", name = "hilt-navigation-compose", version.ref = "hiltNavigationCompose" }
|
||||
androidx-room-runtime = { group = "androidx.room", name = "room-runtime", version.ref = "roomRuntime" }
|
||||
androidx-room-compiler = { group = "androidx.room", name = "room-compiler", version.ref = "roomCompiler" }
|
||||
androidx-room-paging = { group = "androidx.room", name = "room-paging", version.ref = "roomPaging" }
|
||||
androidx-room-ktx = { group = "androidx.room", name = "room-ktx", version.ref = "roomKtx" }
|
||||
androidx-paging-runtime = { group = "androidx.paging", name = "paging-runtime", version.ref = "pagingRuntime" }
|
||||
androidx-paging-compose = { group = "androidx.paging", name = "paging-compose", version.ref = "pagingCompose" }
|
||||
androidx-room-runtime = { group = "androidx.room", name = "room-runtime", version.ref = "room" }
|
||||
androidx-room-compiler = { group = "androidx.room", name = "room-compiler", version.ref = "room" }
|
||||
androidx-room-paging = { group = "androidx.room", name = "room-paging", version.ref = "room" }
|
||||
androidx-room-ktx = { group = "androidx.room", name = "room-ktx", version.ref = "room" }
|
||||
kotlinx-coroutines = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines", version.ref = "kotlinxCoroutines" }
|
||||
kotlinx-coroutines-core = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version.ref = "kotlinxCoroutines" }
|
||||
kotlinx-coroutines-test = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-test", version.ref = "kotlinxCoroutines" }
|
||||
paging-runtime-ktx = { group = "androidx.paging", name = "paging-runtime-ktx", version.ref = "paging" }
|
||||
paging-compose = { group = "androidx.paging", name = "paging-compose", version = "3.4.0-alpha03" }
|
||||
moshi = { group = "com.squareup.moshi", name = "moshi", version.ref = "moshi" }
|
||||
moshi-kotlin = { group = "com.squareup.moshi", name = "moshi-kotlin", version.ref = "moshi" }
|
||||
moshi-kotlin-codegen = { group = "com.squareup.moshi", name = "moshi-kotlin-codegen", version.ref = "moshi" }
|
||||
truth = { group = "com.google.truth", name = "truth", version.ref = "truth" }
|
||||
turbine = { group = "app.cash.turbine", name = "turbine", version.ref = "turbine" }
|
||||
|
||||
[plugins]
|
||||
android-application = { id = "com.android.application", version.ref = "agp" }
|
||||
kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
|
||||
kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" }
|
||||
hilt-android = { id = "com.google.dagger.hilt.android", version.ref = "hiltAndroid" }
|
||||
kotlin-jvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" }
|
||||
hilt-android = { id = "com.google.dagger.hilt.android", version.ref = "hilt" }
|
||||
ksp = { id = "com.google.devtools.ksp", version.ref = "ksp" }
|
||||
secrets = { id = "com.google.android.libraries.mapsplatform.secrets-gradle-plugin", version.ref = "secrets" }
|
||||
parcelize = { id = "org.jetbrains.kotlin.plugin.parcelize", version.ref = "parcelize" }
|
||||
android-library = { id = "com.android.library", version.ref = "agp" }
|
||||
|
|
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
|
@ -1,6 +1,6 @@
|
|||
#Fri Jun 13 13:15:28 CEST 2025
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-8.11.1-bin.zip
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-8.13-bin.zip
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
|
|
|
@ -21,4 +21,6 @@ dependencyResolutionManagement {
|
|||
|
||||
rootProject.name = "Chat App"
|
||||
include(":app")
|
||||
|
||||
include(":api")
|
||||
include(":storage")
|
||||
include(":aiapi")
|
||||
|
|
1
storage/.gitignore
vendored
Normal file
1
storage/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
/build
|
47
storage/build.gradle.kts
Normal file
47
storage/build.gradle.kts
Normal file
|
@ -0,0 +1,47 @@
|
|||
plugins {
|
||||
alias(libs.plugins.android.library)
|
||||
alias(libs.plugins.kotlin.android)
|
||||
alias(libs.plugins.ksp)
|
||||
}
|
||||
|
||||
android {
|
||||
namespace = "eu.m724.chat.storage"
|
||||
compileSdk = 36
|
||||
|
||||
defaultConfig {
|
||||
minSdk = 35
|
||||
|
||||
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
|
||||
consumerProguardFiles("consumer-rules.pro")
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
release {
|
||||
isMinifyEnabled = false
|
||||
proguardFiles(
|
||||
getDefaultProguardFile("proguard-android-optimize.txt"),
|
||||
"proguard-rules.pro"
|
||||
)
|
||||
}
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility = JavaVersion.VERSION_17
|
||||
targetCompatibility = JavaVersion.VERSION_17
|
||||
}
|
||||
kotlinOptions {
|
||||
jvmTarget = "17"
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation(libs.androidx.core.ktx)
|
||||
implementation(libs.androidx.appcompat)
|
||||
implementation(libs.material)
|
||||
implementation(libs.androidx.room.runtime)
|
||||
implementation(libs.androidx.room.paging)
|
||||
implementation(libs.androidx.room.ktx)
|
||||
testImplementation(libs.junit)
|
||||
androidTestImplementation(libs.androidx.junit)
|
||||
androidTestImplementation(libs.androidx.espresso.core)
|
||||
ksp(libs.androidx.room.compiler)
|
||||
}
|
0
storage/consumer-rules.pro
Normal file
0
storage/consumer-rules.pro
Normal file
21
storage/proguard-rules.pro
vendored
Normal file
21
storage/proguard-rules.pro
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
|
@ -0,0 +1,24 @@
|
|||
package eu.m724.chat.storage
|
||||
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
|
||||
import org.junit.Assert.*
|
||||
|
||||
/**
|
||||
* Instrumented test, which will execute on an Android device.
|
||||
*
|
||||
* See [testing documentation](http://d.android.com/tools/testing).
|
||||
*/
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class ExampleInstrumentedTest {
|
||||
@Test
|
||||
fun useAppContext() {
|
||||
// Context of the app under test.
|
||||
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
assertEquals("eu.m724.chat.storage.test", appContext.packageName)
|
||||
}
|
||||
}
|
17
storage/src/main/java/eu/m724/chat/storage/AppDatabase.kt
Normal file
17
storage/src/main/java/eu/m724/chat/storage/AppDatabase.kt
Normal file
|
@ -0,0 +1,17 @@
|
|||
package eu.m724.chat.storage
|
||||
|
||||
import androidx.room.Database
|
||||
import androidx.room.RoomDatabase
|
||||
import eu.m724.chat.storage.dao.ChatDao
|
||||
import eu.m724.chat.storage.dao.MessageDao
|
||||
import eu.m724.chat.storage.entity.ChatEntity
|
||||
import eu.m724.chat.storage.entity.MessageEntity
|
||||
|
||||
@Database(entities = [
|
||||
ChatEntity::class,
|
||||
MessageEntity::class
|
||||
], version = 1)
|
||||
abstract class AppDatabase : RoomDatabase() {
|
||||
abstract fun chatDao(): ChatDao
|
||||
abstract fun messageDao(): MessageDao
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
package eu.m724.chat.storage
|
||||
|
||||
import android.content.Context
|
||||
import androidx.room.Room
|
||||
import eu.m724.chat.storage.repository.ChatStorageRepository
|
||||
import eu.m724.chat.storage.repository.ChatStorageRepositoryImpl
|
||||
|
||||
object ChatStorageDataLayerFactory {
|
||||
private fun createAppDatabase(context: Context): AppDatabase {
|
||||
return Room.databaseBuilder(
|
||||
context,
|
||||
AppDatabase::class.java,
|
||||
"chatapp-database"
|
||||
).build()
|
||||
}
|
||||
|
||||
fun createChatStorageRepository(context: Context): ChatStorageRepository {
|
||||
val database = createAppDatabase(context)
|
||||
|
||||
return ChatStorageRepositoryImpl(database)
|
||||
}
|
||||
}
|
27
storage/src/main/java/eu/m724/chat/storage/dao/ChatDao.kt
Normal file
27
storage/src/main/java/eu/m724/chat/storage/dao/ChatDao.kt
Normal file
|
@ -0,0 +1,27 @@
|
|||
package eu.m724.chat.storage.dao
|
||||
|
||||
import androidx.room.Dao
|
||||
import androidx.room.Delete
|
||||
import androidx.room.Insert
|
||||
import androidx.room.OnConflictStrategy
|
||||
import androidx.room.Query
|
||||
import eu.m724.chat.storage.entity.ChatEntity
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
|
||||
@Dao
|
||||
interface ChatDao {
|
||||
@Query("SELECT * FROM chats ORDER BY lastUpdated DESC")
|
||||
fun getAllChats(): Flow<List<ChatEntity>>
|
||||
|
||||
@Query("SELECT * FROM chats WHERE id = :id")
|
||||
fun getChatById(id: Long): Flow<ChatEntity?>
|
||||
|
||||
@Insert
|
||||
suspend fun insertChat(chat: ChatEntity): Long
|
||||
|
||||
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||
suspend fun updateChat(chat: ChatEntity)
|
||||
|
||||
@Delete
|
||||
suspend fun deleteChat(chat: ChatEntity)
|
||||
}
|
31
storage/src/main/java/eu/m724/chat/storage/dao/MessageDao.kt
Normal file
31
storage/src/main/java/eu/m724/chat/storage/dao/MessageDao.kt
Normal file
|
@ -0,0 +1,31 @@
|
|||
package eu.m724.chat.storage.dao
|
||||
|
||||
import androidx.paging.PagingSource
|
||||
import androidx.room.Dao
|
||||
import androidx.room.Delete
|
||||
import androidx.room.Insert
|
||||
import androidx.room.Query
|
||||
import androidx.room.Update
|
||||
import eu.m724.chat.storage.entity.MessageEntity
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
|
||||
@Dao
|
||||
interface MessageDao {
|
||||
@Query("SELECT * FROM messages WHERE chatId = :chatId")
|
||||
fun pagingSource(chatId: Long): PagingSource<Int, MessageEntity>
|
||||
|
||||
@Query("SELECT * FROM messages WHERE chatId = :chatId")
|
||||
fun getAllMessages(chatId: Long): Flow<List<MessageEntity>>
|
||||
|
||||
@Insert
|
||||
suspend fun insertMessage(message: MessageEntity): Long
|
||||
|
||||
@Update
|
||||
suspend fun updateMessage(message: MessageEntity)
|
||||
|
||||
@Delete
|
||||
suspend fun deleteMessage(message: MessageEntity)
|
||||
|
||||
@Query("DELETE FROM messages WHERE id = :id")
|
||||
suspend fun deleteMessageById(id: Long)
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package eu.m724.chatapp.store.room.entity
|
||||
package eu.m724.chat.storage.entity
|
||||
|
||||
import androidx.room.Entity
|
||||
import androidx.room.PrimaryKey
|
||||
|
@ -8,14 +8,24 @@ data class ChatEntity(
|
|||
/**
|
||||
* The unique identifier of this chat.
|
||||
*/
|
||||
@PrimaryKey(autoGenerate = true)
|
||||
val id: Int = 0,
|
||||
@PrimaryKey
|
||||
val id: Long,
|
||||
|
||||
/**
|
||||
* The last time this chat was updated (clicked) in milliseconds since epoch.
|
||||
*/
|
||||
val lastUpdated: Long,
|
||||
|
||||
/**
|
||||
* The title of this chat, null if not set.
|
||||
*/
|
||||
val title: String?,
|
||||
|
||||
/**
|
||||
* The subtitle of this chat, usually the last message, null if not set.
|
||||
*/
|
||||
val subtitle: String?,
|
||||
|
||||
/**
|
||||
* The model ID used in this chat.
|
||||
*/
|
|
@ -1,4 +1,4 @@
|
|||
package eu.m724.chatapp.store.room.entity
|
||||
package eu.m724.chat.storage.entity
|
||||
|
||||
import androidx.room.Entity
|
||||
import androidx.room.Index
|
||||
|
@ -12,20 +12,21 @@ import androidx.room.PrimaryKey
|
|||
)
|
||||
data class MessageEntity(
|
||||
/**
|
||||
* The unique identifier of this message. TODO make random perhaps
|
||||
* The identifier of this message.
|
||||
*/
|
||||
@PrimaryKey(autoGenerate = true)
|
||||
val id: Int = 0,
|
||||
|
||||
/**
|
||||
* The index of this message in the chat.
|
||||
*/
|
||||
val index: Int,
|
||||
val id: Long = 0,
|
||||
|
||||
/**
|
||||
* The ID of the chat this message belongs to.
|
||||
* TODO relation
|
||||
*/
|
||||
val chatId: Int,
|
||||
val chatId: Long,
|
||||
|
||||
/**
|
||||
* The role of this message.
|
||||
*/
|
||||
val assistant: Boolean,
|
||||
|
||||
/**
|
||||
* The content of this message.
|
|
@ -0,0 +1,77 @@
|
|||
package eu.m724.chat.storage.repository
|
||||
|
||||
import androidx.paging.PagingData
|
||||
import eu.m724.chat.storage.entity.ChatEntity
|
||||
import eu.m724.chat.storage.entity.MessageEntity
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
|
||||
interface ChatStorageRepository {
|
||||
fun pagedMessages(chatId: Long): Flow<PagingData<MessageEntity>>
|
||||
|
||||
/**
|
||||
* Get all chats from the database.
|
||||
*
|
||||
* @return A flow of lists of chats.
|
||||
*/
|
||||
fun listChats(): Flow<List<ChatEntity>>
|
||||
|
||||
/**
|
||||
* Get a chat from the database.
|
||||
*
|
||||
* @param chatId The ID of the chat to get.
|
||||
* @return The chat, or null if it doesn't exist.
|
||||
*/
|
||||
fun getChat(chatId: Long): Flow<ChatEntity?>
|
||||
|
||||
/**
|
||||
* Add a new chat to the database.
|
||||
*
|
||||
* @param chatEntity The chat to add.
|
||||
* @return The ID of the newly added chat.
|
||||
*/
|
||||
suspend fun addChat(chatEntity: ChatEntity): Long
|
||||
|
||||
/**
|
||||
* Update a chat (or insert a new one) in the database.
|
||||
*
|
||||
* @param chatEntity The chat to update.
|
||||
*/
|
||||
suspend fun updateChat(chatEntity: ChatEntity)
|
||||
|
||||
/**
|
||||
* Delete a chat from the database.
|
||||
*
|
||||
* @param chatEntity The chat to delete.
|
||||
*/
|
||||
suspend fun deleteChat(chatEntity: ChatEntity)
|
||||
|
||||
/**
|
||||
* Get all messages from a chat.
|
||||
*
|
||||
* @param chatId The ID of the chat to get messages from.
|
||||
* @return A list of messages.
|
||||
*/
|
||||
fun listMessages(chatId: Long): Flow<List<MessageEntity>>
|
||||
|
||||
/**
|
||||
* Add a new message to the database.
|
||||
*
|
||||
* @param messageEntity The message to add.
|
||||
* @return The ID of the newly added message.
|
||||
*/
|
||||
suspend fun addMessage(messageEntity: MessageEntity): Long
|
||||
|
||||
/**
|
||||
* Delete a message from the database.
|
||||
*
|
||||
* @param messageEntity The message to delete.
|
||||
*/
|
||||
suspend fun deleteMessage(messageEntity: MessageEntity)
|
||||
|
||||
/**
|
||||
* Delete a message from the database.
|
||||
*
|
||||
* @param id The ID of the message to delete.
|
||||
*/
|
||||
suspend fun deleteMessageById(id: Long)
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package eu.m724.chat.storage.repository
|
||||
|
||||
import androidx.paging.Pager
|
||||
import androidx.paging.PagingConfig
|
||||
import androidx.paging.PagingData
|
||||
import eu.m724.chat.storage.AppDatabase
|
||||
import eu.m724.chat.storage.entity.ChatEntity
|
||||
import eu.m724.chat.storage.entity.MessageEntity
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
|
||||
// TODO make this interface
|
||||
internal class ChatStorageRepositoryImpl(
|
||||
private val database: AppDatabase
|
||||
) : ChatStorageRepository {
|
||||
override fun pagedMessages(chatId: Long): Flow<PagingData<MessageEntity>> = Pager(
|
||||
config = PagingConfig(pageSize = 20),
|
||||
pagingSourceFactory = { database.messageDao().pagingSource(chatId) }
|
||||
).flow
|
||||
override fun listChats(): Flow<List<ChatEntity>> {
|
||||
return database.chatDao().getAllChats()
|
||||
}
|
||||
|
||||
override fun getChat(chatId: Long): Flow<ChatEntity?> {
|
||||
return database.chatDao().getChatById(chatId)
|
||||
}
|
||||
|
||||
override suspend fun addChat(chatEntity: ChatEntity): Long {
|
||||
return database.chatDao().insertChat(chatEntity)
|
||||
}
|
||||
|
||||
override suspend fun updateChat(chatEntity: ChatEntity) {
|
||||
return database.chatDao().updateChat(chatEntity)
|
||||
}
|
||||
|
||||
override suspend fun deleteChat(chatEntity: ChatEntity) {
|
||||
database.chatDao().deleteChat(chatEntity)
|
||||
}
|
||||
|
||||
override suspend fun addMessage(messageEntity: MessageEntity): Long {
|
||||
return database.messageDao().insertMessage(messageEntity)
|
||||
}
|
||||
|
||||
override suspend fun deleteMessage(messageEntity: MessageEntity) {
|
||||
database.messageDao().deleteMessage(messageEntity)
|
||||
}
|
||||
|
||||
override suspend fun deleteMessageById(id: Long) {
|
||||
database.messageDao().deleteMessageById(id)
|
||||
}
|
||||
|
||||
override fun listMessages(chatId: Long): Flow<List<MessageEntity>> {
|
||||
return database.messageDao().getAllMessages(chatId)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
package eu.m724.chat.storage
|
||||
|
||||
import org.junit.Test
|
||||
|
||||
import org.junit.Assert.*
|
||||
|
||||
/**
|
||||
* Example local unit test, which will execute on the development machine (host).
|
||||
*
|
||||
* See [testing documentation](http://d.android.com/tools/testing).
|
||||
*/
|
||||
class ExampleUnitTest {
|
||||
@Test
|
||||
fun addition_isCorrect() {
|
||||
assertEquals(4, 2 + 2)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue