big revamp

This commit is contained in:
Minecon724 2024-09-06 18:19:13 +02:00
parent ca7ebbdc2e
commit 6bfd51baa3
Signed by: Minecon724
GPG key ID: 3CCC4D267742C8E8
9 changed files with 194 additions and 222 deletions

View file

@ -10,8 +10,6 @@ import eu.m724.chatapi.source.option.Options
import eu.m724.chatapi.source.option.StringOption
import java.util.concurrent.CompletableFuture
import java.util.concurrent.LinkedBlockingQueue
/**
* an example chatresponsesource chatresponsesource ChatResponseSource CHATRESPONSESOURCE CAHTSERREPOSNECSOURCE
* note to self: rename that already...
@ -40,48 +38,30 @@ class ExampleSource implements ChatSource {
ChatResponse onAsked(Chat chat) {
//String prompt = chat.messages.getLast().text();
String response = rollResponse(chat)
String[] parts = response.split(" ")
LinkedBlockingQueue<ChatEvent> queue = new LinkedBlockingQueue<>()
ChatMessage message = new ChatMessage(true);
CompletableFuture<ChatMessage> future = CompletableFuture.supplyAsync {
CompletableFuture.supplyAsync {
for (int i=0; i<parts.length; i++) {
String token = (i > 0 ? " " : "") + parts[i]
queue.put(ChatEvent.of(token));
message.submitEvent(ChatEvent.of(token))
Thread.sleep(random.nextInt(50, 200))
}
queue.put(ChatEvent.finished("stop"))
return new ChatMessage(true, parts.join(" "))
message.submitEvent(ChatEvent.finished("stop"))
}
return new ChatResponse() {
@Override
boolean streaming() {
return false
}
@Override
LinkedBlockingQueue<ChatEvent> eventQueue() {
return queue
}
@Override
CompletableFuture<ChatMessage> message() {
return future
}
}
return new ChatResponse(message)
}
String rollResponse(Chat chat) {
def special = -1
def prompt = chat.messages.getLast().text()
def prompt = chat.messages.getLast().content()
int messagesCount = (int) Math.ceil(chat.messages.size() / 2)
def counter = lastCounter
def response = ""
if (prompt.toLowerCase().startsWith("my name is")) {
options.setValue("name", prompt.substring(11))
counter = 11

View file

@ -1,11 +1,12 @@
package eu.m724.chatapi.example
import eu.m724.chatapi.chat.Chat
import eu.m724.chatapi.chat.ChatEvent
import eu.m724.chatapi.chat.ChatMessage
import eu.m724.chatapi.source.ChatResponse
import eu.m724.chatapi.source.ChatSource
import eu.m724.chatapi.source.ChatSourceInfo
import eu.m724.chatapi.source.exception.HttpException
import eu.m724.chatapi.source.impl.StreamingChatResponse
import eu.m724.chatapi.source.option.DoubleOption
import eu.m724.chatapi.source.option.Options
import eu.m724.chatapi.source.option.StringOption
@ -65,7 +66,7 @@ class OaiSource implements ChatSource {
def client = HttpClient.newHttpClient()
def chatResponse = new StreamingChatResponse()
def chatMessage = new ChatMessage(true)
def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofLines())
@ -96,21 +97,22 @@ class OaiSource implements ChatSource {
def json = new JSONObject(data)
def choice = json.getJSONArray("choices").getJSONObject(0)
def finishReason = choice.get("finish_reason")
ChatEvent event
if (finishReason != JSONObject.NULL) {
//System.out.println("ending");
chatResponse.end(finishReason.toString())
event = ChatEvent.finished(finishReason.toString())
} else {
def token = choice.getJSONObject("delta").getString("content")
chatResponse.put(token)
event = ChatEvent.of(token)
}
chatMessage.submitEvent(event)
}
}
}
}
}
return chatResponse
return new ChatResponse(chatMessage)
}
static JSONArray formatChat(Chat chat) {
@ -121,7 +123,7 @@ class OaiSource implements ChatSource {
}
chat.messages.each {
array.put(new JSONObject().put("role", it.assistant() ? "assistant" : "user").put("content", it.text()))
array.put(new JSONObject().put("role", it.response() ? "assistant" : "user").put("content", it.content()))
}
return array

View file

@ -6,7 +6,6 @@ import eu.m724.chatapi.chat.ChatMessage;
import eu.m724.chatapi.example.OaiSource;
import eu.m724.chatapi.source.ChatResponse;
import eu.m724.chatapi.source.ChatSource;
import eu.m724.chatapi.source.exception.HttpException;
import eu.m724.chatapi.source.option.Option;
import eu.m724.chatapi.source.option.Options;
@ -19,7 +18,6 @@ class Main {
Scanner scanner = new Scanner(System.in);
String apiKey = System.getenv("API_KEY");
boolean complainedApiKey = false;
if (apiKey == null) {
System.out.print("\nAPI Key: ");
@ -32,7 +30,6 @@ class Main {
if (!Pattern.matches("sk-proj-.*?(?:\\s|$)", apiKey)) {
System.out.println("This key looks invalid");
complainedApiKey = true;
}
// ChatSource source = new ExampleSource();
@ -59,28 +56,33 @@ class Main {
if (!prompt.startsWith(":")) {
chat.messages.add(new ChatMessage(false, prompt));
ChatResponse chatResponse = source.ask(chat);
ChatMessage message = chatResponse.message();
ChatEvent token;
int i = 0;
do {
token = chatResponse.eventQueue().take();
if (!"error".equals(token.finishReason())) { // this looks bad but at least idea doesn't nag me
if (token.text() != null) {
System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m");
System.out.print(token.text());
int[] i = {0}; // this makes no sense
message.addEventConsumer(event -> {
if (!"error".equals(event.finishReason())) {
if (event.text() != null) {
System.out.print(i[0]++ % 2 == 1 ? "\033[1m" : "\033[0m");
System.out.print(event.text());
}
} else {
System.out.print("Error: " + token.error().toString());
if (complainedApiKey && token.error() instanceof HttpException && ((HttpException)token.error()).statusCode == 401) {
System.out.print("\nTold you");
complainedApiKey = false;
System.out.print("Error: " + event.error().toString());
// here was a funny thing, but I had to remove it
}
}, true);
try {
message.onComplete().join();
} catch (Throwable e) {
e.printStackTrace();
}
} while (token.finishReason() == null);
System.out.println();
} else {
String[] parts = prompt.substring(1).split(" ");
if (parts[0].startsWith(":")) {
@ -110,7 +112,7 @@ class Main {
System.out.printf(", excluding system prompt.\nSystem prompt:\n\"\"\"%s\"\"\"\n", chat.systemPrompt);
}
for (ChatMessage message : chat.messages) {
System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text());
System.out.printf("%s: %s\n", message.response() ? "ASSISTANT" : "USER", message.content());
}
break;
case "options":

View file

@ -1,7 +1,133 @@
package eu.m724.chatapi.chat;
public record ChatMessage(boolean assistant, String text) {
public static ChatMessage assistant(String text) {
return new ChatMessage(true, text);
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
/**
* represents a message, user or assistant, streaming or not<br>
* this is quite a robust class, so don't overuse it
*/
public class ChatMessage {
private final boolean response;
private final CompletableFuture<Void> completedFuture = new CompletableFuture<>();
private final Set<Consumer<ChatEvent>> eventConsumers = new HashSet<>();
private final List<ChatEvent> events = new ArrayList<>();
private String content = ""; // TODO I was thinking about making this abstract so you could put different types like images documents etc
//
/**
* Registers a chat message that can be streamed to
*
* @param response is the message a response
*/
public ChatMessage(boolean response) {
this.response = response;
}
/**
* Creates a completed chat message
*
* @param response is the message a response
* @param content the text
*/
public ChatMessage(boolean response, String content) {
this.response = response;
this.content = content;
completedFuture.complete(null);
//submitEvent(ChatEvent.of(text, "end"));
}
//
/**
* registers a consumer for the {@link ChatEvent}s
*
* @param consumer the consumer
* @param history receive events before this signed up
*/
public void addEventConsumer(Consumer<ChatEvent> consumer, boolean history) {
eventConsumers.add(consumer);
if (history) {
for (ChatEvent event : events)
consumer.accept(event);
}
}
//
/**
* submits a {@link ChatEvent} to broadcast to consumers<br>
* it also makes me (the {@link ChatMessage}) use it to update the final content or if streaming ended
*
* @param event the event
*/
public void submitEvent(ChatEvent event) {
if (event.text() != null) {
content += event.text();
}
events.add(event);
eventConsumers.forEach(c -> {
try {
c.accept(event);
} catch (Throwable e) {
System.err.println("Error distributing event:");
e.printStackTrace();
}
});
if (event.error() != null) {
completedFuture.completeExceptionally(event.error());
} else if (event.finishReason() != null) {
completedFuture.complete(null);
}
}
//
/**
* returns this message's content<br>
* if called during a response, it will return what's already been generated
*
* @return this message's content
*/
public String content() {
return content;
}
/**
* whether this message is a response<br>
* a response is usually sent by the assistant, not by the user
*
* @return whether this message is a response
*/
public boolean response() {
return response;
}
/**
* a future that completes when the response is complete<br>
* it can also throw an error
*
* @return a future that completes when the response is complete
*/
public CompletableFuture<Void> onComplete() {
return completedFuture;
}
/**
* @return is the response complete
*/
public boolean completed() {
return completedFuture.isDone();
}
}

View file

@ -1,33 +1,15 @@
package eu.m724.chatapi.source;
import eu.m724.chatapi.chat.ChatEvent;
import eu.m724.chatapi.chat.ChatMessage;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue;
public class ChatResponse {
private final ChatMessage message;
public interface ChatResponse {
/**
* is this response streaming
* if it's not, the queue will get one element that is the whole response
* I think about replacing with an abstract class and put this in the constructor
*
* @return is this response streaming
*/
boolean streaming();
/**
* if streamed, text token by token as it goes (or other splitting depending on the source)
* if not, the {@link CompletableFuture} returns just the whole response after it's ready
*
* @return the fifo queue with each element being a part. null ends the sequence
*/
LinkedBlockingQueue<ChatEvent> eventQueue();
/**
* gets the resulting {@link ChatMessage} when it's ready
*
* @return the resulting {@link ChatMessage} as soon as the response is complete
*/
CompletableFuture<ChatMessage> message();
public ChatResponse(ChatMessage message) {
this.message = message;
}
public ChatMessage message() {
return message;
}
}

View file

@ -31,10 +31,7 @@ public interface ChatSource {
*/
default ChatResponse ask(Chat chat) {
ChatResponse chatResponse = onAsked(chat);
chatResponse.message().thenAccept(message -> {
if (message != null) chat.addMessage(message);
});
chat.addMessage(chatResponse.message());
return chatResponse;
}

View file

@ -0,0 +1,20 @@
package eu.m724.chatapi.source.impl;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.Consumer;
public class BlockingQueueConsumer<T> {
public final LinkedBlockingQueue<T> queue = new LinkedBlockingQueue<>();
public final Consumer<T> consumer = new Consumer<>() {
@Override
public void accept(T t) {
try {
queue.put(t);
} catch (InterruptedException e) {
// again I don't know how that is relevant
throw new RuntimeException(e);
}
}
};
}

View file

@ -1,60 +0,0 @@
package eu.m724.chatapi.source.impl;
import eu.m724.chatapi.chat.ChatEvent;
import eu.m724.chatapi.chat.ChatMessage;
import eu.m724.chatapi.source.ChatResponse;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue;
public class NonStreamingChatResponse implements ChatResponse {
private final LinkedBlockingQueue<ChatEvent> eventQueue = new LinkedBlockingQueue<>();
private final CompletableFuture<ChatMessage> message = new CompletableFuture<>();
@Override
public boolean streaming() {
return false;
}
@Override
public LinkedBlockingQueue<ChatEvent> eventQueue() {
return eventQueue;
}
@Override
public CompletableFuture<ChatMessage> message() {
return message;
}
public boolean complete(String content) {
if (message.isDone()) return false;
try {
eventQueue.put(ChatEvent.of(content, "stop"));
} catch (InterruptedException e) {
// I don't know what this exception means
// and I don't think how will it cause me problems
// so ignoring it for now
throw new RuntimeException(e);
}
message.complete(new ChatMessage(true, content));
return true;
}
public boolean completeExceptionally(Throwable throwable) {
if (message.isDone()) return false;
try {
eventQueue.put(ChatEvent.of(throwable));
} catch (InterruptedException e) {
// again
throw new RuntimeException(e);
}
message.complete(null);
return true;
}
}

View file

@ -1,77 +0,0 @@
package eu.m724.chatapi.source.impl;
import eu.m724.chatapi.chat.ChatEvent;
import eu.m724.chatapi.chat.ChatMessage;
import eu.m724.chatapi.source.ChatResponse;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue;
public class StreamingChatResponse implements ChatResponse {
private final boolean streaming;
private final LinkedBlockingQueue<ChatEvent> eventQueue;
private final CompletableFuture<ChatMessage> message;
private String total = "";
public StreamingChatResponse() {
this.streaming = true;
this.eventQueue = new LinkedBlockingQueue<>();
this.message = new CompletableFuture<>();
}
public void put(String token, String finishReason) {
try {
eventQueue.put(ChatEvent.of(token, finishReason));
} catch (InterruptedException e) {
// I don't know what this exception means
// and I don't think how will it cause me problems
// so ignoring it for now
throw new RuntimeException(e);
}
if (token != null) {
total += token;
}
if (finishReason != null) {
message.complete(ChatMessage.assistant(total));
}
}
public void put(String token) {
put(token, null);
}
public void end(String finishReason) {
put(null, finishReason);
}
public void error(Throwable throwable) {
try {
eventQueue.put(ChatEvent.of(throwable));
} catch (InterruptedException e) {
// again
throw new RuntimeException(e);
}
message.complete(ChatMessage.assistant(total));
}
//
@Override
public boolean streaming() {
return streaming;
}
@Override
public LinkedBlockingQueue<ChatEvent> eventQueue() {
return eventQueue;
}
@Override
public CompletableFuture<ChatMessage> message() {
return message;
}
}