Compare commits

..

2 commits

Author SHA1 Message Date
6ce5453de0
make it better
so I somewhat figured it
2024-08-29 13:22:48 +02:00
ac3c2b0386
make it streamable 2024-08-28 15:48:52 +02:00
7 changed files with 106 additions and 26 deletions

View file

@ -1,14 +1,18 @@
package eu.m724; package eu.m724;
import eu.m724.chat.Chat; import eu.m724.chat.Chat;
import eu.m724.chat.ChatEvent;
import eu.m724.chat.ChatMessage; import eu.m724.chat.ChatMessage;
import eu.m724.example.ExampleSource; import eu.m724.example.ExampleSource;
import eu.m724.responsesource.ChatResponse; import eu.m724.responsesource.ChatResponse;
import eu.m724.responsesource.ChatResponseSource; import eu.m724.responsesource.ChatResponseSource;
import groovy.lang.GroovyShell; import groovy.lang.GroovyShell;
import java.util.ArrayList;
import java.util.List;
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) throws InterruptedException {
ChatResponseSource source = new ExampleSource(); ChatResponseSource source = new ExampleSource();
Chat chat = new Chat(); Chat chat = new Chat();
@ -16,7 +20,32 @@ public class Main {
ChatResponse chatResponse = source.ask(chat); ChatResponse chatResponse = source.ask(chat);
System.out.println(chatResponse.text().join()); // I was thinking about integrating this into ChatMessage
System.out.println(chatResponse.message().text); List<String> tokens = new ArrayList<>();
List<Long> delays = new ArrayList<>();
System.out.println("Streaming response now\n");
ChatEvent token;
// usually finish reason will be alongside a token but this is simpler
while ((token = chatResponse.eventQueue().take()).finishReason() == null) {
System.out.print(token.text());
tokens.add(token.text());
long now = System.currentTimeMillis();
delays.add(now);
}
System.out.println("\n");
System.out.printf("Tokens: %d\n", tokens.size());
long time = delays.getFirst();
for (int i=0; i<tokens.size()-1; i++) {
System.out.printf("\"%s\" + %dms, ", tokens.get(i), delays.get(i+1) - time);
time = delays.get(i+1);
}
System.out.printf("\"%s\"\n\n", tokens.getLast());
System.out.printf("Text: %s\n", chatResponse.message().join().text());
} }
} }

View file

@ -12,4 +12,8 @@ public class Chat {
} }
public Chat() {} public Chat() {}
public void addMessage(ChatMessage message) {
this.messages.add(message);
}
} }

View file

@ -0,0 +1,23 @@
package eu.m724.chat;
public record ChatEvent(
String text,
String finishReason,
Throwable error
) {
public static ChatEvent of(String text) {
return new ChatEvent(text, null, null);
}
public static ChatEvent of(String text, String finishReason) {
return new ChatEvent(text, finishReason, null);
}
public static ChatEvent finished(String finishReason) {
return ChatEvent.of(null, finishReason);
}
public static ChatEvent of(Throwable error) {
return new ChatEvent(null, "error", error);
}
}

View file

@ -3,12 +3,5 @@ package eu.m724.chat;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow; import java.util.concurrent.Flow;
public class ChatMessage { public record ChatMessage(boolean assistant, String text) {
public boolean assistant;
public String text; // TODO make it private and modifiable other way
public ChatMessage(boolean assistant, String text) {
this.assistant = assistant;
this.text = text;
}
} }

View file

@ -1,12 +1,16 @@
package eu.m724.example package eu.m724.example
import eu.m724.chat.Chat import eu.m724.chat.Chat
import eu.m724.chat.ChatEvent
import eu.m724.chat.ChatMessage import eu.m724.chat.ChatMessage
import eu.m724.responsesource.ChatResponse import eu.m724.responsesource.ChatResponse
import eu.m724.responsesource.ChatResponseSource import eu.m724.responsesource.ChatResponseSource
import eu.m724.responsesource.ChatResponseSourceInfo import eu.m724.responsesource.ChatResponseSourceInfo
import java.util.concurrent.CompletableFuture import java.util.concurrent.CompletableFuture
import java.util.concurrent.Executors
import java.util.concurrent.LinkedBlockingDeque
import java.util.concurrent.LinkedBlockingQueue
/** /**
* an example chatresponsesource chatresponsesource ChatResponseSource CHATRESPONSESOURCE CAHTSERREPOSNECSOURCE * an example chatresponsesource chatresponsesource ChatResponseSource CHATRESPONSESOURCE CAHTSERREPOSNECSOURCE
@ -15,6 +19,7 @@ import java.util.concurrent.CompletableFuture
class ExampleSource implements ChatResponseSource { class ExampleSource implements ChatResponseSource {
private ChatResponseSourceInfo info = private ChatResponseSourceInfo info =
new ChatResponseSourceInfo("yo", "ye", "1.0", 1) new ChatResponseSourceInfo("yo", "ye", "1.0", 1)
private Random random = new Random()
@Override @Override
ChatResponseSourceInfo info() { ChatResponseSourceInfo info() {
@ -22,24 +27,36 @@ class ExampleSource implements ChatResponseSource {
} }
@Override @Override
ChatResponse ask(Chat chat) { ChatResponse onAsked(Chat chat) {
return new ChatResponse() { String[] parts = "hello how can I assist you today".split(" ")
String[] parts
CompletableFuture<String> completableFuture = new CompletableFuture<>();
LinkedBlockingQueue<ChatEvent> queue = new LinkedBlockingQueue<>()
CompletableFuture<ChatMessage> future = CompletableFuture.supplyAsync {
for (int i=0; i<parts.length; i++) {
String token = (i > 0 ? " " : "") + parts[i]
queue.put(ChatEvent.of(token));
Thread.sleep(random.nextInt(200, 500))
}
queue.put(ChatEvent.finished("stop"))
return new ChatMessage(true, parts.join(" "))
}
return new ChatResponse() {
@Override @Override
boolean isStreaming() { boolean isStreaming() {
return false return false
} }
@Override @Override
CompletableFuture<String> text() { LinkedBlockingQueue<ChatEvent> eventQueue() {
return CompletableFuture.completedFuture("hello how can i assist you today") return queue
} }
@Override @Override
ChatMessage message() { CompletableFuture<ChatMessage> message() {
return new ChatMessage(true, "i assisted you already bye") return future
} }
} }
} }

View file

@ -1,12 +1,16 @@
package eu.m724.responsesource; package eu.m724.responsesource;
import eu.m724.chat.ChatEvent;
import eu.m724.chat.ChatMessage; import eu.m724.chat.ChatMessage;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue;
public interface ChatResponse { public interface ChatResponse {
/** /**
* is this response streaming * is this response streaming
* if it's not, the queue will get one element that is the whole response
*
* @return is this response streaming * @return is this response streaming
*/ */
boolean isStreaming(); boolean isStreaming();
@ -14,14 +18,15 @@ public interface ChatResponse {
/** /**
* if streamed, text token by token as it goes (or other splitting depending on the source) * if streamed, text token by token as it goes (or other splitting depending on the source)
* if not, the {@link CompletableFuture} returns just the whole response after it's ready * if not, the {@link CompletableFuture} returns just the whole response after it's ready
* @return yeah *
* @return the fifo queue with each element being a part. null ends the sequence
*/ */
CompletableFuture<String> text(); // TODO completablefuture is not correct here also fix the doc LinkedBlockingQueue<ChatEvent> eventQueue();
/** /**
* gets the resulting {@link ChatMessage} * gets the resulting {@link ChatMessage} when it's ready
* TODO I think it should be available after streaming is done so maybe wrap this in {@link CompletableFuture} *
* @return the resulting {@link ChatMessage} * @return the resulting {@link ChatMessage} as soon as the response is complete
*/ */
ChatMessage message(); CompletableFuture<ChatMessage> message();
} }

View file

@ -5,5 +5,14 @@ import eu.m724.chat.Chat;
public interface ChatResponseSource { public interface ChatResponseSource {
ChatResponseSourceInfo info(); ChatResponseSourceInfo info();
ChatResponse ask(Chat chat); ChatResponse onAsked(Chat chat);
default ChatResponse ask(Chat chat) {
ChatResponse chatResponse = onAsked(chat);
// TODO make sure it works in parallel
chatResponse.message().thenAccept(chat::addMessage);
return chatResponse;
}
} }