From 6bfd51baa329007e59899d5ce2f5636c9e542ac9 Mon Sep 17 00:00:00 2001 From: Minecon724 Date: Fri, 6 Sep 2024 18:19:13 +0200 Subject: [PATCH] big revamp --- .../m724/chatapi/example/ExampleSource.groovy | 32 +---- .../eu/m724/chatapi/example/OaiSource.groovy | 16 ++- src/main/java/eu/m724/chatapi/Main.java | 40 +++--- .../eu/m724/chatapi/chat/ChatMessage.java | 132 +++++++++++++++++- .../eu/m724/chatapi/source/ChatResponse.java | 34 ++--- .../eu/m724/chatapi/source/ChatSource.java | 5 +- .../source/impl/BlockingQueueConsumer.java | 20 +++ .../source/impl/NonStreamingChatResponse.java | 60 -------- .../source/impl/StreamingChatResponse.java | 77 ---------- 9 files changed, 194 insertions(+), 222 deletions(-) create mode 100644 src/main/java/eu/m724/chatapi/source/impl/BlockingQueueConsumer.java delete mode 100644 src/main/java/eu/m724/chatapi/source/impl/NonStreamingChatResponse.java delete mode 100644 src/main/java/eu/m724/chatapi/source/impl/StreamingChatResponse.java diff --git a/src/main/groovy/eu/m724/chatapi/example/ExampleSource.groovy b/src/main/groovy/eu/m724/chatapi/example/ExampleSource.groovy index 33c2ccf..1804c76 100644 --- a/src/main/groovy/eu/m724/chatapi/example/ExampleSource.groovy +++ b/src/main/groovy/eu/m724/chatapi/example/ExampleSource.groovy @@ -10,8 +10,6 @@ import eu.m724.chatapi.source.option.Options import eu.m724.chatapi.source.option.StringOption import java.util.concurrent.CompletableFuture -import java.util.concurrent.LinkedBlockingQueue - /** * an example chatresponsesource chatresponsesource ChatResponseSource CHATRESPONSESOURCE CAHTSERREPOSNECSOURCE * note to self: rename that already... @@ -40,48 +38,30 @@ class ExampleSource implements ChatSource { ChatResponse onAsked(Chat chat) { //String prompt = chat.messages.getLast().text(); String response = rollResponse(chat) - String[] parts = response.split(" ") - LinkedBlockingQueue queue = new LinkedBlockingQueue<>() + ChatMessage message = new ChatMessage(true); - CompletableFuture future = CompletableFuture.supplyAsync { + CompletableFuture.supplyAsync { for (int i=0; i 0 ? " " : "") + parts[i] - queue.put(ChatEvent.of(token)); + message.submitEvent(ChatEvent.of(token)) Thread.sleep(random.nextInt(50, 200)) } - queue.put(ChatEvent.finished("stop")) - return new ChatMessage(true, parts.join(" ")) + message.submitEvent(ChatEvent.finished("stop")) } - return new ChatResponse() { - @Override - boolean streaming() { - return false - } - - @Override - LinkedBlockingQueue eventQueue() { - return queue - } - - @Override - CompletableFuture message() { - return future - } - } + return new ChatResponse(message) } String rollResponse(Chat chat) { def special = -1 - def prompt = chat.messages.getLast().text() + def prompt = chat.messages.getLast().content() int messagesCount = (int) Math.ceil(chat.messages.size() / 2) def counter = lastCounter def response = "" - if (prompt.toLowerCase().startsWith("my name is")) { options.setValue("name", prompt.substring(11)) counter = 11 diff --git a/src/main/groovy/eu/m724/chatapi/example/OaiSource.groovy b/src/main/groovy/eu/m724/chatapi/example/OaiSource.groovy index 0e0246a..8816b03 100644 --- a/src/main/groovy/eu/m724/chatapi/example/OaiSource.groovy +++ b/src/main/groovy/eu/m724/chatapi/example/OaiSource.groovy @@ -1,11 +1,12 @@ package eu.m724.chatapi.example import eu.m724.chatapi.chat.Chat +import eu.m724.chatapi.chat.ChatEvent +import eu.m724.chatapi.chat.ChatMessage import eu.m724.chatapi.source.ChatResponse import eu.m724.chatapi.source.ChatSource import eu.m724.chatapi.source.ChatSourceInfo import eu.m724.chatapi.source.exception.HttpException -import eu.m724.chatapi.source.impl.StreamingChatResponse import eu.m724.chatapi.source.option.DoubleOption import eu.m724.chatapi.source.option.Options import eu.m724.chatapi.source.option.StringOption @@ -65,7 +66,7 @@ class OaiSource implements ChatSource { def client = HttpClient.newHttpClient() - def chatResponse = new StreamingChatResponse() + def chatMessage = new ChatMessage(true) def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofLines()) @@ -96,21 +97,22 @@ class OaiSource implements ChatSource { def json = new JSONObject(data) def choice = json.getJSONArray("choices").getJSONObject(0) def finishReason = choice.get("finish_reason") + ChatEvent event if (finishReason != JSONObject.NULL) { - //System.out.println("ending"); - chatResponse.end(finishReason.toString()) + event = ChatEvent.finished(finishReason.toString()) } else { def token = choice.getJSONObject("delta").getString("content") - chatResponse.put(token) + event = ChatEvent.of(token) } + chatMessage.submitEvent(event) } } } } } - return chatResponse + return new ChatResponse(chatMessage) } static JSONArray formatChat(Chat chat) { @@ -121,7 +123,7 @@ class OaiSource implements ChatSource { } chat.messages.each { - array.put(new JSONObject().put("role", it.assistant() ? "assistant" : "user").put("content", it.text())) + array.put(new JSONObject().put("role", it.response() ? "assistant" : "user").put("content", it.content())) } return array diff --git a/src/main/java/eu/m724/chatapi/Main.java b/src/main/java/eu/m724/chatapi/Main.java index bb384b4..9084a71 100644 --- a/src/main/java/eu/m724/chatapi/Main.java +++ b/src/main/java/eu/m724/chatapi/Main.java @@ -6,7 +6,6 @@ import eu.m724.chatapi.chat.ChatMessage; import eu.m724.chatapi.example.OaiSource; import eu.m724.chatapi.source.ChatResponse; import eu.m724.chatapi.source.ChatSource; -import eu.m724.chatapi.source.exception.HttpException; import eu.m724.chatapi.source.option.Option; import eu.m724.chatapi.source.option.Options; @@ -19,7 +18,6 @@ class Main { Scanner scanner = new Scanner(System.in); String apiKey = System.getenv("API_KEY"); - boolean complainedApiKey = false; if (apiKey == null) { System.out.print("\nAPI Key: "); @@ -32,10 +30,9 @@ class Main { if (!Pattern.matches("sk-proj-.*?(?:\\s|$)", apiKey)) { System.out.println("This key looks invalid"); - complainedApiKey = true; } - //ChatSource source = new ExampleSource(); + // ChatSource source = new ExampleSource(); ChatSource source = new OaiSource(apiKey); source.options().setValue("model", "chatgpt-4o-latest"); @@ -59,28 +56,33 @@ class Main { if (!prompt.startsWith(":")) { chat.messages.add(new ChatMessage(false, prompt)); + ChatResponse chatResponse = source.ask(chat); + ChatMessage message = chatResponse.message(); + ChatEvent token; - int i = 0; - do { - token = chatResponse.eventQueue().take(); - - if (!"error".equals(token.finishReason())) { // this looks bad but at least idea doesn't nag me - if (token.text() != null) { - System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m"); - System.out.print(token.text()); + int[] i = {0}; // this makes no sense + message.addEventConsumer(event -> { + if (!"error".equals(event.finishReason())) { + if (event.text() != null) { + System.out.print(i[0]++ % 2 == 1 ? "\033[1m" : "\033[0m"); + System.out.print(event.text()); } } else { - System.out.print("Error: " + token.error().toString()); - if (complainedApiKey && token.error() instanceof HttpException && ((HttpException)token.error()).statusCode == 401) { - System.out.print("\nTold you"); - complainedApiKey = false; - } + System.out.print("Error: " + event.error().toString()); + // here was a funny thing, but I had to remove it } - } while (token.finishReason() == null); + }, true); + + try { + message.onComplete().join(); + } catch (Throwable e) { + e.printStackTrace(); + } System.out.println(); + } else { String[] parts = prompt.substring(1).split(" "); if (parts[0].startsWith(":")) { @@ -110,7 +112,7 @@ class Main { System.out.printf(", excluding system prompt.\nSystem prompt:\n\"\"\"%s\"\"\"\n", chat.systemPrompt); } for (ChatMessage message : chat.messages) { - System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text()); + System.out.printf("%s: %s\n", message.response() ? "ASSISTANT" : "USER", message.content()); } break; case "options": diff --git a/src/main/java/eu/m724/chatapi/chat/ChatMessage.java b/src/main/java/eu/m724/chatapi/chat/ChatMessage.java index 77a1ede..448447d 100644 --- a/src/main/java/eu/m724/chatapi/chat/ChatMessage.java +++ b/src/main/java/eu/m724/chatapi/chat/ChatMessage.java @@ -1,7 +1,133 @@ package eu.m724.chatapi.chat; -public record ChatMessage(boolean assistant, String text) { - public static ChatMessage assistant(String text) { - return new ChatMessage(true, text); +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * represents a message, user or assistant, streaming or not
+ * this is quite a robust class, so don't overuse it + */ +public class ChatMessage { + private final boolean response; + private final CompletableFuture completedFuture = new CompletableFuture<>(); + private final Set> eventConsumers = new HashSet<>(); + + private final List events = new ArrayList<>(); + + private String content = ""; // TODO I was thinking about making this abstract so you could put different types like images documents etc + + // + + /** + * Registers a chat message that can be streamed to + * + * @param response is the message a response + */ + public ChatMessage(boolean response) { + this.response = response; + } + + /** + * Creates a completed chat message + * + * @param response is the message a response + * @param content the text + */ + public ChatMessage(boolean response, String content) { + this.response = response; + this.content = content; + + completedFuture.complete(null); + //submitEvent(ChatEvent.of(text, "end")); + } + + // + + /** + * registers a consumer for the {@link ChatEvent}s + * + * @param consumer the consumer + * @param history receive events before this signed up + */ + public void addEventConsumer(Consumer consumer, boolean history) { + eventConsumers.add(consumer); + if (history) { + for (ChatEvent event : events) + consumer.accept(event); + } + } + + // + + /** + * submits a {@link ChatEvent} to broadcast to consumers
+ * it also makes me (the {@link ChatMessage}) use it to update the final content or if streaming ended + * + * @param event the event + */ + public void submitEvent(ChatEvent event) { + if (event.text() != null) { + content += event.text(); + } + + events.add(event); + eventConsumers.forEach(c -> { + try { + c.accept(event); + } catch (Throwable e) { + System.err.println("Error distributing event:"); + e.printStackTrace(); + } + }); + + if (event.error() != null) { + completedFuture.completeExceptionally(event.error()); + } else if (event.finishReason() != null) { + completedFuture.complete(null); + } + + } + + // + + /** + * returns this message's content
+ * if called during a response, it will return what's already been generated + * + * @return this message's content + */ + public String content() { + return content; + } + + /** + * whether this message is a response
+ * a response is usually sent by the assistant, not by the user + * + * @return whether this message is a response + */ + public boolean response() { + return response; + } + + /** + * a future that completes when the response is complete
+ * it can also throw an error + * + * @return a future that completes when the response is complete + */ + public CompletableFuture onComplete() { + return completedFuture; + } + + /** + * @return is the response complete + */ + public boolean completed() { + return completedFuture.isDone(); } } diff --git a/src/main/java/eu/m724/chatapi/source/ChatResponse.java b/src/main/java/eu/m724/chatapi/source/ChatResponse.java index 49419f0..a281835 100644 --- a/src/main/java/eu/m724/chatapi/source/ChatResponse.java +++ b/src/main/java/eu/m724/chatapi/source/ChatResponse.java @@ -1,33 +1,15 @@ package eu.m724.chatapi.source; -import eu.m724.chatapi.chat.ChatEvent; import eu.m724.chatapi.chat.ChatMessage; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.LinkedBlockingQueue; +public class ChatResponse { + private final ChatMessage message; -public interface ChatResponse { - /** - * is this response streaming - * if it's not, the queue will get one element that is the whole response - * I think about replacing with an abstract class and put this in the constructor - * - * @return is this response streaming - */ - boolean streaming(); + public ChatResponse(ChatMessage message) { + this.message = message; + } - /** - * if streamed, text token by token as it goes (or other splitting depending on the source) - * if not, the {@link CompletableFuture} returns just the whole response after it's ready - * - * @return the fifo queue with each element being a part. null ends the sequence - */ - LinkedBlockingQueue eventQueue(); - - /** - * gets the resulting {@link ChatMessage} when it's ready - * - * @return the resulting {@link ChatMessage} as soon as the response is complete - */ - CompletableFuture message(); + public ChatMessage message() { + return message; + } } diff --git a/src/main/java/eu/m724/chatapi/source/ChatSource.java b/src/main/java/eu/m724/chatapi/source/ChatSource.java index 5a43e1b..aba05a9 100644 --- a/src/main/java/eu/m724/chatapi/source/ChatSource.java +++ b/src/main/java/eu/m724/chatapi/source/ChatSource.java @@ -31,10 +31,7 @@ public interface ChatSource { */ default ChatResponse ask(Chat chat) { ChatResponse chatResponse = onAsked(chat); - - chatResponse.message().thenAccept(message -> { - if (message != null) chat.addMessage(message); - }); + chat.addMessage(chatResponse.message()); return chatResponse; } diff --git a/src/main/java/eu/m724/chatapi/source/impl/BlockingQueueConsumer.java b/src/main/java/eu/m724/chatapi/source/impl/BlockingQueueConsumer.java new file mode 100644 index 0000000..7863dec --- /dev/null +++ b/src/main/java/eu/m724/chatapi/source/impl/BlockingQueueConsumer.java @@ -0,0 +1,20 @@ +package eu.m724.chatapi.source.impl; + +import java.util.concurrent.LinkedBlockingQueue; +import java.util.function.Consumer; + +public class BlockingQueueConsumer { + public final LinkedBlockingQueue queue = new LinkedBlockingQueue<>(); + + public final Consumer consumer = new Consumer<>() { + @Override + public void accept(T t) { + try { + queue.put(t); + } catch (InterruptedException e) { + // again I don't know how that is relevant + throw new RuntimeException(e); + } + } + }; +} diff --git a/src/main/java/eu/m724/chatapi/source/impl/NonStreamingChatResponse.java b/src/main/java/eu/m724/chatapi/source/impl/NonStreamingChatResponse.java deleted file mode 100644 index 0bee30f..0000000 --- a/src/main/java/eu/m724/chatapi/source/impl/NonStreamingChatResponse.java +++ /dev/null @@ -1,60 +0,0 @@ -package eu.m724.chatapi.source.impl; - -import eu.m724.chatapi.chat.ChatEvent; -import eu.m724.chatapi.chat.ChatMessage; -import eu.m724.chatapi.source.ChatResponse; - -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.LinkedBlockingQueue; - -public class NonStreamingChatResponse implements ChatResponse { - private final LinkedBlockingQueue eventQueue = new LinkedBlockingQueue<>(); - private final CompletableFuture message = new CompletableFuture<>(); - - @Override - public boolean streaming() { - return false; - } - - @Override - public LinkedBlockingQueue eventQueue() { - return eventQueue; - } - - @Override - public CompletableFuture message() { - return message; - } - - public boolean complete(String content) { - if (message.isDone()) return false; - - try { - eventQueue.put(ChatEvent.of(content, "stop")); - } catch (InterruptedException e) { - // I don't know what this exception means - // and I don't think how will it cause me problems - // so ignoring it for now - throw new RuntimeException(e); - } - - message.complete(new ChatMessage(true, content)); - - return true; - } - - public boolean completeExceptionally(Throwable throwable) { - if (message.isDone()) return false; - - try { - eventQueue.put(ChatEvent.of(throwable)); - } catch (InterruptedException e) { - // again - throw new RuntimeException(e); - } - - message.complete(null); - - return true; - } -} diff --git a/src/main/java/eu/m724/chatapi/source/impl/StreamingChatResponse.java b/src/main/java/eu/m724/chatapi/source/impl/StreamingChatResponse.java deleted file mode 100644 index 8bcb5b1..0000000 --- a/src/main/java/eu/m724/chatapi/source/impl/StreamingChatResponse.java +++ /dev/null @@ -1,77 +0,0 @@ -package eu.m724.chatapi.source.impl; - -import eu.m724.chatapi.chat.ChatEvent; -import eu.m724.chatapi.chat.ChatMessage; -import eu.m724.chatapi.source.ChatResponse; - -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.LinkedBlockingQueue; - -public class StreamingChatResponse implements ChatResponse { - private final boolean streaming; - private final LinkedBlockingQueue eventQueue; - private final CompletableFuture message; - - private String total = ""; - - public StreamingChatResponse() { - this.streaming = true; - this.eventQueue = new LinkedBlockingQueue<>(); - this.message = new CompletableFuture<>(); - } - - public void put(String token, String finishReason) { - try { - eventQueue.put(ChatEvent.of(token, finishReason)); - } catch (InterruptedException e) { - // I don't know what this exception means - // and I don't think how will it cause me problems - // so ignoring it for now - throw new RuntimeException(e); - } - - if (token != null) { - total += token; - } - - if (finishReason != null) { - message.complete(ChatMessage.assistant(total)); - } - } - - public void put(String token) { - put(token, null); - } - - public void end(String finishReason) { - put(null, finishReason); - } - - public void error(Throwable throwable) { - try { - eventQueue.put(ChatEvent.of(throwable)); - } catch (InterruptedException e) { - // again - throw new RuntimeException(e); - } - - message.complete(ChatMessage.assistant(total)); - } - - // - - @Override - public boolean streaming() { - return streaming; - } - - @Override - public LinkedBlockingQueue eventQueue() { - return eventQueue; - } - - @Override - public CompletableFuture message() { - return message; - } -}