big revamp
This commit is contained in:
parent
ca7ebbdc2e
commit
6bfd51baa3
9 changed files with 194 additions and 222 deletions
|
@ -10,8 +10,6 @@ import eu.m724.chatapi.source.option.Options
|
|||
import eu.m724.chatapi.source.option.StringOption
|
||||
|
||||
import java.util.concurrent.CompletableFuture
|
||||
import java.util.concurrent.LinkedBlockingQueue
|
||||
|
||||
/**
|
||||
* an example chatresponsesource chatresponsesource ChatResponseSource CHATRESPONSESOURCE CAHTSERREPOSNECSOURCE
|
||||
* note to self: rename that already...
|
||||
|
@ -40,48 +38,30 @@ class ExampleSource implements ChatSource {
|
|||
ChatResponse onAsked(Chat chat) {
|
||||
//String prompt = chat.messages.getLast().text();
|
||||
String response = rollResponse(chat)
|
||||
|
||||
String[] parts = response.split(" ")
|
||||
|
||||
LinkedBlockingQueue<ChatEvent> queue = new LinkedBlockingQueue<>()
|
||||
ChatMessage message = new ChatMessage(true);
|
||||
|
||||
CompletableFuture<ChatMessage> future = CompletableFuture.supplyAsync {
|
||||
CompletableFuture.supplyAsync {
|
||||
for (int i=0; i<parts.length; i++) {
|
||||
String token = (i > 0 ? " " : "") + parts[i]
|
||||
queue.put(ChatEvent.of(token));
|
||||
message.submitEvent(ChatEvent.of(token))
|
||||
Thread.sleep(random.nextInt(50, 200))
|
||||
}
|
||||
|
||||
queue.put(ChatEvent.finished("stop"))
|
||||
return new ChatMessage(true, parts.join(" "))
|
||||
message.submitEvent(ChatEvent.finished("stop"))
|
||||
}
|
||||
|
||||
return new ChatResponse() {
|
||||
@Override
|
||||
boolean streaming() {
|
||||
return false
|
||||
}
|
||||
|
||||
@Override
|
||||
LinkedBlockingQueue<ChatEvent> eventQueue() {
|
||||
return queue
|
||||
}
|
||||
|
||||
@Override
|
||||
CompletableFuture<ChatMessage> message() {
|
||||
return future
|
||||
}
|
||||
}
|
||||
return new ChatResponse(message)
|
||||
}
|
||||
|
||||
String rollResponse(Chat chat) {
|
||||
def special = -1
|
||||
def prompt = chat.messages.getLast().text()
|
||||
def prompt = chat.messages.getLast().content()
|
||||
int messagesCount = (int) Math.ceil(chat.messages.size() / 2)
|
||||
def counter = lastCounter
|
||||
def response = ""
|
||||
|
||||
|
||||
if (prompt.toLowerCase().startsWith("my name is")) {
|
||||
options.setValue("name", prompt.substring(11))
|
||||
counter = 11
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
package eu.m724.chatapi.example
|
||||
|
||||
import eu.m724.chatapi.chat.Chat
|
||||
import eu.m724.chatapi.chat.ChatEvent
|
||||
import eu.m724.chatapi.chat.ChatMessage
|
||||
import eu.m724.chatapi.source.ChatResponse
|
||||
import eu.m724.chatapi.source.ChatSource
|
||||
import eu.m724.chatapi.source.ChatSourceInfo
|
||||
import eu.m724.chatapi.source.exception.HttpException
|
||||
import eu.m724.chatapi.source.impl.StreamingChatResponse
|
||||
import eu.m724.chatapi.source.option.DoubleOption
|
||||
import eu.m724.chatapi.source.option.Options
|
||||
import eu.m724.chatapi.source.option.StringOption
|
||||
|
@ -65,7 +66,7 @@ class OaiSource implements ChatSource {
|
|||
|
||||
def client = HttpClient.newHttpClient()
|
||||
|
||||
def chatResponse = new StreamingChatResponse()
|
||||
def chatMessage = new ChatMessage(true)
|
||||
|
||||
def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofLines())
|
||||
|
||||
|
@ -96,21 +97,22 @@ class OaiSource implements ChatSource {
|
|||
def json = new JSONObject(data)
|
||||
def choice = json.getJSONArray("choices").getJSONObject(0)
|
||||
def finishReason = choice.get("finish_reason")
|
||||
ChatEvent event
|
||||
|
||||
if (finishReason != JSONObject.NULL) {
|
||||
//System.out.println("ending");
|
||||
chatResponse.end(finishReason.toString())
|
||||
event = ChatEvent.finished(finishReason.toString())
|
||||
} else {
|
||||
def token = choice.getJSONObject("delta").getString("content")
|
||||
chatResponse.put(token)
|
||||
event = ChatEvent.of(token)
|
||||
}
|
||||
chatMessage.submitEvent(event)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return chatResponse
|
||||
return new ChatResponse(chatMessage)
|
||||
}
|
||||
|
||||
static JSONArray formatChat(Chat chat) {
|
||||
|
@ -121,7 +123,7 @@ class OaiSource implements ChatSource {
|
|||
}
|
||||
|
||||
chat.messages.each {
|
||||
array.put(new JSONObject().put("role", it.assistant() ? "assistant" : "user").put("content", it.text()))
|
||||
array.put(new JSONObject().put("role", it.response() ? "assistant" : "user").put("content", it.content()))
|
||||
}
|
||||
|
||||
return array
|
||||
|
|
|
@ -6,7 +6,6 @@ import eu.m724.chatapi.chat.ChatMessage;
|
|||
import eu.m724.chatapi.example.OaiSource;
|
||||
import eu.m724.chatapi.source.ChatResponse;
|
||||
import eu.m724.chatapi.source.ChatSource;
|
||||
import eu.m724.chatapi.source.exception.HttpException;
|
||||
import eu.m724.chatapi.source.option.Option;
|
||||
import eu.m724.chatapi.source.option.Options;
|
||||
|
||||
|
@ -19,7 +18,6 @@ class Main {
|
|||
Scanner scanner = new Scanner(System.in);
|
||||
|
||||
String apiKey = System.getenv("API_KEY");
|
||||
boolean complainedApiKey = false;
|
||||
|
||||
if (apiKey == null) {
|
||||
System.out.print("\nAPI Key: ");
|
||||
|
@ -32,10 +30,9 @@ class Main {
|
|||
|
||||
if (!Pattern.matches("sk-proj-.*?(?:\\s|$)", apiKey)) {
|
||||
System.out.println("This key looks invalid");
|
||||
complainedApiKey = true;
|
||||
}
|
||||
|
||||
//ChatSource source = new ExampleSource();
|
||||
// ChatSource source = new ExampleSource();
|
||||
ChatSource source = new OaiSource(apiKey);
|
||||
source.options().setValue("model", "chatgpt-4o-latest");
|
||||
|
||||
|
@ -59,28 +56,33 @@ class Main {
|
|||
|
||||
if (!prompt.startsWith(":")) {
|
||||
chat.messages.add(new ChatMessage(false, prompt));
|
||||
|
||||
ChatResponse chatResponse = source.ask(chat);
|
||||
ChatMessage message = chatResponse.message();
|
||||
|
||||
ChatEvent token;
|
||||
|
||||
int i = 0;
|
||||
do {
|
||||
token = chatResponse.eventQueue().take();
|
||||
|
||||
if (!"error".equals(token.finishReason())) { // this looks bad but at least idea doesn't nag me
|
||||
if (token.text() != null) {
|
||||
System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m");
|
||||
System.out.print(token.text());
|
||||
int[] i = {0}; // this makes no sense
|
||||
message.addEventConsumer(event -> {
|
||||
if (!"error".equals(event.finishReason())) {
|
||||
if (event.text() != null) {
|
||||
System.out.print(i[0]++ % 2 == 1 ? "\033[1m" : "\033[0m");
|
||||
System.out.print(event.text());
|
||||
}
|
||||
} else {
|
||||
System.out.print("Error: " + token.error().toString());
|
||||
if (complainedApiKey && token.error() instanceof HttpException && ((HttpException)token.error()).statusCode == 401) {
|
||||
System.out.print("\nTold you");
|
||||
complainedApiKey = false;
|
||||
}
|
||||
System.out.print("Error: " + event.error().toString());
|
||||
// here was a funny thing, but I had to remove it
|
||||
}
|
||||
} while (token.finishReason() == null);
|
||||
}, true);
|
||||
|
||||
try {
|
||||
message.onComplete().join();
|
||||
} catch (Throwable e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
System.out.println();
|
||||
|
||||
} else {
|
||||
String[] parts = prompt.substring(1).split(" ");
|
||||
if (parts[0].startsWith(":")) {
|
||||
|
@ -110,7 +112,7 @@ class Main {
|
|||
System.out.printf(", excluding system prompt.\nSystem prompt:\n\"\"\"%s\"\"\"\n", chat.systemPrompt);
|
||||
}
|
||||
for (ChatMessage message : chat.messages) {
|
||||
System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text());
|
||||
System.out.printf("%s: %s\n", message.response() ? "ASSISTANT" : "USER", message.content());
|
||||
}
|
||||
break;
|
||||
case "options":
|
||||
|
|
|
@ -1,7 +1,133 @@
|
|||
package eu.m724.chatapi.chat;
|
||||
|
||||
public record ChatMessage(boolean assistant, String text) {
|
||||
public static ChatMessage assistant(String text) {
|
||||
return new ChatMessage(true, text);
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* represents a message, user or assistant, streaming or not<br>
|
||||
* this is quite a robust class, so don't overuse it
|
||||
*/
|
||||
public class ChatMessage {
|
||||
private final boolean response;
|
||||
private final CompletableFuture<Void> completedFuture = new CompletableFuture<>();
|
||||
private final Set<Consumer<ChatEvent>> eventConsumers = new HashSet<>();
|
||||
|
||||
private final List<ChatEvent> events = new ArrayList<>();
|
||||
|
||||
private String content = ""; // TODO I was thinking about making this abstract so you could put different types like images documents etc
|
||||
|
||||
//
|
||||
|
||||
/**
|
||||
* Registers a chat message that can be streamed to
|
||||
*
|
||||
* @param response is the message a response
|
||||
*/
|
||||
public ChatMessage(boolean response) {
|
||||
this.response = response;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a completed chat message
|
||||
*
|
||||
* @param response is the message a response
|
||||
* @param content the text
|
||||
*/
|
||||
public ChatMessage(boolean response, String content) {
|
||||
this.response = response;
|
||||
this.content = content;
|
||||
|
||||
completedFuture.complete(null);
|
||||
//submitEvent(ChatEvent.of(text, "end"));
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
/**
|
||||
* registers a consumer for the {@link ChatEvent}s
|
||||
*
|
||||
* @param consumer the consumer
|
||||
* @param history receive events before this signed up
|
||||
*/
|
||||
public void addEventConsumer(Consumer<ChatEvent> consumer, boolean history) {
|
||||
eventConsumers.add(consumer);
|
||||
if (history) {
|
||||
for (ChatEvent event : events)
|
||||
consumer.accept(event);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
/**
|
||||
* submits a {@link ChatEvent} to broadcast to consumers<br>
|
||||
* it also makes me (the {@link ChatMessage}) use it to update the final content or if streaming ended
|
||||
*
|
||||
* @param event the event
|
||||
*/
|
||||
public void submitEvent(ChatEvent event) {
|
||||
if (event.text() != null) {
|
||||
content += event.text();
|
||||
}
|
||||
|
||||
events.add(event);
|
||||
eventConsumers.forEach(c -> {
|
||||
try {
|
||||
c.accept(event);
|
||||
} catch (Throwable e) {
|
||||
System.err.println("Error distributing event:");
|
||||
e.printStackTrace();
|
||||
}
|
||||
});
|
||||
|
||||
if (event.error() != null) {
|
||||
completedFuture.completeExceptionally(event.error());
|
||||
} else if (event.finishReason() != null) {
|
||||
completedFuture.complete(null);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
/**
|
||||
* returns this message's content<br>
|
||||
* if called during a response, it will return what's already been generated
|
||||
*
|
||||
* @return this message's content
|
||||
*/
|
||||
public String content() {
|
||||
return content;
|
||||
}
|
||||
|
||||
/**
|
||||
* whether this message is a response<br>
|
||||
* a response is usually sent by the assistant, not by the user
|
||||
*
|
||||
* @return whether this message is a response
|
||||
*/
|
||||
public boolean response() {
|
||||
return response;
|
||||
}
|
||||
|
||||
/**
|
||||
* a future that completes when the response is complete<br>
|
||||
* it can also throw an error
|
||||
*
|
||||
* @return a future that completes when the response is complete
|
||||
*/
|
||||
public CompletableFuture<Void> onComplete() {
|
||||
return completedFuture;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return is the response complete
|
||||
*/
|
||||
public boolean completed() {
|
||||
return completedFuture.isDone();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,33 +1,15 @@
|
|||
package eu.m724.chatapi.source;
|
||||
|
||||
import eu.m724.chatapi.chat.ChatEvent;
|
||||
import eu.m724.chatapi.chat.ChatMessage;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
public class ChatResponse {
|
||||
private final ChatMessage message;
|
||||
|
||||
public interface ChatResponse {
|
||||
/**
|
||||
* is this response streaming
|
||||
* if it's not, the queue will get one element that is the whole response
|
||||
* I think about replacing with an abstract class and put this in the constructor
|
||||
*
|
||||
* @return is this response streaming
|
||||
*/
|
||||
boolean streaming();
|
||||
public ChatResponse(ChatMessage message) {
|
||||
this.message = message;
|
||||
}
|
||||
|
||||
/**
|
||||
* if streamed, text token by token as it goes (or other splitting depending on the source)
|
||||
* if not, the {@link CompletableFuture} returns just the whole response after it's ready
|
||||
*
|
||||
* @return the fifo queue with each element being a part. null ends the sequence
|
||||
*/
|
||||
LinkedBlockingQueue<ChatEvent> eventQueue();
|
||||
|
||||
/**
|
||||
* gets the resulting {@link ChatMessage} when it's ready
|
||||
*
|
||||
* @return the resulting {@link ChatMessage} as soon as the response is complete
|
||||
*/
|
||||
CompletableFuture<ChatMessage> message();
|
||||
public ChatMessage message() {
|
||||
return message;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,10 +31,7 @@ public interface ChatSource {
|
|||
*/
|
||||
default ChatResponse ask(Chat chat) {
|
||||
ChatResponse chatResponse = onAsked(chat);
|
||||
|
||||
chatResponse.message().thenAccept(message -> {
|
||||
if (message != null) chat.addMessage(message);
|
||||
});
|
||||
chat.addMessage(chatResponse.message());
|
||||
|
||||
return chatResponse;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
package eu.m724.chatapi.source.impl;
|
||||
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class BlockingQueueConsumer<T> {
|
||||
public final LinkedBlockingQueue<T> queue = new LinkedBlockingQueue<>();
|
||||
|
||||
public final Consumer<T> consumer = new Consumer<>() {
|
||||
@Override
|
||||
public void accept(T t) {
|
||||
try {
|
||||
queue.put(t);
|
||||
} catch (InterruptedException e) {
|
||||
// again I don't know how that is relevant
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
package eu.m724.chatapi.source.impl;
|
||||
|
||||
import eu.m724.chatapi.chat.ChatEvent;
|
||||
import eu.m724.chatapi.chat.ChatMessage;
|
||||
import eu.m724.chatapi.source.ChatResponse;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
|
||||
public class NonStreamingChatResponse implements ChatResponse {
|
||||
private final LinkedBlockingQueue<ChatEvent> eventQueue = new LinkedBlockingQueue<>();
|
||||
private final CompletableFuture<ChatMessage> message = new CompletableFuture<>();
|
||||
|
||||
@Override
|
||||
public boolean streaming() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public LinkedBlockingQueue<ChatEvent> eventQueue() {
|
||||
return eventQueue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletableFuture<ChatMessage> message() {
|
||||
return message;
|
||||
}
|
||||
|
||||
public boolean complete(String content) {
|
||||
if (message.isDone()) return false;
|
||||
|
||||
try {
|
||||
eventQueue.put(ChatEvent.of(content, "stop"));
|
||||
} catch (InterruptedException e) {
|
||||
// I don't know what this exception means
|
||||
// and I don't think how will it cause me problems
|
||||
// so ignoring it for now
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
message.complete(new ChatMessage(true, content));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public boolean completeExceptionally(Throwable throwable) {
|
||||
if (message.isDone()) return false;
|
||||
|
||||
try {
|
||||
eventQueue.put(ChatEvent.of(throwable));
|
||||
} catch (InterruptedException e) {
|
||||
// again
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
message.complete(null);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -1,77 +0,0 @@
|
|||
package eu.m724.chatapi.source.impl;
|
||||
|
||||
import eu.m724.chatapi.chat.ChatEvent;
|
||||
import eu.m724.chatapi.chat.ChatMessage;
|
||||
import eu.m724.chatapi.source.ChatResponse;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
|
||||
public class StreamingChatResponse implements ChatResponse {
|
||||
private final boolean streaming;
|
||||
private final LinkedBlockingQueue<ChatEvent> eventQueue;
|
||||
private final CompletableFuture<ChatMessage> message;
|
||||
|
||||
private String total = "";
|
||||
|
||||
public StreamingChatResponse() {
|
||||
this.streaming = true;
|
||||
this.eventQueue = new LinkedBlockingQueue<>();
|
||||
this.message = new CompletableFuture<>();
|
||||
}
|
||||
|
||||
public void put(String token, String finishReason) {
|
||||
try {
|
||||
eventQueue.put(ChatEvent.of(token, finishReason));
|
||||
} catch (InterruptedException e) {
|
||||
// I don't know what this exception means
|
||||
// and I don't think how will it cause me problems
|
||||
// so ignoring it for now
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
if (token != null) {
|
||||
total += token;
|
||||
}
|
||||
|
||||
if (finishReason != null) {
|
||||
message.complete(ChatMessage.assistant(total));
|
||||
}
|
||||
}
|
||||
|
||||
public void put(String token) {
|
||||
put(token, null);
|
||||
}
|
||||
|
||||
public void end(String finishReason) {
|
||||
put(null, finishReason);
|
||||
}
|
||||
|
||||
public void error(Throwable throwable) {
|
||||
try {
|
||||
eventQueue.put(ChatEvent.of(throwable));
|
||||
} catch (InterruptedException e) {
|
||||
// again
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
message.complete(ChatMessage.assistant(total));
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@Override
|
||||
public boolean streaming() {
|
||||
return streaming;
|
||||
}
|
||||
|
||||
@Override
|
||||
public LinkedBlockingQueue<ChatEvent> eventQueue() {
|
||||
return eventQueue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletableFuture<ChatMessage> message() {
|
||||
return message;
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue