make it better
so I somewhat figured it
This commit is contained in:
parent
ac3c2b0386
commit
6ce5453de0
7 changed files with 72 additions and 22 deletions
|
@ -1,12 +1,16 @@
|
|||
package eu.m724;
|
||||
|
||||
import eu.m724.chat.Chat;
|
||||
import eu.m724.chat.ChatEvent;
|
||||
import eu.m724.chat.ChatMessage;
|
||||
import eu.m724.example.ExampleSource;
|
||||
import eu.m724.responsesource.ChatResponse;
|
||||
import eu.m724.responsesource.ChatResponseSource;
|
||||
import groovy.lang.GroovyShell;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class Main {
|
||||
public static void main(String[] args) throws InterruptedException {
|
||||
ChatResponseSource source = new ExampleSource();
|
||||
|
@ -15,18 +19,33 @@ public class Main {
|
|||
chat.messages.add(new ChatMessage(false, "hello"));
|
||||
|
||||
ChatResponse chatResponse = source.ask(chat);
|
||||
String token;
|
||||
int tokens = 0;
|
||||
|
||||
// I was thinking about integrating this into ChatMessage
|
||||
List<String> tokens = new ArrayList<>();
|
||||
List<Long> delays = new ArrayList<>();
|
||||
|
||||
System.out.println("Streaming response now\n");
|
||||
ChatEvent token;
|
||||
|
||||
while (!(token = chatResponse.textQueue().take()).equals("END_OF_TEXT")) {
|
||||
System.out.print(token);
|
||||
tokens++;
|
||||
// usually finish reason will be alongside a token but this is simpler
|
||||
while ((token = chatResponse.eventQueue().take()).finishReason() == null) {
|
||||
System.out.print(token.text());
|
||||
tokens.add(token.text());
|
||||
|
||||
long now = System.currentTimeMillis();
|
||||
delays.add(now);
|
||||
}
|
||||
|
||||
System.out.println("\n");
|
||||
System.out.printf("Tokens: %d\n", tokens);
|
||||
System.out.printf("Text: %s\n", chatResponse.message().join().text);
|
||||
System.out.printf("Tokens: %d\n", tokens.size());
|
||||
|
||||
long time = delays.getFirst();
|
||||
for (int i=0; i<tokens.size()-1; i++) {
|
||||
System.out.printf("\"%s\" + %dms, ", tokens.get(i), delays.get(i+1) - time);
|
||||
time = delays.get(i+1);
|
||||
}
|
||||
System.out.printf("\"%s\"\n\n", tokens.getLast());
|
||||
|
||||
System.out.printf("Text: %s\n", chatResponse.message().join().text());
|
||||
}
|
||||
}
|
|
@ -12,4 +12,8 @@ public class Chat {
|
|||
}
|
||||
|
||||
public Chat() {}
|
||||
|
||||
public void addMessage(ChatMessage message) {
|
||||
this.messages.add(message);
|
||||
}
|
||||
}
|
||||
|
|
23
src/main/java/eu/m724/chat/ChatEvent.java
Normal file
23
src/main/java/eu/m724/chat/ChatEvent.java
Normal file
|
@ -0,0 +1,23 @@
|
|||
package eu.m724.chat;
|
||||
|
||||
public record ChatEvent(
|
||||
String text,
|
||||
String finishReason,
|
||||
Throwable error
|
||||
) {
|
||||
public static ChatEvent of(String text) {
|
||||
return new ChatEvent(text, null, null);
|
||||
}
|
||||
|
||||
public static ChatEvent of(String text, String finishReason) {
|
||||
return new ChatEvent(text, finishReason, null);
|
||||
}
|
||||
|
||||
public static ChatEvent finished(String finishReason) {
|
||||
return ChatEvent.of(null, finishReason);
|
||||
}
|
||||
|
||||
public static ChatEvent of(Throwable error) {
|
||||
return new ChatEvent(null, "error", error);
|
||||
}
|
||||
}
|
|
@ -3,12 +3,5 @@ package eu.m724.chat;
|
|||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.Flow;
|
||||
|
||||
public class ChatMessage {
|
||||
public boolean assistant;
|
||||
public String text; // TODO make it private and modifiable other way
|
||||
|
||||
public ChatMessage(boolean assistant, String text) {
|
||||
this.assistant = assistant;
|
||||
this.text = text;
|
||||
}
|
||||
public record ChatMessage(boolean assistant, String text) {
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package eu.m724.example
|
||||
|
||||
import eu.m724.chat.Chat
|
||||
import eu.m724.chat.ChatEvent
|
||||
import eu.m724.chat.ChatMessage
|
||||
import eu.m724.responsesource.ChatResponse
|
||||
import eu.m724.responsesource.ChatResponseSource
|
||||
|
@ -26,19 +27,19 @@ class ExampleSource implements ChatResponseSource {
|
|||
}
|
||||
|
||||
@Override
|
||||
ChatResponse ask(Chat chat) {
|
||||
ChatResponse onAsked(Chat chat) {
|
||||
String[] parts = "hello how can I assist you today".split(" ")
|
||||
|
||||
LinkedBlockingQueue<String> queue = new LinkedBlockingQueue<>()
|
||||
LinkedBlockingQueue<ChatEvent> queue = new LinkedBlockingQueue<>()
|
||||
|
||||
CompletableFuture<ChatMessage> future = CompletableFuture.supplyAsync {
|
||||
for (int i=0; i<parts.length; i++) {
|
||||
String token = (i > 0 ? " " : "") + parts[i]
|
||||
queue.put(token);
|
||||
queue.put(ChatEvent.of(token));
|
||||
Thread.sleep(random.nextInt(200, 500))
|
||||
}
|
||||
|
||||
queue.put("END_OF_TEXT")
|
||||
queue.put(ChatEvent.finished("stop"))
|
||||
return new ChatMessage(true, parts.join(" "))
|
||||
}
|
||||
|
||||
|
@ -49,7 +50,7 @@ class ExampleSource implements ChatResponseSource {
|
|||
}
|
||||
|
||||
@Override
|
||||
LinkedBlockingQueue<String> textQueue() {
|
||||
LinkedBlockingQueue<ChatEvent> eventQueue() {
|
||||
return queue
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package eu.m724.responsesource;
|
||||
|
||||
import eu.m724.chat.ChatEvent;
|
||||
import eu.m724.chat.ChatMessage;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
@ -20,7 +21,7 @@ public interface ChatResponse {
|
|||
*
|
||||
* @return the fifo queue with each element being a part. null ends the sequence
|
||||
*/
|
||||
LinkedBlockingQueue<String> textQueue();
|
||||
LinkedBlockingQueue<ChatEvent> eventQueue();
|
||||
|
||||
/**
|
||||
* gets the resulting {@link ChatMessage} when it's ready
|
||||
|
|
|
@ -5,5 +5,14 @@ import eu.m724.chat.Chat;
|
|||
public interface ChatResponseSource {
|
||||
ChatResponseSourceInfo info();
|
||||
|
||||
ChatResponse ask(Chat chat);
|
||||
ChatResponse onAsked(Chat chat);
|
||||
|
||||
default ChatResponse ask(Chat chat) {
|
||||
ChatResponse chatResponse = onAsked(chat);
|
||||
|
||||
// TODO make sure it works in parallel
|
||||
chatResponse.message().thenAccept(chat::addMessage);
|
||||
|
||||
return chatResponse;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue