make it better

so I somewhat figured it
This commit is contained in:
Minecon724 2024-08-29 13:22:48 +02:00
parent ac3c2b0386
commit 6ce5453de0
Signed by: Minecon724
GPG key ID: 3CCC4D267742C8E8
7 changed files with 72 additions and 22 deletions

View file

@ -1,12 +1,16 @@
package eu.m724; package eu.m724;
import eu.m724.chat.Chat; import eu.m724.chat.Chat;
import eu.m724.chat.ChatEvent;
import eu.m724.chat.ChatMessage; import eu.m724.chat.ChatMessage;
import eu.m724.example.ExampleSource; import eu.m724.example.ExampleSource;
import eu.m724.responsesource.ChatResponse; import eu.m724.responsesource.ChatResponse;
import eu.m724.responsesource.ChatResponseSource; import eu.m724.responsesource.ChatResponseSource;
import groovy.lang.GroovyShell; import groovy.lang.GroovyShell;
import java.util.ArrayList;
import java.util.List;
public class Main { public class Main {
public static void main(String[] args) throws InterruptedException { public static void main(String[] args) throws InterruptedException {
ChatResponseSource source = new ExampleSource(); ChatResponseSource source = new ExampleSource();
@ -15,18 +19,33 @@ public class Main {
chat.messages.add(new ChatMessage(false, "hello")); chat.messages.add(new ChatMessage(false, "hello"));
ChatResponse chatResponse = source.ask(chat); ChatResponse chatResponse = source.ask(chat);
String token;
int tokens = 0; // I was thinking about integrating this into ChatMessage
List<String> tokens = new ArrayList<>();
List<Long> delays = new ArrayList<>();
System.out.println("Streaming response now\n"); System.out.println("Streaming response now\n");
ChatEvent token;
while (!(token = chatResponse.textQueue().take()).equals("END_OF_TEXT")) { // usually finish reason will be alongside a token but this is simpler
System.out.print(token); while ((token = chatResponse.eventQueue().take()).finishReason() == null) {
tokens++; System.out.print(token.text());
tokens.add(token.text());
long now = System.currentTimeMillis();
delays.add(now);
} }
System.out.println("\n"); System.out.println("\n");
System.out.printf("Tokens: %d\n", tokens); System.out.printf("Tokens: %d\n", tokens.size());
System.out.printf("Text: %s\n", chatResponse.message().join().text);
long time = delays.getFirst();
for (int i=0; i<tokens.size()-1; i++) {
System.out.printf("\"%s\" + %dms, ", tokens.get(i), delays.get(i+1) - time);
time = delays.get(i+1);
}
System.out.printf("\"%s\"\n\n", tokens.getLast());
System.out.printf("Text: %s\n", chatResponse.message().join().text());
} }
} }

View file

@ -12,4 +12,8 @@ public class Chat {
} }
public Chat() {} public Chat() {}
public void addMessage(ChatMessage message) {
this.messages.add(message);
}
} }

View file

@ -0,0 +1,23 @@
package eu.m724.chat;
public record ChatEvent(
String text,
String finishReason,
Throwable error
) {
public static ChatEvent of(String text) {
return new ChatEvent(text, null, null);
}
public static ChatEvent of(String text, String finishReason) {
return new ChatEvent(text, finishReason, null);
}
public static ChatEvent finished(String finishReason) {
return ChatEvent.of(null, finishReason);
}
public static ChatEvent of(Throwable error) {
return new ChatEvent(null, "error", error);
}
}

View file

@ -3,12 +3,5 @@ package eu.m724.chat;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow; import java.util.concurrent.Flow;
public class ChatMessage { public record ChatMessage(boolean assistant, String text) {
public boolean assistant;
public String text; // TODO make it private and modifiable other way
public ChatMessage(boolean assistant, String text) {
this.assistant = assistant;
this.text = text;
}
} }

View file

@ -1,6 +1,7 @@
package eu.m724.example package eu.m724.example
import eu.m724.chat.Chat import eu.m724.chat.Chat
import eu.m724.chat.ChatEvent
import eu.m724.chat.ChatMessage import eu.m724.chat.ChatMessage
import eu.m724.responsesource.ChatResponse import eu.m724.responsesource.ChatResponse
import eu.m724.responsesource.ChatResponseSource import eu.m724.responsesource.ChatResponseSource
@ -26,19 +27,19 @@ class ExampleSource implements ChatResponseSource {
} }
@Override @Override
ChatResponse ask(Chat chat) { ChatResponse onAsked(Chat chat) {
String[] parts = "hello how can I assist you today".split(" ") String[] parts = "hello how can I assist you today".split(" ")
LinkedBlockingQueue<String> queue = new LinkedBlockingQueue<>() LinkedBlockingQueue<ChatEvent> queue = new LinkedBlockingQueue<>()
CompletableFuture<ChatMessage> future = CompletableFuture.supplyAsync { CompletableFuture<ChatMessage> future = CompletableFuture.supplyAsync {
for (int i=0; i<parts.length; i++) { for (int i=0; i<parts.length; i++) {
String token = (i > 0 ? " " : "") + parts[i] String token = (i > 0 ? " " : "") + parts[i]
queue.put(token); queue.put(ChatEvent.of(token));
Thread.sleep(random.nextInt(200, 500)) Thread.sleep(random.nextInt(200, 500))
} }
queue.put("END_OF_TEXT") queue.put(ChatEvent.finished("stop"))
return new ChatMessage(true, parts.join(" ")) return new ChatMessage(true, parts.join(" "))
} }
@ -49,7 +50,7 @@ class ExampleSource implements ChatResponseSource {
} }
@Override @Override
LinkedBlockingQueue<String> textQueue() { LinkedBlockingQueue<ChatEvent> eventQueue() {
return queue return queue
} }

View file

@ -1,5 +1,6 @@
package eu.m724.responsesource; package eu.m724.responsesource;
import eu.m724.chat.ChatEvent;
import eu.m724.chat.ChatMessage; import eu.m724.chat.ChatMessage;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -20,7 +21,7 @@ public interface ChatResponse {
* *
* @return the fifo queue with each element being a part. null ends the sequence * @return the fifo queue with each element being a part. null ends the sequence
*/ */
LinkedBlockingQueue<String> textQueue(); LinkedBlockingQueue<ChatEvent> eventQueue();
/** /**
* gets the resulting {@link ChatMessage} when it's ready * gets the resulting {@link ChatMessage} when it's ready

View file

@ -5,5 +5,14 @@ import eu.m724.chat.Chat;
public interface ChatResponseSource { public interface ChatResponseSource {
ChatResponseSourceInfo info(); ChatResponseSourceInfo info();
ChatResponse ask(Chat chat); ChatResponse onAsked(Chat chat);
default ChatResponse ask(Chat chat) {
ChatResponse chatResponse = onAsked(chat);
// TODO make sure it works in parallel
chatResponse.message().thenAccept(chat::addMessage);
return chatResponse;
}
} }