diff --git a/src/main/java/eu/m724/Main.java b/src/main/java/eu/m724/Main.java index c278343..622b2ff 100644 --- a/src/main/java/eu/m724/Main.java +++ b/src/main/java/eu/m724/Main.java @@ -4,27 +4,29 @@ import eu.m724.chat.Chat; import eu.m724.chat.ChatEvent; import eu.m724.chat.ChatMessage; import eu.m724.example.ExampleSource; +import eu.m724.example.OaiSource; import eu.m724.source.ChatResponse; import eu.m724.source.ChatSource; +import eu.m724.source.option.Option; import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; import java.net.URL; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Scanner; +import java.util.*; import java.util.stream.Collectors; import java.util.stream.LongStream; public class Main { public static void main(String[] args) throws InterruptedException { - ChatSource source = new ExampleSource(); - source.options().setValue("name", "nekalakininahappenenawiwanatin"); + //ChatSource source = new ExampleSource(); + //source.options().setValue("name", "nekalakininahappenenawiwanatin"); - Chat chat = new Chat(); + ChatSource source = new OaiSource(System.getenv("API_KEY")); + //source.options().setValue("model", "chatgpt-4o-latest"); + + Chat chat = new Chat("Speak in uwu wanguage."); System.out.println("Welcome to CHUT chat. Say something after the \033[1m>\033[0m, or type \033[1m:help\033[0m to see available commands"); Scanner scanner = new Scanner(System.in); @@ -39,14 +41,21 @@ public class Main { ChatEvent token; int i = 0; - while ((token = chatResponse.eventQueue().take()).finishReason() == null) { - System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m"); - System.out.print(token.text()); - } + do { + token = chatResponse.eventQueue().take(); + + if (token.finishReason() != "error") { + System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m"); + System.out.print(token.text()); + } else { + System.out.print("Error: " + token.error().toString()); + } + } while (token.finishReason() == null); System.out.println(); } else { String[] parts = prompt.substring(1).split(" "); + switch (parts[0]) { case "help": case "h": @@ -54,6 +63,7 @@ public class Main { System.out.println("Available commands:"); System.out.println(":help - this"); System.out.println(":dump - recap of the chat"); + System.out.println(":opt - change source options"); break; case "dump": case "d": @@ -68,6 +78,21 @@ public class Main { System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text()); } break; + case "options": + case "opt": + case "o": + if (parts.length < 3 || Arrays.stream(source.options().keys()).noneMatch(parts[1]::equals)) { + System.out.println("Available options:"); + for (Map.Entry> entry : source.options().getOptions().entrySet()) { + System.out.printf("%s (%s) = %s\n", entry.getValue().label, entry.getKey(), entry.getValue().getValue().toString()); + } + } else { + Option option = source.options().getOptions().get(parts[1]); + Object value = option.fromString(parts[2]); + option.setValue(value); + System.out.printf("Set %s to %s\n", option.label, option.getValue()); + } + break; default: System.out.println("Invalid command: " + parts[0]); break; diff --git a/src/main/java/eu/m724/example/OaiSource.groovy b/src/main/java/eu/m724/example/OaiSource.groovy index 834d05e..019ef04 100644 --- a/src/main/java/eu/m724/example/OaiSource.groovy +++ b/src/main/java/eu/m724/example/OaiSource.groovy @@ -6,7 +6,10 @@ import eu.m724.chat.ChatMessage import eu.m724.source.ChatResponse import eu.m724.source.ChatSource import eu.m724.source.ChatSourceInfo +import eu.m724.source.NonStreamingChatResponse import eu.m724.source.SimpleChatResponse +import eu.m724.source.option.DoubleOption +import eu.m724.source.option.NumberOption import eu.m724.source.option.Options import eu.m724.source.option.StringOption import org.json.JSONArray @@ -21,11 +24,19 @@ import java.util.concurrent.LinkedBlockingQueue // for now let's not focus on readability // this is more about find out what is common and should be included in the common api class OaiSource implements ChatSource { + private final String apiKey + + OaiSource(String apiKey) { + this.apiKey = apiKey + } + private def info = new ChatSourceInfo("oai source", "me", "1.0", 1) private def options = new Options( - new StringOption("apiKey", "API key", null), + new StringOption("apiKey", "API key", apiKey), + new StringOption("model", "Model", "gpt-4o-mini"), + new DoubleOption("temperature", "Temperature", 1.2) ) @Override @@ -40,8 +51,13 @@ class OaiSource implements ChatSource { @Override ChatResponse onAsked(Chat chat) { - def apiKey = options.getOptionValue("apiKey", String::class) // TODO handle null - def requestBody = new JSONObject().put("model", "gpt-4o-mini").put("messages", formatChat(chat)).toString() + def apiKey = options.getStringValue("apiKey") // TODO handle null + def requestBody = new JSONObject() + .put("model", options.getStringValue("model")) + .put("temperature", options.getOptionValue("temperature", Double::class)) + .put("presence_penalty", 0.1) + .put("frequency_penalty", 0.1) + .put("messages", formatChat(chat)).toString() def request = HttpRequest.newBuilder() .uri(URI.create("https://api.openai.com/v1/chat/completions")) @@ -52,8 +68,7 @@ class OaiSource implements ChatSource { def client = HttpClient.newHttpClient() - LinkedBlockingQueue eventQueue = new LinkedBlockingQueue<>() - CompletableFuture messageFuture = new CompletableFuture<>() + ChatResponse chatResponse = new NonStreamingChatResponse() def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofString()) @@ -65,20 +80,26 @@ class OaiSource implements ChatSource { } if (exception != null) { - eventQueue.put(ChatEvent.of(exception)) - messageFuture.completeExceptionally(exception) + chatResponse.completeExceptionally(exception) } else { - + def json = new JSONObject(it.body()) + //System.out.println(json) + def completion = json.getJSONArray("choices").getJSONObject(0).getJSONObject("message").getString("content") + chatResponse.complete(completion) } } - return new SimpleChatResponse(false, eventQueue, messageFuture) + return chatResponse } static JSONArray formatChat(Chat chat) { def array = new JSONArray() + if (chat.systemPrompt != null) { + array.put(new JSONObject().put("role", "system").put("content", chat.systemPrompt)) + } + chat.messages.each { array.put(new JSONObject().put("role", it.assistant() ? "assistant" : "user").put("content", it.text())) } diff --git a/src/main/java/eu/m724/source/ChatSource.java b/src/main/java/eu/m724/source/ChatSource.java index 5e6a39b..fdedde3 100644 --- a/src/main/java/eu/m724/source/ChatSource.java +++ b/src/main/java/eu/m724/source/ChatSource.java @@ -34,7 +34,9 @@ public interface ChatSource { ChatResponse chatResponse = onAsked(chat); // TODO make sure it works in parallel - chatResponse.message().thenAccept(chat::addMessage); + chatResponse.message().thenAccept(message -> { + if (message != null) chat.addMessage(message); + }); return chatResponse; } diff --git a/src/main/java/eu/m724/source/NonStreamingChatResponse.java b/src/main/java/eu/m724/source/NonStreamingChatResponse.java new file mode 100644 index 0000000..029fd5d --- /dev/null +++ b/src/main/java/eu/m724/source/NonStreamingChatResponse.java @@ -0,0 +1,53 @@ +package eu.m724.source; + +import eu.m724.chat.ChatEvent; +import eu.m724.chat.ChatMessage; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.LinkedBlockingQueue; + +public class NonStreamingChatResponse implements ChatResponse { + private LinkedBlockingQueue eventQueue = new LinkedBlockingQueue<>(); + private CompletableFuture message = new CompletableFuture<>(); + + @Override + public boolean streaming() { + return false; + } + + @Override + public LinkedBlockingQueue eventQueue() { + return eventQueue; + } + + @Override + public CompletableFuture message() { + return message; + } + + public boolean complete(String content) { + if (message.isDone()) return false; + + try { + eventQueue.put(ChatEvent.of(content, "stop")); + } catch (InterruptedException e) { + throw new RuntimeException(e); // TODO probably should handle this + } + message.complete(new ChatMessage(true, content)); + + return true; + } + + public boolean completeExceptionally(Throwable throwable) { + if (message.isDone()) return false; + + try { + eventQueue.put(ChatEvent.of(throwable)); + } catch (InterruptedException e) { + throw new RuntimeException(e); // TODO probably should handle this + } + message.complete(null); + + return true; + } +} diff --git a/src/main/java/eu/m724/source/option/DoubleOption.java b/src/main/java/eu/m724/source/option/DoubleOption.java new file mode 100644 index 0000000..00c6c52 --- /dev/null +++ b/src/main/java/eu/m724/source/option/DoubleOption.java @@ -0,0 +1,26 @@ +package eu.m724.source.option; + +public class DoubleOption extends Option { + private double minValue = Double.MIN_VALUE; + private double maxValue = Double.MAX_VALUE; + + public DoubleOption(String id, String label, Double value) { + super(id, label, value); + } + + public DoubleOption(String id, String label, Double value, double minValue, double maxValue) { + super(id, label, value); + this.minValue = minValue; + this.maxValue = maxValue; + } + + @Override + boolean isValid(Double value) { + return value >= minValue && value <= maxValue; + } + + @Override + public Double fromString(String text) { + return Double.valueOf(text); + } +} diff --git a/src/main/java/eu/m724/source/option/NumberOption.java b/src/main/java/eu/m724/source/option/NumberOption.java index 7755c8a..293e638 100644 --- a/src/main/java/eu/m724/source/option/NumberOption.java +++ b/src/main/java/eu/m724/source/option/NumberOption.java @@ -18,4 +18,9 @@ public class NumberOption extends Option { boolean isValid(Integer value) { return value >= minValue && value <= maxValue; } + + @Override + public Integer fromString(String text) { + return Integer.valueOf(text); + } } diff --git a/src/main/java/eu/m724/source/option/Option.java b/src/main/java/eu/m724/source/option/Option.java index 4e464bd..942894f 100644 --- a/src/main/java/eu/m724/source/option/Option.java +++ b/src/main/java/eu/m724/source/option/Option.java @@ -1,7 +1,6 @@ package eu.m724.source.option; import java.lang.reflect.ParameterizedType; -import java.lang.reflect.Type; /** * represents an option that is a text label and value of any type @@ -49,5 +48,18 @@ public abstract class Option { // TODO I'm not a fan of that, probably should address the warnings + /** + * checks if the option is valid given current constraints + * @param value the checked value + * @return is it valid + */ abstract boolean isValid(T value); + + /** + * convert a string to the type of this option + * TODO fix english + * @param text a text representation of a value + * @return a value in an acceptable type + */ + public abstract T fromString(String text); } diff --git a/src/main/java/eu/m724/source/option/Options.java b/src/main/java/eu/m724/source/option/Options.java index f5f8cd4..b7d9dc8 100644 --- a/src/main/java/eu/m724/source/option/Options.java +++ b/src/main/java/eu/m724/source/option/Options.java @@ -34,6 +34,10 @@ public class Options { return (T) options.get(id).getValue(); } + public String getStringValue(String id) { + return (String) getOptionValue(id, String.class); + } + /** * set a value of an option * @param id the option id diff --git a/src/main/java/eu/m724/source/option/StringOption.java b/src/main/java/eu/m724/source/option/StringOption.java index e3e8974..08e1296 100644 --- a/src/main/java/eu/m724/source/option/StringOption.java +++ b/src/main/java/eu/m724/source/option/StringOption.java @@ -23,4 +23,9 @@ public class StringOption extends Option { boolean isValid(String value) { return pattern == null || pattern.matcher(value).matches(); } + + @Override + public String fromString(String text) { + return text; + } }