Basic API ability
This commit is contained in:
parent
3a821889a9
commit
4d2dbe166c
17 changed files with 291 additions and 43 deletions
2
.idea/.gitignore
generated
vendored
2
.idea/.gitignore
generated
vendored
|
@ -1,3 +1,5 @@
|
|||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
|
||||
deploymentTargetSelector.xml
|
10
.idea/deploymentTargetSelector.xml
generated
10
.idea/deploymentTargetSelector.xml
generated
|
@ -1,10 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="deploymentTargetSelector">
|
||||
<selectionStates>
|
||||
<SelectionState runConfigName="app">
|
||||
<option name="selectionMode" value="DROPDOWN" />
|
||||
</SelectionState>
|
||||
</selectionStates>
|
||||
</component>
|
||||
</project>
|
|
@ -64,5 +64,6 @@ dependencies {
|
|||
androidTestImplementation(libs.androidx.ui.test.junit4)
|
||||
debugImplementation(libs.androidx.ui.tooling)
|
||||
debugImplementation(libs.androidx.ui.test.manifest)
|
||||
debugImplementation(libs.logging.interceptor)
|
||||
ksp(libs.hilt.compiler)
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package eu.m724.chatapp.activity.chat
|
||||
|
||||
import android.os.Bundle
|
||||
import android.widget.Toast
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.compose.setContent
|
||||
import androidx.activity.enableEdgeToEdge
|
||||
|
@ -45,6 +46,7 @@ import androidx.compose.ui.Modifier
|
|||
import androidx.compose.ui.focus.FocusRequester
|
||||
import androidx.compose.ui.focus.focusRequester
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalSoftwareKeyboardController
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.lifecycle.compose.collectAsStateWithLifecycle
|
||||
|
@ -73,6 +75,7 @@ fun Content(
|
|||
) {
|
||||
val uiState by viewModel.uiState.collectAsStateWithLifecycle()
|
||||
val localSoftwareKeyboardController = LocalSoftwareKeyboardController.current
|
||||
val context = LocalContext.current
|
||||
|
||||
var composerValue by remember { mutableStateOf("") }
|
||||
val composerFocusRequester = remember { FocusRequester() }
|
||||
|
@ -81,8 +84,12 @@ fun Content(
|
|||
|
||||
LaunchedEffect(uiState.requestInProgress) { // TODO probably not the best way
|
||||
if (!uiState.requestInProgress) {
|
||||
composerValue = ""
|
||||
composerFocusRequester.requestFocus() // TODO maybe make toggleable? or smart?
|
||||
if (uiState.requestLastError == null) {
|
||||
composerValue = ""
|
||||
composerFocusRequester.requestFocus() // TODO maybe make toggleable? or smart?
|
||||
} else {
|
||||
Toast.makeText(context, uiState.requestLastError, Toast.LENGTH_SHORT).show() // TODO better way of showing this
|
||||
}
|
||||
} else {
|
||||
if (!uiState.messageHistory.isEmpty()) {
|
||||
lazyListState.animateScrollToItem(uiState.messageHistory.size)
|
||||
|
@ -168,7 +175,7 @@ fun ChatTopAppBar(
|
|||
) {
|
||||
AnimatedChangingText(
|
||||
text = title,
|
||||
)
|
||||
) // TODO fade
|
||||
}
|
||||
}
|
||||
)
|
||||
|
@ -176,9 +183,9 @@ fun ChatTopAppBar(
|
|||
|
||||
@Composable
|
||||
fun MessageExchange(
|
||||
modifier: Modifier = Modifier,
|
||||
isComposing: Boolean,
|
||||
composerValue: String,
|
||||
modifier: Modifier = Modifier,
|
||||
onComposerValueChange: (String) -> Unit = {},
|
||||
composerFocusRequester: FocusRequester = FocusRequester(),
|
||||
responseValue: String = "",
|
||||
|
|
|
@ -1,10 +1,30 @@
|
|||
package eu.m724.chatapp.activity.chat
|
||||
|
||||
data class ChatActivityUiState(
|
||||
/**
|
||||
* The title of the current chat
|
||||
*/
|
||||
val chatTitle: String? = null,
|
||||
|
||||
/**
|
||||
* Whether a request is in progress (a response is streaming)
|
||||
*/
|
||||
val requestInProgress: Boolean = false,
|
||||
|
||||
/**
|
||||
* The response right now, updates when streaming
|
||||
*/
|
||||
val currentMessageResponse: String = "",
|
||||
val messageHistory: List<ChatMessageExchange> = listOf()
|
||||
|
||||
/**
|
||||
* All the messages of this chat
|
||||
*/
|
||||
val messageHistory: List<ChatMessageExchange> = listOf(),
|
||||
|
||||
/**
|
||||
* Error, if any, of the last request
|
||||
*/
|
||||
val requestLastError: String? = null
|
||||
)
|
||||
|
||||
data class ChatMessageExchange(
|
||||
|
|
|
@ -4,7 +4,8 @@ import androidx.lifecycle.ViewModel
|
|||
import androidx.lifecycle.viewModelScope
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import eu.m724.chatapp.api.AiApiService
|
||||
import kotlinx.coroutines.delay
|
||||
import eu.m724.chatapp.api.data.ChatMessage
|
||||
import eu.m724.chatapp.api.data.request.ChatCompletionRequest
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
|
@ -19,13 +20,9 @@ class ChatActivityViewModel @Inject constructor(
|
|||
private val _uiState = MutableStateFlow(ChatActivityUiState())
|
||||
val uiState: StateFlow<ChatActivityUiState> = _uiState.asStateFlow()
|
||||
|
||||
val responses = arrayOf(
|
||||
"Hello right back at you! How can I help you today?",
|
||||
"I'm sorry, but I can't assist with that."
|
||||
)
|
||||
private val messages = mutableListOf<ChatMessage>()
|
||||
|
||||
fun sendMessage(message: String) {
|
||||
|
||||
_uiState.update {
|
||||
it.copy(
|
||||
requestInProgress = true,
|
||||
|
@ -39,26 +36,44 @@ class ChatActivityViewModel @Inject constructor(
|
|||
}
|
||||
}
|
||||
|
||||
messages.add(ChatMessage(
|
||||
role = ChatMessage.Role.USER,
|
||||
content = message
|
||||
))
|
||||
|
||||
|
||||
viewModelScope.launch {
|
||||
val response = responses.random()
|
||||
val targetResponseParts = response.split(" ")
|
||||
|
||||
for (part in targetResponseParts) {
|
||||
delay(50)
|
||||
val response = aiApiService.chatComplete(ChatCompletionRequest(
|
||||
model = "free-model",
|
||||
messages = messages,
|
||||
temperature = 1.0f,
|
||||
maxTokens = 128,
|
||||
frequencyPenalty = 0.0f,
|
||||
presencePenalty = 0.0f
|
||||
))
|
||||
|
||||
if (!response.isSuccessful || response.body() == null) {
|
||||
_uiState.update {
|
||||
it.copy(
|
||||
currentMessageResponse = it.currentMessageResponse.trim() + " $part"
|
||||
requestInProgress = false,
|
||||
requestLastError = response.code().toString()
|
||||
)
|
||||
}
|
||||
|
||||
// TODO launch toast or something
|
||||
return@launch
|
||||
}
|
||||
|
||||
val completion = response.body()!!
|
||||
val choice = completion.choices[0]
|
||||
|
||||
messages.add(choice.message)
|
||||
|
||||
_uiState.update {
|
||||
it.copy(
|
||||
requestInProgress = false,
|
||||
messageHistory = it.messageHistory.plus(
|
||||
ChatMessageExchange(message, response)
|
||||
ChatMessageExchange(message, choice.message.content)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ import androidx.compose.ui.graphics.Color
|
|||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.unit.dp
|
||||
import eu.m724.chatapp.api.response.LanguageModel
|
||||
import eu.m724.chatapp.api.data.response.LanguageModel
|
||||
import java.math.RoundingMode
|
||||
import java.text.DecimalFormat
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package eu.m724.chatapp.activity.select
|
||||
|
||||
import eu.m724.chatapp.api.response.LanguageModel
|
||||
import eu.m724.chatapp.api.data.response.LanguageModel
|
||||
|
||||
data class SelectModelUiState(
|
||||
val models: List<LanguageModel> = listOf()
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
package eu.m724.chatapp.api
|
||||
|
||||
import com.google.gson.FieldNamingPolicy
|
||||
import com.google.gson.GsonBuilder
|
||||
import dagger.Module
|
||||
import dagger.Provides
|
||||
import dagger.hilt.InstallIn
|
||||
import dagger.hilt.components.SingletonComponent
|
||||
import eu.m724.chatapp.BuildConfig
|
||||
import okhttp3.OkHttpClient
|
||||
import okhttp3.logging.HttpLoggingInterceptor
|
||||
import retrofit2.Retrofit
|
||||
import retrofit2.converter.gson.GsonConverterFactory
|
||||
import javax.inject.Singleton
|
||||
|
@ -16,25 +19,33 @@ object AiApiNetworkModule {
|
|||
@Provides
|
||||
@Singleton
|
||||
fun provideOkHttpClient(): OkHttpClient {
|
||||
return OkHttpClient.Builder()
|
||||
.addInterceptor {
|
||||
it.proceed(
|
||||
it.request().newBuilder()
|
||||
.header("User-Agent", BuildConfig.USER_AGENT)
|
||||
.build()
|
||||
)
|
||||
// TODO add api key here
|
||||
}
|
||||
.build()
|
||||
val interceptor = AiApiRequestInterceptor(
|
||||
userAgent = BuildConfig.USER_AGENT,
|
||||
apiEndpoint = BuildConfig.API_ENDPOINT,
|
||||
apiKey = BuildConfig.API_KEY
|
||||
)
|
||||
|
||||
val builder = OkHttpClient.Builder()
|
||||
.addInterceptor(interceptor)
|
||||
|
||||
if (BuildConfig.DEBUG) {
|
||||
builder.addInterceptor(HttpLoggingInterceptor().setLevel(HttpLoggingInterceptor.Level.BODY))
|
||||
}
|
||||
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideRetrofit(okHttpClient: OkHttpClient): Retrofit {
|
||||
val gson = GsonBuilder()
|
||||
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) // snake_case
|
||||
.create()
|
||||
|
||||
return Retrofit.Builder()
|
||||
.baseUrl(BuildConfig.API_ENDPOINT)
|
||||
.client(okHttpClient)
|
||||
.addConverterFactory(GsonConverterFactory.create())
|
||||
.addConverterFactory(GsonConverterFactory.create(gson))
|
||||
.build()
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
package eu.m724.chatapp.api
|
||||
|
||||
import okhttp3.Interceptor
|
||||
import okhttp3.Response
|
||||
|
||||
class AiApiRequestInterceptor(
|
||||
private val userAgent: String,
|
||||
private val apiEndpoint: String,
|
||||
private val apiKey: String
|
||||
) : Interceptor {
|
||||
override fun intercept(chain: Interceptor.Chain): Response {
|
||||
val builder = chain.request().newBuilder()
|
||||
.header("User-Agent", userAgent)
|
||||
|
||||
if (chain.request().url.toString().startsWith(apiEndpoint)) {
|
||||
builder.header("Authorization", "Bearer $apiKey")
|
||||
}
|
||||
|
||||
return chain.proceed(builder.build())
|
||||
}
|
||||
}
|
|
@ -1,10 +1,17 @@
|
|||
package eu.m724.chatapp.api
|
||||
|
||||
import eu.m724.chatapp.api.response.LanguageModelsResponse
|
||||
import eu.m724.chatapp.api.data.request.ChatCompletionRequest
|
||||
import eu.m724.chatapp.api.data.response.ChatCompletionResponse
|
||||
import eu.m724.chatapp.api.data.response.LanguageModelsResponse
|
||||
import retrofit2.Response
|
||||
import retrofit2.http.Body
|
||||
import retrofit2.http.GET
|
||||
import retrofit2.http.POST
|
||||
|
||||
interface AiApiService {
|
||||
@GET("models?detailed=true")
|
||||
suspend fun getModels(): Response<LanguageModelsResponse>
|
||||
|
||||
@POST("chat/completions")
|
||||
suspend fun chatComplete(@Body body: ChatCompletionRequest): Response<ChatCompletionResponse>
|
||||
}
|
12
app/src/main/java/eu/m724/chatapp/api/data/ChatMessage.kt
Normal file
12
app/src/main/java/eu/m724/chatapp/api/data/ChatMessage.kt
Normal file
|
@ -0,0 +1,12 @@
|
|||
package eu.m724.chatapp.api.data
|
||||
|
||||
data class ChatMessage(
|
||||
val role: Role,
|
||||
val content: String
|
||||
) {
|
||||
enum class Role {
|
||||
SYSTEM,
|
||||
USER,
|
||||
ASSISTANT
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package eu.m724.chatapp.api.data.request
|
||||
|
||||
import eu.m724.chatapp.api.data.ChatMessage
|
||||
|
||||
data class ChatCompletionRequest(
|
||||
/**
|
||||
* The model ID
|
||||
*/
|
||||
val model: String,
|
||||
|
||||
/**
|
||||
* The messages in the current chat. Usually a user message is the last one if making a request.
|
||||
*/
|
||||
val messages: List<ChatMessage>,
|
||||
|
||||
/**
|
||||
* Controls the "creativity" of the model.
|
||||
* Read more: https://www.iguazio.com/glossary/llm-temperature/
|
||||
*/
|
||||
val temperature: Float,
|
||||
|
||||
/**
|
||||
* The maximum amount of tokens to generate
|
||||
*/
|
||||
val maxTokens: Int,
|
||||
|
||||
/**
|
||||
* Controls the repetition of words in the generated text.
|
||||
* Applies an incremental penalty, depending on how many times a token appears in the text.
|
||||
*
|
||||
* Read more: https://www.promptitude.io/glossary/frequency-penalty
|
||||
* @see presencePenalty
|
||||
*/
|
||||
val frequencyPenalty: Float,
|
||||
|
||||
/**
|
||||
* Controls the repetition of words in the generated text.
|
||||
* Applies a constant penalty, no matter how many times a token appears in the text.
|
||||
*
|
||||
* Read more: https://www.promptitude.io/glossary/frequency-penalty
|
||||
* @see frequencyPenalty
|
||||
*/
|
||||
val presencePenalty: Float
|
||||
) {
|
||||
init {
|
||||
require(temperature >= 0.0) { "temperature must be at least 0.0" }
|
||||
require(temperature <= 2.0) { "temperature must be at most 2.0" }
|
||||
|
||||
require(maxTokens >= 0) { "maxTokens must be at least 0. If you don't want a limit here, use Int.MAX_VALUE"}
|
||||
|
||||
require(frequencyPenalty >= -2.0) { "frequencyPenalty must be at least -2.0" }
|
||||
require(frequencyPenalty <= 2.0) { "frequencyPenalty must be at most 2.0" }
|
||||
|
||||
require(presencePenalty >= -2.0) { "presencePenalty must be at least -2.0" }
|
||||
require(presencePenalty <= 2.0) { "presencePenalty must be at most 2.0" }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
package eu.m724.chatapp.api.data.response
|
||||
|
||||
import com.google.gson.annotations.JsonAdapter
|
||||
import com.google.gson.annotations.SerializedName
|
||||
import eu.m724.chatapp.api.data.ChatMessage
|
||||
import eu.m724.chatapp.api.serialize.EpochSecondToLocalDateTimeDeserializer
|
||||
import java.time.LocalDateTime
|
||||
|
||||
data class ChatCompletionResponse(
|
||||
/**
|
||||
* Request ID
|
||||
*/
|
||||
val id: String,
|
||||
|
||||
/**
|
||||
* Request time
|
||||
*/
|
||||
@SerializedName("created")
|
||||
@JsonAdapter(EpochSecondToLocalDateTimeDeserializer::class)
|
||||
val createdAt: LocalDateTime,
|
||||
|
||||
/**
|
||||
* Completion choices. Usually has only one element.
|
||||
*/
|
||||
val choices: List<CompletionChoice>,
|
||||
|
||||
/**
|
||||
* The cost (in tokens) of this completion
|
||||
*/
|
||||
@SerializedName("usage")
|
||||
val tokenUsage: CompletionTokenUsage
|
||||
)
|
||||
|
||||
data class CompletionChoice(
|
||||
val index: Int,
|
||||
|
||||
/**
|
||||
* The generated message
|
||||
*/
|
||||
val message: ChatMessage,
|
||||
|
||||
/**
|
||||
* The reason why generating the response has stopped
|
||||
*/
|
||||
val finishReason: CompletionFinishReason
|
||||
)
|
||||
|
||||
enum class CompletionFinishReason {
|
||||
/**
|
||||
* The response has stopped, because the model said so
|
||||
*/
|
||||
STOP,
|
||||
|
||||
/**
|
||||
* The response has stopped, because it got too long
|
||||
*/
|
||||
LENGTH,
|
||||
|
||||
/**
|
||||
* The response has stopped, because the content got flagged
|
||||
*/
|
||||
CONTENT_FILTER
|
||||
}
|
||||
|
||||
data class CompletionTokenUsage(
|
||||
/**
|
||||
* The amount of tokens of the prompt
|
||||
*/
|
||||
val promptTokens: Int,
|
||||
|
||||
/**
|
||||
* The amount of tokens of the generated completion
|
||||
*/
|
||||
val completionTokens: Int,
|
||||
|
||||
/**
|
||||
* The total amount of tokens processed
|
||||
*/
|
||||
val totalTokens: Int
|
||||
)
|
|
@ -1,4 +1,4 @@
|
|||
package eu.m724.chatapp.api.response
|
||||
package eu.m724.chatapp.api.data.response
|
||||
|
||||
import com.google.gson.annotations.SerializedName
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
package eu.m724.chatapp.api.serialize
|
||||
|
||||
import com.google.gson.JsonDeserializationContext
|
||||
import com.google.gson.JsonDeserializer
|
||||
import com.google.gson.JsonElement
|
||||
import com.google.gson.JsonParseException
|
||||
import java.lang.reflect.Type
|
||||
import java.time.LocalDateTime
|
||||
import java.time.ZoneOffset
|
||||
|
||||
class EpochSecondToLocalDateTimeDeserializer : JsonDeserializer<LocalDateTime> {
|
||||
override fun deserialize(
|
||||
json: JsonElement?,
|
||||
typeOfT: Type?,
|
||||
context: JsonDeserializationContext?
|
||||
): LocalDateTime? {
|
||||
json?.asLong?.let { timestamp ->
|
||||
return LocalDateTime.ofEpochSecond(timestamp, 0, ZoneOffset.UTC)
|
||||
}
|
||||
|
||||
throw JsonParseException("Error deserializing LocalDateTime from $json")
|
||||
}
|
||||
}
|
|
@ -14,6 +14,7 @@ hilt = "2.56.2"
|
|||
ksp = "2.1.21-2.0.2"
|
||||
retrofit = "3.0.0"
|
||||
secrets = "2.0.1"
|
||||
loggingInterceptor = "4.12.0"
|
||||
|
||||
[libraries]
|
||||
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
|
||||
|
@ -36,6 +37,7 @@ hilt-android = { group = "com.google.dagger", name = "hilt-android", version.ref
|
|||
hilt-compiler = { group = "com.google.dagger", name = "hilt-compiler", version.ref = "hilt" }
|
||||
retrofit = { group = "com.squareup.retrofit2", name = "retrofit", version.ref = "retrofit" }
|
||||
retrofit-converter-gson = { group = "com.squareup.retrofit2", name = "converter-gson", version.ref = "retrofit"}
|
||||
logging-interceptor = { group = "com.squareup.okhttp3", name = "logging-interceptor", version.ref = "loggingInterceptor" }
|
||||
|
||||
[plugins]
|
||||
android-application = { id = "com.android.application", version.ref = "agp" }
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue