From 1dcf09614e86849f41c05ab5c89cfef9726f786c Mon Sep 17 00:00:00 2001 From: Minecon724 Date: Sat, 31 Aug 2024 19:12:55 +0200 Subject: [PATCH] make it stream and other changes --- src/main/java/eu/m724/Main.java | 31 +++++------ src/main/java/eu/m724/chat/ChatMessage.java | 6 +- .../java/eu/m724/example/OaiSource.groovy | 35 ++++++++++-- src/main/java/eu/m724/source/ChatSource.java | 20 ++++++- .../m724/source/NonStreamingChatResponse.java | 4 +- .../eu/m724/source/SimpleChatResponse.java | 55 ++++++++++++++++++- thinkings/thinking.txt | 4 ++ 7 files changed, 125 insertions(+), 30 deletions(-) diff --git a/src/main/java/eu/m724/Main.java b/src/main/java/eu/m724/Main.java index 1ab3341..5f81b69 100644 --- a/src/main/java/eu/m724/Main.java +++ b/src/main/java/eu/m724/Main.java @@ -3,20 +3,14 @@ package eu.m724; import eu.m724.chat.Chat; import eu.m724.chat.ChatEvent; import eu.m724.chat.ChatMessage; -import eu.m724.example.ExampleSource; import eu.m724.example.OaiSource; import eu.m724.source.ChatResponse; import eu.m724.source.ChatSource; import eu.m724.source.option.Option; -import java.io.BufferedReader; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.net.URL; -import java.nio.charset.StandardCharsets; -import java.util.*; -import java.util.stream.Collectors; -import java.util.stream.LongStream; +import java.util.Arrays; +import java.util.Map; +import java.util.Scanner; public class Main { public static void main(String[] args) throws InterruptedException { @@ -24,12 +18,12 @@ public class Main { //source.options().setValue("name", "nekalakininahappenenawiwanatin"); ChatSource source = new OaiSource(System.getenv("API_KEY")); - //source.options().setValue("model", "chatgpt-4o-latest"); + source.options().setValue("model", "chatgpt-4o-latest"); Chat chat = new Chat("Speak in super wuper uwu wanguage."); System.out.println("Welcome to CHUT chat. Say something after the \033[1m>\033[0m, or type \033[1m:help\033[0m to see available commands"); - System.out.println("Source: " + source.getClass().getName()); + System.out.printf("Source: \033[1m%s\033[0m %s (%d) by %s\n", source.info().name(), source.info().versionName(), source.info().version(), source.info().author()); Scanner scanner = new Scanner(System.in); @@ -47,8 +41,10 @@ public class Main { token = chatResponse.eventQueue().take(); if (token.finishReason() != "error") { - System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m"); - System.out.print(token.text()); + if (token.text() != null) { + System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m"); + System.out.print(token.text()); + } } else { System.out.print("Error: " + token.error().toString()); } @@ -75,12 +71,13 @@ public class Main { break; case "dump": case "d": + int size = chat.messages.size(); + System.out.printf("This chat has %d messages, or %d pairs", size, size / 2); + if (chat.systemPrompt == null) { - System.out.printf("This chat has %d messages.\n", chat.messages.size()); - System.out.println("There's no system prompt."); + System.out.println(".\nThere's no system prompt."); } else { - System.out.printf("This chat has %d messages, excluding system prompt.\n", chat.messages.size()); - System.out.printf("System prompt:\n\"\"\"%s\"\"\"\n", chat.systemPrompt); + System.out.printf(", excluding system prompt.\nSystem prompt:\n\"\"\"%s\"\"\"\n", chat.systemPrompt); } for (ChatMessage message : chat.messages) { System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text()); diff --git a/src/main/java/eu/m724/chat/ChatMessage.java b/src/main/java/eu/m724/chat/ChatMessage.java index b840bfb..538b278 100644 --- a/src/main/java/eu/m724/chat/ChatMessage.java +++ b/src/main/java/eu/m724/chat/ChatMessage.java @@ -1,7 +1,7 @@ package eu.m724.chat; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Flow; - public record ChatMessage(boolean assistant, String text) { + public static ChatMessage assistant(String text) { + return new ChatMessage(true, text); + } } diff --git a/src/main/java/eu/m724/example/OaiSource.groovy b/src/main/java/eu/m724/example/OaiSource.groovy index cf31804..e2fee6a 100644 --- a/src/main/java/eu/m724/example/OaiSource.groovy +++ b/src/main/java/eu/m724/example/OaiSource.groovy @@ -4,7 +4,7 @@ import eu.m724.chat.Chat import eu.m724.source.ChatResponse import eu.m724.source.ChatSource import eu.m724.source.ChatSourceInfo -import eu.m724.source.NonStreamingChatResponse +import eu.m724.source.SimpleChatResponse import eu.m724.source.option.DoubleOption import eu.m724.source.option.Options import eu.m724.source.option.StringOption @@ -14,7 +14,6 @@ import org.json.JSONObject import java.net.http.HttpClient import java.net.http.HttpRequest import java.net.http.HttpResponse - // for now let's not focus on readability // this is more about find out what is common and should be included in the common api class OaiSource implements ChatSource { @@ -50,6 +49,7 @@ class OaiSource implements ChatSource { .put("model", options.getStringValue("model")) .put("temperature", options.getOptionValue("temperature", Double::class)) .put("presence_penalty", 0.1) + .put("stream", true) .put("messages", formatChat(chat)).toString() def request = HttpRequest.newBuilder() @@ -61,9 +61,10 @@ class OaiSource implements ChatSource { def client = HttpClient.newHttpClient() - ChatResponse chatResponse = new NonStreamingChatResponse() + def chatResponse = new SimpleChatResponse() + + def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofLines()) - def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofString()) response.thenAccept { Exception exception = null @@ -72,6 +73,30 @@ class OaiSource implements ChatSource { exception = new Exception("Non 200 status code: %d".formatted(it.statusCode())) } + it.body().forEach { + //System.out.println(it); + for (String line : it.split("\n")) { + if (line.startsWith("data: ")) { + def data = line.substring(6) + + if (data != "[DONE]") { + def json = new JSONObject(data) + def choice = json.getJSONArray("choices").getJSONObject(0) + def finishReason = choice.get("finish_reason") + + if (finishReason != JSONObject.NULL) { + //System.out.println("ending"); + chatResponse.end(finishReason.toString()) + } else { + def token = choice.getJSONObject("delta").getString("content") + chatResponse.put(token) + } + } + } + } + } +/* + if (exception != null) { chatResponse.completeExceptionally(exception) } else { @@ -79,7 +104,7 @@ class OaiSource implements ChatSource { //System.out.println(json) def completion = json.getJSONArray("choices").getJSONObject(0).getJSONObject("message").getString("content") chatResponse.complete(completion) - } + }*/ } diff --git a/src/main/java/eu/m724/source/ChatSource.java b/src/main/java/eu/m724/source/ChatSource.java index fdedde3..257088e 100644 --- a/src/main/java/eu/m724/source/ChatSource.java +++ b/src/main/java/eu/m724/source/ChatSource.java @@ -1,7 +1,6 @@ package eu.m724.source; import eu.m724.chat.Chat; -import eu.m724.source.option.Option; import eu.m724.source.option.Options; public interface ChatSource { @@ -33,6 +32,25 @@ public interface ChatSource { default ChatResponse ask(Chat chat) { ChatResponse chatResponse = onAsked(chat); + /* if (chatResponse.streaming()) { + StringBuilder total = new StringBuilder(); + CompletableFuture.runAsync(() -> { + ChatEvent event; + while (true) { + try { + event = chatResponse.eventQueue().take(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + total.append(event.text()); + if (event.finishReason() != null) { + chatResponse.message().complete(new ChatMessage(true, total.toString())); + break; + } + } + }); + } this was a draft I'm keeping because I might change my mind */ + // TODO make sure it works in parallel chatResponse.message().thenAccept(message -> { if (message != null) chat.addMessage(message); diff --git a/src/main/java/eu/m724/source/NonStreamingChatResponse.java b/src/main/java/eu/m724/source/NonStreamingChatResponse.java index 029fd5d..aceb0c3 100644 --- a/src/main/java/eu/m724/source/NonStreamingChatResponse.java +++ b/src/main/java/eu/m724/source/NonStreamingChatResponse.java @@ -7,8 +7,8 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.LinkedBlockingQueue; public class NonStreamingChatResponse implements ChatResponse { - private LinkedBlockingQueue eventQueue = new LinkedBlockingQueue<>(); - private CompletableFuture message = new CompletableFuture<>(); + private final LinkedBlockingQueue eventQueue = new LinkedBlockingQueue<>(); + private final CompletableFuture message = new CompletableFuture<>(); @Override public boolean streaming() { diff --git a/src/main/java/eu/m724/source/SimpleChatResponse.java b/src/main/java/eu/m724/source/SimpleChatResponse.java index 170d63f..d50001e 100644 --- a/src/main/java/eu/m724/source/SimpleChatResponse.java +++ b/src/main/java/eu/m724/source/SimpleChatResponse.java @@ -6,6 +6,57 @@ import eu.m724.chat.ChatMessage; import java.util.concurrent.CompletableFuture; import java.util.concurrent.LinkedBlockingQueue; -public record SimpleChatResponse(boolean streaming, LinkedBlockingQueue eventQueue, - CompletableFuture message) implements ChatResponse { +public class SimpleChatResponse implements ChatResponse { + private final boolean streaming; + private final LinkedBlockingQueue eventQueue; + private final CompletableFuture message; + + private String total = ""; + + public SimpleChatResponse() { + this.streaming = true; + this.eventQueue = new LinkedBlockingQueue<>(); + this.message = new CompletableFuture<>(); + } + + public void put(String token, String finishReason) { + //System.out.println(System.currentTimeMillis()); + + try { + eventQueue.put(ChatEvent.of(token, finishReason)); + } catch (InterruptedException e) { + throw new RuntimeException(e); // TODO I don't know what this is + } + + if (token != null) { + total += token; + } + + if (finishReason != null) { + message.complete(ChatMessage.assistant(total)); + } + } + + public void put(String token) { + put(token, null); + } + + public void end(String finishReason) { + put(null, finishReason); + } + + @Override + public boolean streaming() { + return streaming; + } + + @Override + public LinkedBlockingQueue eventQueue() { + return eventQueue; + } + + @Override + public CompletableFuture message() { + return message; + } } diff --git a/thinkings/thinking.txt b/thinkings/thinking.txt index a7f79f5..8f68839 100644 --- a/thinkings/thinking.txt +++ b/thinkings/thinking.txt @@ -3,6 +3,10 @@ so, the api is mostly good now, but new things to focus on: - how to do requests - how to make it friendly +delay +right now streaming seems delayed and batched, I don't know why +maybe it's just me + network every platform has different libraries for that, so abstraction how? I still don't know