diff --git a/src/main/java/eu/m724/Main.java b/src/main/java/eu/m724/Main.java index a72f30a..219b1b2 100644 --- a/src/main/java/eu/m724/Main.java +++ b/src/main/java/eu/m724/Main.java @@ -1,12 +1,16 @@ package eu.m724; import eu.m724.chat.Chat; +import eu.m724.chat.ChatEvent; import eu.m724.chat.ChatMessage; import eu.m724.example.ExampleSource; import eu.m724.responsesource.ChatResponse; import eu.m724.responsesource.ChatResponseSource; import groovy.lang.GroovyShell; +import java.util.ArrayList; +import java.util.List; + public class Main { public static void main(String[] args) throws InterruptedException { ChatResponseSource source = new ExampleSource(); @@ -15,18 +19,33 @@ public class Main { chat.messages.add(new ChatMessage(false, "hello")); ChatResponse chatResponse = source.ask(chat); - String token; - int tokens = 0; + + // I was thinking about integrating this into ChatMessage + List tokens = new ArrayList<>(); + List delays = new ArrayList<>(); System.out.println("Streaming response now\n"); + ChatEvent token; - while (!(token = chatResponse.textQueue().take()).equals("END_OF_TEXT")) { - System.out.print(token); - tokens++; + // usually finish reason will be alongside a token but this is simpler + while ((token = chatResponse.eventQueue().take()).finishReason() == null) { + System.out.print(token.text()); + tokens.add(token.text()); + + long now = System.currentTimeMillis(); + delays.add(now); } System.out.println("\n"); - System.out.printf("Tokens: %d\n", tokens); - System.out.printf("Text: %s\n", chatResponse.message().join().text); + System.out.printf("Tokens: %d\n", tokens.size()); + + long time = delays.getFirst(); + for (int i=0; i queue = new LinkedBlockingQueue<>() + LinkedBlockingQueue queue = new LinkedBlockingQueue<>() CompletableFuture future = CompletableFuture.supplyAsync { for (int i=0; i 0 ? " " : "") + parts[i] - queue.put(token); + queue.put(ChatEvent.of(token)); Thread.sleep(random.nextInt(200, 500)) } - queue.put("END_OF_TEXT") + queue.put(ChatEvent.finished("stop")) return new ChatMessage(true, parts.join(" ")) } @@ -49,7 +50,7 @@ class ExampleSource implements ChatResponseSource { } @Override - LinkedBlockingQueue textQueue() { + LinkedBlockingQueue eventQueue() { return queue } diff --git a/src/main/java/eu/m724/responsesource/ChatResponse.java b/src/main/java/eu/m724/responsesource/ChatResponse.java index 9f36909..e7d2379 100644 --- a/src/main/java/eu/m724/responsesource/ChatResponse.java +++ b/src/main/java/eu/m724/responsesource/ChatResponse.java @@ -1,5 +1,6 @@ package eu.m724.responsesource; +import eu.m724.chat.ChatEvent; import eu.m724.chat.ChatMessage; import java.util.concurrent.CompletableFuture; @@ -20,7 +21,7 @@ public interface ChatResponse { * * @return the fifo queue with each element being a part. null ends the sequence */ - LinkedBlockingQueue textQueue(); + LinkedBlockingQueue eventQueue(); /** * gets the resulting {@link ChatMessage} when it's ready diff --git a/src/main/java/eu/m724/responsesource/ChatResponseSource.java b/src/main/java/eu/m724/responsesource/ChatResponseSource.java index 2f9a2c6..585c993 100644 --- a/src/main/java/eu/m724/responsesource/ChatResponseSource.java +++ b/src/main/java/eu/m724/responsesource/ChatResponseSource.java @@ -5,5 +5,14 @@ import eu.m724.chat.Chat; public interface ChatResponseSource { ChatResponseSourceInfo info(); - ChatResponse ask(Chat chat); + ChatResponse onAsked(Chat chat); + + default ChatResponse ask(Chat chat) { + ChatResponse chatResponse = onAsked(chat); + + // TODO make sure it works in parallel + chatResponse.message().thenAccept(chat::addMessage); + + return chatResponse; + } }