big revamp

This commit is contained in:
Minecon724 2024-09-06 18:19:13 +02:00
parent ca7ebbdc2e
commit 6bfd51baa3
Signed by: Minecon724
GPG key ID: 3CCC4D267742C8E8
9 changed files with 194 additions and 222 deletions

View file

@ -10,8 +10,6 @@ import eu.m724.chatapi.source.option.Options
import eu.m724.chatapi.source.option.StringOption import eu.m724.chatapi.source.option.StringOption
import java.util.concurrent.CompletableFuture import java.util.concurrent.CompletableFuture
import java.util.concurrent.LinkedBlockingQueue
/** /**
* an example chatresponsesource chatresponsesource ChatResponseSource CHATRESPONSESOURCE CAHTSERREPOSNECSOURCE * an example chatresponsesource chatresponsesource ChatResponseSource CHATRESPONSESOURCE CAHTSERREPOSNECSOURCE
* note to self: rename that already... * note to self: rename that already...
@ -40,48 +38,30 @@ class ExampleSource implements ChatSource {
ChatResponse onAsked(Chat chat) { ChatResponse onAsked(Chat chat) {
//String prompt = chat.messages.getLast().text(); //String prompt = chat.messages.getLast().text();
String response = rollResponse(chat) String response = rollResponse(chat)
String[] parts = response.split(" ") String[] parts = response.split(" ")
LinkedBlockingQueue<ChatEvent> queue = new LinkedBlockingQueue<>() ChatMessage message = new ChatMessage(true);
CompletableFuture<ChatMessage> future = CompletableFuture.supplyAsync { CompletableFuture.supplyAsync {
for (int i=0; i<parts.length; i++) { for (int i=0; i<parts.length; i++) {
String token = (i > 0 ? " " : "") + parts[i] String token = (i > 0 ? " " : "") + parts[i]
queue.put(ChatEvent.of(token)); message.submitEvent(ChatEvent.of(token))
Thread.sleep(random.nextInt(50, 200)) Thread.sleep(random.nextInt(50, 200))
} }
queue.put(ChatEvent.finished("stop")) message.submitEvent(ChatEvent.finished("stop"))
return new ChatMessage(true, parts.join(" "))
} }
return new ChatResponse() { return new ChatResponse(message)
@Override
boolean streaming() {
return false
}
@Override
LinkedBlockingQueue<ChatEvent> eventQueue() {
return queue
}
@Override
CompletableFuture<ChatMessage> message() {
return future
}
}
} }
String rollResponse(Chat chat) { String rollResponse(Chat chat) {
def special = -1 def special = -1
def prompt = chat.messages.getLast().text() def prompt = chat.messages.getLast().content()
int messagesCount = (int) Math.ceil(chat.messages.size() / 2) int messagesCount = (int) Math.ceil(chat.messages.size() / 2)
def counter = lastCounter def counter = lastCounter
def response = "" def response = ""
if (prompt.toLowerCase().startsWith("my name is")) { if (prompt.toLowerCase().startsWith("my name is")) {
options.setValue("name", prompt.substring(11)) options.setValue("name", prompt.substring(11))
counter = 11 counter = 11

View file

@ -1,11 +1,12 @@
package eu.m724.chatapi.example package eu.m724.chatapi.example
import eu.m724.chatapi.chat.Chat import eu.m724.chatapi.chat.Chat
import eu.m724.chatapi.chat.ChatEvent
import eu.m724.chatapi.chat.ChatMessage
import eu.m724.chatapi.source.ChatResponse import eu.m724.chatapi.source.ChatResponse
import eu.m724.chatapi.source.ChatSource import eu.m724.chatapi.source.ChatSource
import eu.m724.chatapi.source.ChatSourceInfo import eu.m724.chatapi.source.ChatSourceInfo
import eu.m724.chatapi.source.exception.HttpException import eu.m724.chatapi.source.exception.HttpException
import eu.m724.chatapi.source.impl.StreamingChatResponse
import eu.m724.chatapi.source.option.DoubleOption import eu.m724.chatapi.source.option.DoubleOption
import eu.m724.chatapi.source.option.Options import eu.m724.chatapi.source.option.Options
import eu.m724.chatapi.source.option.StringOption import eu.m724.chatapi.source.option.StringOption
@ -65,7 +66,7 @@ class OaiSource implements ChatSource {
def client = HttpClient.newHttpClient() def client = HttpClient.newHttpClient()
def chatResponse = new StreamingChatResponse() def chatMessage = new ChatMessage(true)
def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofLines()) def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofLines())
@ -96,21 +97,22 @@ class OaiSource implements ChatSource {
def json = new JSONObject(data) def json = new JSONObject(data)
def choice = json.getJSONArray("choices").getJSONObject(0) def choice = json.getJSONArray("choices").getJSONObject(0)
def finishReason = choice.get("finish_reason") def finishReason = choice.get("finish_reason")
ChatEvent event
if (finishReason != JSONObject.NULL) { if (finishReason != JSONObject.NULL) {
//System.out.println("ending"); event = ChatEvent.finished(finishReason.toString())
chatResponse.end(finishReason.toString())
} else { } else {
def token = choice.getJSONObject("delta").getString("content") def token = choice.getJSONObject("delta").getString("content")
chatResponse.put(token) event = ChatEvent.of(token)
} }
chatMessage.submitEvent(event)
} }
} }
} }
} }
} }
return chatResponse return new ChatResponse(chatMessage)
} }
static JSONArray formatChat(Chat chat) { static JSONArray formatChat(Chat chat) {
@ -121,7 +123,7 @@ class OaiSource implements ChatSource {
} }
chat.messages.each { chat.messages.each {
array.put(new JSONObject().put("role", it.assistant() ? "assistant" : "user").put("content", it.text())) array.put(new JSONObject().put("role", it.response() ? "assistant" : "user").put("content", it.content()))
} }
return array return array

View file

@ -6,7 +6,6 @@ import eu.m724.chatapi.chat.ChatMessage;
import eu.m724.chatapi.example.OaiSource; import eu.m724.chatapi.example.OaiSource;
import eu.m724.chatapi.source.ChatResponse; import eu.m724.chatapi.source.ChatResponse;
import eu.m724.chatapi.source.ChatSource; import eu.m724.chatapi.source.ChatSource;
import eu.m724.chatapi.source.exception.HttpException;
import eu.m724.chatapi.source.option.Option; import eu.m724.chatapi.source.option.Option;
import eu.m724.chatapi.source.option.Options; import eu.m724.chatapi.source.option.Options;
@ -19,7 +18,6 @@ class Main {
Scanner scanner = new Scanner(System.in); Scanner scanner = new Scanner(System.in);
String apiKey = System.getenv("API_KEY"); String apiKey = System.getenv("API_KEY");
boolean complainedApiKey = false;
if (apiKey == null) { if (apiKey == null) {
System.out.print("\nAPI Key: "); System.out.print("\nAPI Key: ");
@ -32,10 +30,9 @@ class Main {
if (!Pattern.matches("sk-proj-.*?(?:\\s|$)", apiKey)) { if (!Pattern.matches("sk-proj-.*?(?:\\s|$)", apiKey)) {
System.out.println("This key looks invalid"); System.out.println("This key looks invalid");
complainedApiKey = true;
} }
//ChatSource source = new ExampleSource(); // ChatSource source = new ExampleSource();
ChatSource source = new OaiSource(apiKey); ChatSource source = new OaiSource(apiKey);
source.options().setValue("model", "chatgpt-4o-latest"); source.options().setValue("model", "chatgpt-4o-latest");
@ -59,28 +56,33 @@ class Main {
if (!prompt.startsWith(":")) { if (!prompt.startsWith(":")) {
chat.messages.add(new ChatMessage(false, prompt)); chat.messages.add(new ChatMessage(false, prompt));
ChatResponse chatResponse = source.ask(chat); ChatResponse chatResponse = source.ask(chat);
ChatMessage message = chatResponse.message();
ChatEvent token; ChatEvent token;
int i = 0; int[] i = {0}; // this makes no sense
do { message.addEventConsumer(event -> {
token = chatResponse.eventQueue().take(); if (!"error".equals(event.finishReason())) {
if (event.text() != null) {
if (!"error".equals(token.finishReason())) { // this looks bad but at least idea doesn't nag me System.out.print(i[0]++ % 2 == 1 ? "\033[1m" : "\033[0m");
if (token.text() != null) { System.out.print(event.text());
System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m");
System.out.print(token.text());
} }
} else { } else {
System.out.print("Error: " + token.error().toString()); System.out.print("Error: " + event.error().toString());
if (complainedApiKey && token.error() instanceof HttpException && ((HttpException)token.error()).statusCode == 401) { // here was a funny thing, but I had to remove it
System.out.print("\nTold you");
complainedApiKey = false;
}
} }
} while (token.finishReason() == null); }, true);
try {
message.onComplete().join();
} catch (Throwable e) {
e.printStackTrace();
}
System.out.println(); System.out.println();
} else { } else {
String[] parts = prompt.substring(1).split(" "); String[] parts = prompt.substring(1).split(" ");
if (parts[0].startsWith(":")) { if (parts[0].startsWith(":")) {
@ -110,7 +112,7 @@ class Main {
System.out.printf(", excluding system prompt.\nSystem prompt:\n\"\"\"%s\"\"\"\n", chat.systemPrompt); System.out.printf(", excluding system prompt.\nSystem prompt:\n\"\"\"%s\"\"\"\n", chat.systemPrompt);
} }
for (ChatMessage message : chat.messages) { for (ChatMessage message : chat.messages) {
System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text()); System.out.printf("%s: %s\n", message.response() ? "ASSISTANT" : "USER", message.content());
} }
break; break;
case "options": case "options":

View file

@ -1,7 +1,133 @@
package eu.m724.chatapi.chat; package eu.m724.chatapi.chat;
public record ChatMessage(boolean assistant, String text) { import java.util.ArrayList;
public static ChatMessage assistant(String text) { import java.util.HashSet;
return new ChatMessage(true, text); import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
/**
* represents a message, user or assistant, streaming or not<br>
* this is quite a robust class, so don't overuse it
*/
public class ChatMessage {
private final boolean response;
private final CompletableFuture<Void> completedFuture = new CompletableFuture<>();
private final Set<Consumer<ChatEvent>> eventConsumers = new HashSet<>();
private final List<ChatEvent> events = new ArrayList<>();
private String content = ""; // TODO I was thinking about making this abstract so you could put different types like images documents etc
//
/**
* Registers a chat message that can be streamed to
*
* @param response is the message a response
*/
public ChatMessage(boolean response) {
this.response = response;
}
/**
* Creates a completed chat message
*
* @param response is the message a response
* @param content the text
*/
public ChatMessage(boolean response, String content) {
this.response = response;
this.content = content;
completedFuture.complete(null);
//submitEvent(ChatEvent.of(text, "end"));
}
//
/**
* registers a consumer for the {@link ChatEvent}s
*
* @param consumer the consumer
* @param history receive events before this signed up
*/
public void addEventConsumer(Consumer<ChatEvent> consumer, boolean history) {
eventConsumers.add(consumer);
if (history) {
for (ChatEvent event : events)
consumer.accept(event);
}
}
//
/**
* submits a {@link ChatEvent} to broadcast to consumers<br>
* it also makes me (the {@link ChatMessage}) use it to update the final content or if streaming ended
*
* @param event the event
*/
public void submitEvent(ChatEvent event) {
if (event.text() != null) {
content += event.text();
}
events.add(event);
eventConsumers.forEach(c -> {
try {
c.accept(event);
} catch (Throwable e) {
System.err.println("Error distributing event:");
e.printStackTrace();
}
});
if (event.error() != null) {
completedFuture.completeExceptionally(event.error());
} else if (event.finishReason() != null) {
completedFuture.complete(null);
}
}
//
/**
* returns this message's content<br>
* if called during a response, it will return what's already been generated
*
* @return this message's content
*/
public String content() {
return content;
}
/**
* whether this message is a response<br>
* a response is usually sent by the assistant, not by the user
*
* @return whether this message is a response
*/
public boolean response() {
return response;
}
/**
* a future that completes when the response is complete<br>
* it can also throw an error
*
* @return a future that completes when the response is complete
*/
public CompletableFuture<Void> onComplete() {
return completedFuture;
}
/**
* @return is the response complete
*/
public boolean completed() {
return completedFuture.isDone();
} }
} }

View file

@ -1,33 +1,15 @@
package eu.m724.chatapi.source; package eu.m724.chatapi.source;
import eu.m724.chatapi.chat.ChatEvent;
import eu.m724.chatapi.chat.ChatMessage; import eu.m724.chatapi.chat.ChatMessage;
import java.util.concurrent.CompletableFuture; public class ChatResponse {
import java.util.concurrent.LinkedBlockingQueue; private final ChatMessage message;
public interface ChatResponse { public ChatResponse(ChatMessage message) {
/** this.message = message;
* is this response streaming }
* if it's not, the queue will get one element that is the whole response
* I think about replacing with an abstract class and put this in the constructor
*
* @return is this response streaming
*/
boolean streaming();
/** public ChatMessage message() {
* if streamed, text token by token as it goes (or other splitting depending on the source) return message;
* if not, the {@link CompletableFuture} returns just the whole response after it's ready }
*
* @return the fifo queue with each element being a part. null ends the sequence
*/
LinkedBlockingQueue<ChatEvent> eventQueue();
/**
* gets the resulting {@link ChatMessage} when it's ready
*
* @return the resulting {@link ChatMessage} as soon as the response is complete
*/
CompletableFuture<ChatMessage> message();
} }

View file

@ -31,10 +31,7 @@ public interface ChatSource {
*/ */
default ChatResponse ask(Chat chat) { default ChatResponse ask(Chat chat) {
ChatResponse chatResponse = onAsked(chat); ChatResponse chatResponse = onAsked(chat);
chat.addMessage(chatResponse.message());
chatResponse.message().thenAccept(message -> {
if (message != null) chat.addMessage(message);
});
return chatResponse; return chatResponse;
} }

View file

@ -0,0 +1,20 @@
package eu.m724.chatapi.source.impl;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.Consumer;
public class BlockingQueueConsumer<T> {
public final LinkedBlockingQueue<T> queue = new LinkedBlockingQueue<>();
public final Consumer<T> consumer = new Consumer<>() {
@Override
public void accept(T t) {
try {
queue.put(t);
} catch (InterruptedException e) {
// again I don't know how that is relevant
throw new RuntimeException(e);
}
}
};
}

View file

@ -1,60 +0,0 @@
package eu.m724.chatapi.source.impl;
import eu.m724.chatapi.chat.ChatEvent;
import eu.m724.chatapi.chat.ChatMessage;
import eu.m724.chatapi.source.ChatResponse;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue;
public class NonStreamingChatResponse implements ChatResponse {
private final LinkedBlockingQueue<ChatEvent> eventQueue = new LinkedBlockingQueue<>();
private final CompletableFuture<ChatMessage> message = new CompletableFuture<>();
@Override
public boolean streaming() {
return false;
}
@Override
public LinkedBlockingQueue<ChatEvent> eventQueue() {
return eventQueue;
}
@Override
public CompletableFuture<ChatMessage> message() {
return message;
}
public boolean complete(String content) {
if (message.isDone()) return false;
try {
eventQueue.put(ChatEvent.of(content, "stop"));
} catch (InterruptedException e) {
// I don't know what this exception means
// and I don't think how will it cause me problems
// so ignoring it for now
throw new RuntimeException(e);
}
message.complete(new ChatMessage(true, content));
return true;
}
public boolean completeExceptionally(Throwable throwable) {
if (message.isDone()) return false;
try {
eventQueue.put(ChatEvent.of(throwable));
} catch (InterruptedException e) {
// again
throw new RuntimeException(e);
}
message.complete(null);
return true;
}
}

View file

@ -1,77 +0,0 @@
package eu.m724.chatapi.source.impl;
import eu.m724.chatapi.chat.ChatEvent;
import eu.m724.chatapi.chat.ChatMessage;
import eu.m724.chatapi.source.ChatResponse;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue;
public class StreamingChatResponse implements ChatResponse {
private final boolean streaming;
private final LinkedBlockingQueue<ChatEvent> eventQueue;
private final CompletableFuture<ChatMessage> message;
private String total = "";
public StreamingChatResponse() {
this.streaming = true;
this.eventQueue = new LinkedBlockingQueue<>();
this.message = new CompletableFuture<>();
}
public void put(String token, String finishReason) {
try {
eventQueue.put(ChatEvent.of(token, finishReason));
} catch (InterruptedException e) {
// I don't know what this exception means
// and I don't think how will it cause me problems
// so ignoring it for now
throw new RuntimeException(e);
}
if (token != null) {
total += token;
}
if (finishReason != null) {
message.complete(ChatMessage.assistant(total));
}
}
public void put(String token) {
put(token, null);
}
public void end(String finishReason) {
put(null, finishReason);
}
public void error(Throwable throwable) {
try {
eventQueue.put(ChatEvent.of(throwable));
} catch (InterruptedException e) {
// again
throw new RuntimeException(e);
}
message.complete(ChatMessage.assistant(total));
}
//
@Override
public boolean streaming() {
return streaming;
}
@Override
public LinkedBlockingQueue<ChatEvent> eventQueue() {
return eventQueue;
}
@Override
public CompletableFuture<ChatMessage> message() {
return message;
}
}