make it stream

and other changes
This commit is contained in:
Minecon724 2024-08-31 19:12:55 +02:00
parent 0f59b4d09c
commit 1dcf09614e
Signed by: Minecon724
GPG key ID: 3CCC4D267742C8E8
7 changed files with 125 additions and 30 deletions

View file

@ -3,20 +3,14 @@ package eu.m724;
import eu.m724.chat.Chat;
import eu.m724.chat.ChatEvent;
import eu.m724.chat.ChatMessage;
import eu.m724.example.ExampleSource;
import eu.m724.example.OaiSource;
import eu.m724.source.ChatResponse;
import eu.m724.source.ChatSource;
import eu.m724.source.option.Option;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.Arrays;
import java.util.Map;
import java.util.Scanner;
public class Main {
public static void main(String[] args) throws InterruptedException {
@ -24,12 +18,12 @@ public class Main {
//source.options().setValue("name", "nekalakininahappenenawiwanatin");
ChatSource source = new OaiSource(System.getenv("API_KEY"));
//source.options().setValue("model", "chatgpt-4o-latest");
source.options().setValue("model", "chatgpt-4o-latest");
Chat chat = new Chat("Speak in super wuper uwu wanguage.");
System.out.println("Welcome to CHUT chat. Say something after the \033[1m>\033[0m, or type \033[1m:help\033[0m to see available commands");
System.out.println("Source: " + source.getClass().getName());
System.out.printf("Source: \033[1m%s\033[0m %s (%d) by %s\n", source.info().name(), source.info().versionName(), source.info().version(), source.info().author());
Scanner scanner = new Scanner(System.in);
@ -47,8 +41,10 @@ public class Main {
token = chatResponse.eventQueue().take();
if (token.finishReason() != "error") {
System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m");
System.out.print(token.text());
if (token.text() != null) {
System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m");
System.out.print(token.text());
}
} else {
System.out.print("Error: " + token.error().toString());
}
@ -75,12 +71,13 @@ public class Main {
break;
case "dump":
case "d":
int size = chat.messages.size();
System.out.printf("This chat has %d messages, or %d pairs", size, size / 2);
if (chat.systemPrompt == null) {
System.out.printf("This chat has %d messages.\n", chat.messages.size());
System.out.println("There's no system prompt.");
System.out.println(".\nThere's no system prompt.");
} else {
System.out.printf("This chat has %d messages, excluding system prompt.\n", chat.messages.size());
System.out.printf("System prompt:\n\"\"\"%s\"\"\"\n", chat.systemPrompt);
System.out.printf(", excluding system prompt.\nSystem prompt:\n\"\"\"%s\"\"\"\n", chat.systemPrompt);
}
for (ChatMessage message : chat.messages) {
System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text());

View file

@ -1,7 +1,7 @@
package eu.m724.chat;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow;
public record ChatMessage(boolean assistant, String text) {
public static ChatMessage assistant(String text) {
return new ChatMessage(true, text);
}
}

View file

@ -4,7 +4,7 @@ import eu.m724.chat.Chat
import eu.m724.source.ChatResponse
import eu.m724.source.ChatSource
import eu.m724.source.ChatSourceInfo
import eu.m724.source.NonStreamingChatResponse
import eu.m724.source.SimpleChatResponse
import eu.m724.source.option.DoubleOption
import eu.m724.source.option.Options
import eu.m724.source.option.StringOption
@ -14,7 +14,6 @@ import org.json.JSONObject
import java.net.http.HttpClient
import java.net.http.HttpRequest
import java.net.http.HttpResponse
// for now let's not focus on readability
// this is more about find out what is common and should be included in the common api
class OaiSource implements ChatSource {
@ -50,6 +49,7 @@ class OaiSource implements ChatSource {
.put("model", options.getStringValue("model"))
.put("temperature", options.getOptionValue("temperature", Double::class))
.put("presence_penalty", 0.1)
.put("stream", true)
.put("messages", formatChat(chat)).toString()
def request = HttpRequest.newBuilder()
@ -61,9 +61,10 @@ class OaiSource implements ChatSource {
def client = HttpClient.newHttpClient()
ChatResponse chatResponse = new NonStreamingChatResponse()
def chatResponse = new SimpleChatResponse()
def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofLines())
def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofString())
response.thenAccept {
Exception exception = null
@ -72,6 +73,30 @@ class OaiSource implements ChatSource {
exception = new Exception("Non 200 status code: %d".formatted(it.statusCode()))
}
it.body().forEach {
//System.out.println(it);
for (String line : it.split("\n")) {
if (line.startsWith("data: ")) {
def data = line.substring(6)
if (data != "[DONE]") {
def json = new JSONObject(data)
def choice = json.getJSONArray("choices").getJSONObject(0)
def finishReason = choice.get("finish_reason")
if (finishReason != JSONObject.NULL) {
//System.out.println("ending");
chatResponse.end(finishReason.toString())
} else {
def token = choice.getJSONObject("delta").getString("content")
chatResponse.put(token)
}
}
}
}
}
/*
if (exception != null) {
chatResponse.completeExceptionally(exception)
} else {
@ -79,7 +104,7 @@ class OaiSource implements ChatSource {
//System.out.println(json)
def completion = json.getJSONArray("choices").getJSONObject(0).getJSONObject("message").getString("content")
chatResponse.complete(completion)
}
}*/
}

View file

@ -1,7 +1,6 @@
package eu.m724.source;
import eu.m724.chat.Chat;
import eu.m724.source.option.Option;
import eu.m724.source.option.Options;
public interface ChatSource {
@ -33,6 +32,25 @@ public interface ChatSource {
default ChatResponse ask(Chat chat) {
ChatResponse chatResponse = onAsked(chat);
/* if (chatResponse.streaming()) {
StringBuilder total = new StringBuilder();
CompletableFuture.runAsync(() -> {
ChatEvent event;
while (true) {
try {
event = chatResponse.eventQueue().take();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
total.append(event.text());
if (event.finishReason() != null) {
chatResponse.message().complete(new ChatMessage(true, total.toString()));
break;
}
}
});
} this was a draft I'm keeping because I might change my mind */
// TODO make sure it works in parallel
chatResponse.message().thenAccept(message -> {
if (message != null) chat.addMessage(message);

View file

@ -7,8 +7,8 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue;
public class NonStreamingChatResponse implements ChatResponse {
private LinkedBlockingQueue<ChatEvent> eventQueue = new LinkedBlockingQueue<>();
private CompletableFuture<ChatMessage> message = new CompletableFuture<>();
private final LinkedBlockingQueue<ChatEvent> eventQueue = new LinkedBlockingQueue<>();
private final CompletableFuture<ChatMessage> message = new CompletableFuture<>();
@Override
public boolean streaming() {

View file

@ -6,6 +6,57 @@ import eu.m724.chat.ChatMessage;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue;
public record SimpleChatResponse(boolean streaming, LinkedBlockingQueue<ChatEvent> eventQueue,
CompletableFuture<ChatMessage> message) implements ChatResponse {
public class SimpleChatResponse implements ChatResponse {
private final boolean streaming;
private final LinkedBlockingQueue<ChatEvent> eventQueue;
private final CompletableFuture<ChatMessage> message;
private String total = "";
public SimpleChatResponse() {
this.streaming = true;
this.eventQueue = new LinkedBlockingQueue<>();
this.message = new CompletableFuture<>();
}
public void put(String token, String finishReason) {
//System.out.println(System.currentTimeMillis());
try {
eventQueue.put(ChatEvent.of(token, finishReason));
} catch (InterruptedException e) {
throw new RuntimeException(e); // TODO I don't know what this is
}
if (token != null) {
total += token;
}
if (finishReason != null) {
message.complete(ChatMessage.assistant(total));
}
}
public void put(String token) {
put(token, null);
}
public void end(String finishReason) {
put(null, finishReason);
}
@Override
public boolean streaming() {
return streaming;
}
@Override
public LinkedBlockingQueue<ChatEvent> eventQueue() {
return eventQueue;
}
@Override
public CompletableFuture<ChatMessage> message() {
return message;
}
}

View file

@ -3,6 +3,10 @@ so, the api is mostly good now, but new things to focus on:
- how to do requests
- how to make it friendly
delay
right now streaming seems delayed and batched, I don't know why
maybe it's just me
network
every platform has different libraries for that, so abstraction
how? I still don't know