switches and toggles

This commit is contained in:
Minecon724 2024-08-31 15:38:29 +02:00
parent 1eb880be4f
commit f75d1ca715
Signed by: Minecon724
GPG key ID: 3CCC4D267742C8E8
9 changed files with 175 additions and 22 deletions

View file

@ -4,27 +4,29 @@ import eu.m724.chat.Chat;
import eu.m724.chat.ChatEvent;
import eu.m724.chat.ChatMessage;
import eu.m724.example.ExampleSource;
import eu.m724.example.OaiSource;
import eu.m724.source.ChatResponse;
import eu.m724.source.ChatSource;
import eu.m724.source.option.Option;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Scanner;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
public class Main {
public static void main(String[] args) throws InterruptedException {
ChatSource source = new ExampleSource();
source.options().setValue("name", "nekalakininahappenenawiwanatin");
//ChatSource source = new ExampleSource();
//source.options().setValue("name", "nekalakininahappenenawiwanatin");
Chat chat = new Chat();
ChatSource source = new OaiSource(System.getenv("API_KEY"));
//source.options().setValue("model", "chatgpt-4o-latest");
Chat chat = new Chat("Speak in uwu wanguage.");
System.out.println("Welcome to CHUT chat. Say something after the \033[1m>\033[0m, or type \033[1m:help\033[0m to see available commands");
Scanner scanner = new Scanner(System.in);
@ -39,14 +41,21 @@ public class Main {
ChatEvent token;
int i = 0;
while ((token = chatResponse.eventQueue().take()).finishReason() == null) {
do {
token = chatResponse.eventQueue().take();
if (token.finishReason() != "error") {
System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m");
System.out.print(token.text());
} else {
System.out.print("Error: " + token.error().toString());
}
} while (token.finishReason() == null);
System.out.println();
} else {
String[] parts = prompt.substring(1).split(" ");
switch (parts[0]) {
case "help":
case "h":
@ -54,6 +63,7 @@ public class Main {
System.out.println("Available commands:");
System.out.println(":help - this");
System.out.println(":dump - recap of the chat");
System.out.println(":opt - change source options");
break;
case "dump":
case "d":
@ -68,6 +78,21 @@ public class Main {
System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text());
}
break;
case "options":
case "opt":
case "o":
if (parts.length < 3 || Arrays.stream(source.options().keys()).noneMatch(parts[1]::equals)) {
System.out.println("Available options:");
for (Map.Entry<String, Option<?>> entry : source.options().getOptions().entrySet()) {
System.out.printf("%s (%s) = %s\n", entry.getValue().label, entry.getKey(), entry.getValue().getValue().toString());
}
} else {
Option<?> option = source.options().getOptions().get(parts[1]);
Object value = option.fromString(parts[2]);
option.setValue(value);
System.out.printf("Set %s to %s\n", option.label, option.getValue());
}
break;
default:
System.out.println("Invalid command: " + parts[0]);
break;

View file

@ -6,7 +6,10 @@ import eu.m724.chat.ChatMessage
import eu.m724.source.ChatResponse
import eu.m724.source.ChatSource
import eu.m724.source.ChatSourceInfo
import eu.m724.source.NonStreamingChatResponse
import eu.m724.source.SimpleChatResponse
import eu.m724.source.option.DoubleOption
import eu.m724.source.option.NumberOption
import eu.m724.source.option.Options
import eu.m724.source.option.StringOption
import org.json.JSONArray
@ -21,11 +24,19 @@ import java.util.concurrent.LinkedBlockingQueue
// for now let's not focus on readability
// this is more about find out what is common and should be included in the common api
class OaiSource implements ChatSource {
private final String apiKey
OaiSource(String apiKey) {
this.apiKey = apiKey
}
private def info =
new ChatSourceInfo("oai source", "me", "1.0", 1)
private def options = new Options(
new StringOption("apiKey", "API key", null),
new StringOption("apiKey", "API key", apiKey),
new StringOption("model", "Model", "gpt-4o-mini"),
new DoubleOption("temperature", "Temperature", 1.2)
)
@Override
@ -40,8 +51,13 @@ class OaiSource implements ChatSource {
@Override
ChatResponse onAsked(Chat chat) {
def apiKey = options.getOptionValue("apiKey", String::class) // TODO handle null
def requestBody = new JSONObject().put("model", "gpt-4o-mini").put("messages", formatChat(chat)).toString()
def apiKey = options.getStringValue("apiKey") // TODO handle null
def requestBody = new JSONObject()
.put("model", options.getStringValue("model"))
.put("temperature", options.getOptionValue("temperature", Double::class))
.put("presence_penalty", 0.1)
.put("frequency_penalty", 0.1)
.put("messages", formatChat(chat)).toString()
def request = HttpRequest.newBuilder()
.uri(URI.create("https://api.openai.com/v1/chat/completions"))
@ -52,8 +68,7 @@ class OaiSource implements ChatSource {
def client = HttpClient.newHttpClient()
LinkedBlockingQueue<ChatEvent> eventQueue = new LinkedBlockingQueue<>()
CompletableFuture<ChatMessage> messageFuture = new CompletableFuture<>()
ChatResponse chatResponse = new NonStreamingChatResponse()
def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofString())
@ -65,20 +80,26 @@ class OaiSource implements ChatSource {
}
if (exception != null) {
eventQueue.put(ChatEvent.of(exception))
messageFuture.completeExceptionally(exception)
chatResponse.completeExceptionally(exception)
} else {
def json = new JSONObject(it.body())
//System.out.println(json)
def completion = json.getJSONArray("choices").getJSONObject(0).getJSONObject("message").getString("content")
chatResponse.complete(completion)
}
}
return new SimpleChatResponse(false, eventQueue, messageFuture)
return chatResponse
}
static JSONArray formatChat(Chat chat) {
def array = new JSONArray()
if (chat.systemPrompt != null) {
array.put(new JSONObject().put("role", "system").put("content", chat.systemPrompt))
}
chat.messages.each {
array.put(new JSONObject().put("role", it.assistant() ? "assistant" : "user").put("content", it.text()))
}

View file

@ -34,7 +34,9 @@ public interface ChatSource {
ChatResponse chatResponse = onAsked(chat);
// TODO make sure it works in parallel
chatResponse.message().thenAccept(chat::addMessage);
chatResponse.message().thenAccept(message -> {
if (message != null) chat.addMessage(message);
});
return chatResponse;
}

View file

@ -0,0 +1,53 @@
package eu.m724.source;
import eu.m724.chat.ChatEvent;
import eu.m724.chat.ChatMessage;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue;
public class NonStreamingChatResponse implements ChatResponse {
private LinkedBlockingQueue<ChatEvent> eventQueue = new LinkedBlockingQueue<>();
private CompletableFuture<ChatMessage> message = new CompletableFuture<>();
@Override
public boolean streaming() {
return false;
}
@Override
public LinkedBlockingQueue<ChatEvent> eventQueue() {
return eventQueue;
}
@Override
public CompletableFuture<ChatMessage> message() {
return message;
}
public boolean complete(String content) {
if (message.isDone()) return false;
try {
eventQueue.put(ChatEvent.of(content, "stop"));
} catch (InterruptedException e) {
throw new RuntimeException(e); // TODO probably should handle this
}
message.complete(new ChatMessage(true, content));
return true;
}
public boolean completeExceptionally(Throwable throwable) {
if (message.isDone()) return false;
try {
eventQueue.put(ChatEvent.of(throwable));
} catch (InterruptedException e) {
throw new RuntimeException(e); // TODO probably should handle this
}
message.complete(null);
return true;
}
}

View file

@ -0,0 +1,26 @@
package eu.m724.source.option;
public class DoubleOption extends Option<Double> {
private double minValue = Double.MIN_VALUE;
private double maxValue = Double.MAX_VALUE;
public DoubleOption(String id, String label, Double value) {
super(id, label, value);
}
public DoubleOption(String id, String label, Double value, double minValue, double maxValue) {
super(id, label, value);
this.minValue = minValue;
this.maxValue = maxValue;
}
@Override
boolean isValid(Double value) {
return value >= minValue && value <= maxValue;
}
@Override
public Double fromString(String text) {
return Double.valueOf(text);
}
}

View file

@ -18,4 +18,9 @@ public class NumberOption extends Option<Integer> {
boolean isValid(Integer value) {
return value >= minValue && value <= maxValue;
}
@Override
public Integer fromString(String text) {
return Integer.valueOf(text);
}
}

View file

@ -1,7 +1,6 @@
package eu.m724.source.option;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
/**
* represents an option that is a text label and value of any type
@ -49,5 +48,18 @@ public abstract class Option<T> {
// TODO I'm not a fan of that, probably should address the warnings
/**
* checks if the option is valid given current constraints
* @param value the checked value
* @return is it valid
*/
abstract boolean isValid(T value);
/**
* convert a string to the type of this option
* TODO fix english
* @param text a text representation of a value
* @return a value in an acceptable type
*/
public abstract T fromString(String text);
}

View file

@ -34,6 +34,10 @@ public class Options {
return (T) options.get(id).getValue();
}
public String getStringValue(String id) {
return (String) getOptionValue(id, String.class);
}
/**
* set a value of an option
* @param id the option id

View file

@ -23,4 +23,9 @@ public class StringOption extends Option<String> {
boolean isValid(String value) {
return pattern == null || pattern.matcher(value).matches();
}
@Override
public String fromString(String text) {
return text;
}
}