Compare commits

..

3 commits

Author SHA1 Message Date
1dcf09614e
make it stream
and other changes
2024-08-31 19:12:55 +02:00
0f59b4d09c
do this 2024-08-31 17:38:52 +02:00
63201d9888
stuff 2024-08-31 17:34:57 +02:00
8 changed files with 187 additions and 47 deletions

View file

@ -1,3 +1 @@
This is a Java library you can write scripts for to support an api with conversational language models A scriptable Java library for chatbots
I can't really say much because I'm still working out stuff so see `thinkings` directory

View file

@ -3,20 +3,14 @@ package eu.m724;
import eu.m724.chat.Chat; import eu.m724.chat.Chat;
import eu.m724.chat.ChatEvent; import eu.m724.chat.ChatEvent;
import eu.m724.chat.ChatMessage; import eu.m724.chat.ChatMessage;
import eu.m724.example.ExampleSource;
import eu.m724.example.OaiSource; import eu.m724.example.OaiSource;
import eu.m724.source.ChatResponse; import eu.m724.source.ChatResponse;
import eu.m724.source.ChatSource; import eu.m724.source.ChatSource;
import eu.m724.source.option.Option; import eu.m724.source.option.Option;
import java.io.BufferedReader; import java.util.Arrays;
import java.io.InputStream; import java.util.Map;
import java.io.InputStreamReader; import java.util.Scanner;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
public class Main { public class Main {
public static void main(String[] args) throws InterruptedException { public static void main(String[] args) throws InterruptedException {
@ -24,11 +18,13 @@ public class Main {
//source.options().setValue("name", "nekalakininahappenenawiwanatin"); //source.options().setValue("name", "nekalakininahappenenawiwanatin");
ChatSource source = new OaiSource(System.getenv("API_KEY")); ChatSource source = new OaiSource(System.getenv("API_KEY"));
//source.options().setValue("model", "chatgpt-4o-latest"); source.options().setValue("model", "chatgpt-4o-latest");
Chat chat = new Chat("Speak in uwu wanguage."); Chat chat = new Chat("Speak in super wuper uwu wanguage.");
System.out.println("Welcome to CHUT chat. Say something after the \033[1m>\033[0m, or type \033[1m:help\033[0m to see available commands"); System.out.println("Welcome to CHUT chat. Say something after the \033[1m>\033[0m, or type \033[1m:help\033[0m to see available commands");
System.out.printf("Source: \033[1m%s\033[0m %s (%d) by %s\n", source.info().name(), source.info().versionName(), source.info().version(), source.info().author());
Scanner scanner = new Scanner(System.in); Scanner scanner = new Scanner(System.in);
while (true) { while (true) {
@ -45,8 +41,10 @@ public class Main {
token = chatResponse.eventQueue().take(); token = chatResponse.eventQueue().take();
if (token.finishReason() != "error") { if (token.finishReason() != "error") {
System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m"); if (token.text() != null) {
System.out.print(token.text()); System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m");
System.out.print(token.text());
}
} else { } else {
System.out.print("Error: " + token.error().toString()); System.out.print("Error: " + token.error().toString());
} }
@ -55,24 +53,31 @@ public class Main {
System.out.println(); System.out.println();
} else { } else {
String[] parts = prompt.substring(1).split(" "); String[] parts = prompt.substring(1).split(" ");
if (parts[0].startsWith(":")) {
System.out.println("If you want to start a message with a \033[1m:\033[0m, you can't");
}
switch (parts[0]) { switch (parts[0]) {
case "help": case "help":
case "h": case "h":
case "": case "":
System.out.println("Source: " + source.getClass().getName());
System.out.println("Available commands:"); System.out.println("Available commands:");
System.out.println(":help - this"); System.out.println(":help - this");
System.out.println(":dump - recap of the chat"); System.out.println(":dump - recap of the chat");
System.out.println(":opt - change source options"); System.out.println(":opt - change source options");
System.out.println(":system - system prompt");
System.out.println("Most commands have a few abbreviations like \033[1m:s\033[0m or \033[1m:sys\033[0m for :system");
break; break;
case "dump": case "dump":
case "d": case "d":
int size = chat.messages.size();
System.out.printf("This chat has %d messages, or %d pairs", size, size / 2);
if (chat.systemPrompt == null) { if (chat.systemPrompt == null) {
System.out.printf("This chat has %d messages.\n", chat.messages.size()); System.out.println(".\nThere's no system prompt.");
System.out.println("There's no system prompt.");
} else { } else {
System.out.printf("This chat has %d messages, excluding system prompt.\n", chat.messages.size()); System.out.printf(", excluding system prompt.\nSystem prompt:\n\"\"\"%s\"\"\"\n", chat.systemPrompt);
System.out.printf("System prompt:\n\"\"\"%s\"\"\"\n", chat.systemPrompt);
} }
for (ChatMessage message : chat.messages) { for (ChatMessage message : chat.messages) {
System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text()); System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text());
@ -81,20 +86,66 @@ public class Main {
case "options": case "options":
case "opt": case "opt":
case "o": case "o":
if (parts.length < 3 || Arrays.stream(source.options().keys()).noneMatch(parts[1]::equals)) { boolean shouldShowOptions = parts.length < 3;
boolean optionExists = parts.length > 1 && Arrays.asList(source.options().keys()).contains(parts[1]);
boolean complainNoOption = parts.length > 1 && !optionExists;
String chosenOption = parts.length > 1 ? parts[1] : null;
if (!shouldShowOptions) {
Option<?> option = source.options().getOptions().get(parts[1]);
if (option != null) {
Object value = option.fromString(parts[2]);
option.setValue(value);
System.out.printf("Set %s to %s\n", option.label, option.getValue());
} else {
shouldShowOptions = true;
System.out.printf("Unknown option \"%s\". ", parts[1]);
}
}
if (shouldShowOptions) {
if (complainNoOption)
System.out.printf("Unknown option \"%s\". ", chosenOption);
System.out.println("Available options:"); System.out.println("Available options:");
for (Map.Entry<String, Option<?>> entry : source.options().getOptions().entrySet()) { for (Map.Entry<String, Option<?>> entry : source.options().getOptions().entrySet()) {
System.out.printf("%s (%s) = %s\n", entry.getValue().label, entry.getKey(), entry.getValue().getValue().toString()); String value = entry.getValue().getValue().toString() + " (" + entry.getValue().getType().getName() + ")";
if (entry.getKey().equals(chosenOption)) {
System.out.printf("\033[1m%s (%s) = %s\033[0m\n", entry.getValue().label, entry.getKey(), value);
} else {
if (entry.getValue().label.toLowerCase().contains("key")) {
value = "(looks confidential, specify to see)";
}
System.out.printf("%s (%s) = %s\n", entry.getValue().label, entry.getKey(), value);
}
}
}
break;
case "system":
case "sys":
case "s":
if (parts.length == 1) {
if (chat.systemPrompt != null) {
System.out.printf("System prompt:\n\033[1m%s\033[0m\n\n", chat.systemPrompt);
System.out.println("Set to \033[1mnull\033[0m to remove");
} else {
System.out.println("No system prompt");
} }
} else { } else {
Option<?> option = source.options().getOptions().get(parts[1]); System.out.printf("Previous system prompt:\n%s\n\n", chat.systemPrompt);
Object value = option.fromString(parts[2]); if (parts[1].equals("null")) {
option.setValue(value); chat.systemPrompt = null;
System.out.printf("Set %s to %s\n", option.label, option.getValue()); System.out.println("System prompt removed");
} else {
chat.systemPrompt = prompt.substring(parts[0].length() + 2).replace("\\n", "\n");
System.out.printf("New system prompt:\n\033[1m%s\033[0m\n", chat.systemPrompt);
}
} }
break; break;
default: default:
System.out.println("Invalid command: " + parts[0]); System.out.println("Invalid command: \033[1m" + parts[0] + "\033[0m");
break; break;
} }
} }

View file

@ -1,7 +1,7 @@
package eu.m724.chat; package eu.m724.chat;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow;
public record ChatMessage(boolean assistant, String text) { public record ChatMessage(boolean assistant, String text) {
public static ChatMessage assistant(String text) {
return new ChatMessage(true, text);
}
} }

View file

@ -1,15 +1,11 @@
package eu.m724.example package eu.m724.example
import eu.m724.chat.Chat import eu.m724.chat.Chat
import eu.m724.chat.ChatEvent
import eu.m724.chat.ChatMessage
import eu.m724.source.ChatResponse import eu.m724.source.ChatResponse
import eu.m724.source.ChatSource import eu.m724.source.ChatSource
import eu.m724.source.ChatSourceInfo import eu.m724.source.ChatSourceInfo
import eu.m724.source.NonStreamingChatResponse
import eu.m724.source.SimpleChatResponse import eu.m724.source.SimpleChatResponse
import eu.m724.source.option.DoubleOption import eu.m724.source.option.DoubleOption
import eu.m724.source.option.NumberOption
import eu.m724.source.option.Options import eu.m724.source.option.Options
import eu.m724.source.option.StringOption import eu.m724.source.option.StringOption
import org.json.JSONArray import org.json.JSONArray
@ -18,9 +14,6 @@ import org.json.JSONObject
import java.net.http.HttpClient import java.net.http.HttpClient
import java.net.http.HttpRequest import java.net.http.HttpRequest
import java.net.http.HttpResponse import java.net.http.HttpResponse
import java.util.concurrent.CompletableFuture
import java.util.concurrent.LinkedBlockingQueue
// for now let's not focus on readability // for now let's not focus on readability
// this is more about find out what is common and should be included in the common api // this is more about find out what is common and should be included in the common api
class OaiSource implements ChatSource { class OaiSource implements ChatSource {
@ -36,7 +29,7 @@ class OaiSource implements ChatSource {
private def options = new Options( private def options = new Options(
new StringOption("apiKey", "API key", apiKey), new StringOption("apiKey", "API key", apiKey),
new StringOption("model", "Model", "gpt-4o-mini"), new StringOption("model", "Model", "gpt-4o-mini"),
new DoubleOption("temperature", "Temperature", 1.2) new DoubleOption("temperature", "Temperature", 1.1)
) )
@Override @Override
@ -56,7 +49,7 @@ class OaiSource implements ChatSource {
.put("model", options.getStringValue("model")) .put("model", options.getStringValue("model"))
.put("temperature", options.getOptionValue("temperature", Double::class)) .put("temperature", options.getOptionValue("temperature", Double::class))
.put("presence_penalty", 0.1) .put("presence_penalty", 0.1)
.put("frequency_penalty", 0.1) .put("stream", true)
.put("messages", formatChat(chat)).toString() .put("messages", formatChat(chat)).toString()
def request = HttpRequest.newBuilder() def request = HttpRequest.newBuilder()
@ -68,9 +61,10 @@ class OaiSource implements ChatSource {
def client = HttpClient.newHttpClient() def client = HttpClient.newHttpClient()
ChatResponse chatResponse = new NonStreamingChatResponse() def chatResponse = new SimpleChatResponse()
def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofLines())
def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofString())
response.thenAccept { response.thenAccept {
Exception exception = null Exception exception = null
@ -79,6 +73,30 @@ class OaiSource implements ChatSource {
exception = new Exception("Non 200 status code: %d".formatted(it.statusCode())) exception = new Exception("Non 200 status code: %d".formatted(it.statusCode()))
} }
it.body().forEach {
//System.out.println(it);
for (String line : it.split("\n")) {
if (line.startsWith("data: ")) {
def data = line.substring(6)
if (data != "[DONE]") {
def json = new JSONObject(data)
def choice = json.getJSONArray("choices").getJSONObject(0)
def finishReason = choice.get("finish_reason")
if (finishReason != JSONObject.NULL) {
//System.out.println("ending");
chatResponse.end(finishReason.toString())
} else {
def token = choice.getJSONObject("delta").getString("content")
chatResponse.put(token)
}
}
}
}
}
/*
if (exception != null) { if (exception != null) {
chatResponse.completeExceptionally(exception) chatResponse.completeExceptionally(exception)
} else { } else {
@ -86,7 +104,7 @@ class OaiSource implements ChatSource {
//System.out.println(json) //System.out.println(json)
def completion = json.getJSONArray("choices").getJSONObject(0).getJSONObject("message").getString("content") def completion = json.getJSONArray("choices").getJSONObject(0).getJSONObject("message").getString("content")
chatResponse.complete(completion) chatResponse.complete(completion)
} }*/
} }

View file

@ -1,7 +1,6 @@
package eu.m724.source; package eu.m724.source;
import eu.m724.chat.Chat; import eu.m724.chat.Chat;
import eu.m724.source.option.Option;
import eu.m724.source.option.Options; import eu.m724.source.option.Options;
public interface ChatSource { public interface ChatSource {
@ -33,6 +32,25 @@ public interface ChatSource {
default ChatResponse ask(Chat chat) { default ChatResponse ask(Chat chat) {
ChatResponse chatResponse = onAsked(chat); ChatResponse chatResponse = onAsked(chat);
/* if (chatResponse.streaming()) {
StringBuilder total = new StringBuilder();
CompletableFuture.runAsync(() -> {
ChatEvent event;
while (true) {
try {
event = chatResponse.eventQueue().take();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
total.append(event.text());
if (event.finishReason() != null) {
chatResponse.message().complete(new ChatMessage(true, total.toString()));
break;
}
}
});
} this was a draft I'm keeping because I might change my mind */
// TODO make sure it works in parallel // TODO make sure it works in parallel
chatResponse.message().thenAccept(message -> { chatResponse.message().thenAccept(message -> {
if (message != null) chat.addMessage(message); if (message != null) chat.addMessage(message);

View file

@ -7,8 +7,8 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
public class NonStreamingChatResponse implements ChatResponse { public class NonStreamingChatResponse implements ChatResponse {
private LinkedBlockingQueue<ChatEvent> eventQueue = new LinkedBlockingQueue<>(); private final LinkedBlockingQueue<ChatEvent> eventQueue = new LinkedBlockingQueue<>();
private CompletableFuture<ChatMessage> message = new CompletableFuture<>(); private final CompletableFuture<ChatMessage> message = new CompletableFuture<>();
@Override @Override
public boolean streaming() { public boolean streaming() {

View file

@ -6,6 +6,57 @@ import eu.m724.chat.ChatMessage;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
public record SimpleChatResponse(boolean streaming, LinkedBlockingQueue<ChatEvent> eventQueue, public class SimpleChatResponse implements ChatResponse {
CompletableFuture<ChatMessage> message) implements ChatResponse { private final boolean streaming;
private final LinkedBlockingQueue<ChatEvent> eventQueue;
private final CompletableFuture<ChatMessage> message;
private String total = "";
public SimpleChatResponse() {
this.streaming = true;
this.eventQueue = new LinkedBlockingQueue<>();
this.message = new CompletableFuture<>();
}
public void put(String token, String finishReason) {
//System.out.println(System.currentTimeMillis());
try {
eventQueue.put(ChatEvent.of(token, finishReason));
} catch (InterruptedException e) {
throw new RuntimeException(e); // TODO I don't know what this is
}
if (token != null) {
total += token;
}
if (finishReason != null) {
message.complete(ChatMessage.assistant(total));
}
}
public void put(String token) {
put(token, null);
}
public void end(String finishReason) {
put(null, finishReason);
}
@Override
public boolean streaming() {
return streaming;
}
@Override
public LinkedBlockingQueue<ChatEvent> eventQueue() {
return eventQueue;
}
@Override
public CompletableFuture<ChatMessage> message() {
return message;
}
} }

View file

@ -3,6 +3,10 @@ so, the api is mostly good now, but new things to focus on:
- how to do requests - how to do requests
- how to make it friendly - how to make it friendly
delay
right now streaming seems delayed and batched, I don't know why
maybe it's just me
network network
every platform has different libraries for that, so abstraction every platform has different libraries for that, so abstraction
how? I still don't know how? I still don't know