186 lines
No EOL
8.5 KiB
Java
186 lines
No EOL
8.5 KiB
Java
package eu.m724;
|
|
|
|
import eu.m724.chat.Chat;
|
|
import eu.m724.chat.ChatEvent;
|
|
import eu.m724.chat.ChatMessage;
|
|
import eu.m724.example.OaiSource;
|
|
import eu.m724.source.ChatResponse;
|
|
import eu.m724.source.ChatSource;
|
|
import eu.m724.source.exception.HttpException;
|
|
import eu.m724.source.option.Option;
|
|
import eu.m724.source.option.Options;
|
|
|
|
import java.util.NoSuchElementException;
|
|
import java.util.Scanner;
|
|
import java.util.regex.Pattern;
|
|
|
|
class Main {
|
|
public static void main(String[] args) throws InterruptedException {
|
|
Scanner scanner = new Scanner(System.in);
|
|
|
|
String apiKey = System.getenv("API_KEY");
|
|
boolean complainedApiKey = false;
|
|
|
|
if (apiKey == null) {
|
|
System.out.print("\nAPI Key: ");
|
|
apiKey = scanner.nextLine();
|
|
if (apiKey.isBlank()) {
|
|
System.out.println("Wrong");
|
|
return;
|
|
}
|
|
}
|
|
|
|
if (!Pattern.matches("sk-proj-.*?(?:\\s|$)", apiKey)) {
|
|
System.out.println("This key looks invalid");
|
|
complainedApiKey = true;
|
|
}
|
|
|
|
ChatSource source = new OaiSource(apiKey);
|
|
source.options().setValue("model", "chatgpt-4o-latest");
|
|
|
|
Chat chat = new Chat("Speak in super wuper uwu wanguage.");
|
|
|
|
System.out.println("\nWelcome to CHUT chat. Say something after the \033[1m>\033[0m, or type \033[1m:help\033[0m to see available commands");
|
|
System.out.println("Working directory: " + System.getProperty("user.dir"));
|
|
System.out.printf("\033[1m%s\033[0m %s by %s (%s v%d)\n", source.info().name(), source.info().versionName(), source.info().author(), source.getClass().getName(), source.info().version());
|
|
|
|
String prompt;
|
|
|
|
while (true) {
|
|
System.out.print("\n> ");
|
|
|
|
try {
|
|
prompt = scanner.nextLine();
|
|
} catch (NoSuchElementException e) {
|
|
System.out.println("Exiting");
|
|
break;
|
|
}
|
|
|
|
if (!prompt.startsWith(":")) {
|
|
chat.messages.add(new ChatMessage(false, prompt));
|
|
ChatResponse chatResponse = source.ask(chat);
|
|
ChatEvent token;
|
|
|
|
int i = 0;
|
|
do {
|
|
token = chatResponse.eventQueue().take();
|
|
|
|
if (!"error".equals(token.finishReason())) { // this looks bad but at least idea doesn't nag me
|
|
if (token.text() != null) {
|
|
System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m");
|
|
System.out.print(token.text());
|
|
}
|
|
} else {
|
|
System.out.print("Error: " + token.error().toString());
|
|
if (complainedApiKey && token.error() instanceof HttpException && ((HttpException)token.error()).statusCode == 401) {
|
|
System.out.print("\nTold you");
|
|
complainedApiKey = false;
|
|
}
|
|
}
|
|
} while (token.finishReason() == null);
|
|
|
|
System.out.println();
|
|
} else {
|
|
String[] parts = prompt.substring(1).split(" ");
|
|
if (parts[0].startsWith(":")) {
|
|
System.out.println("If you want to start a message with a \033[1m:\033[0m, you can't");
|
|
}
|
|
|
|
switch (parts[0]) {
|
|
case "help":
|
|
case "h":
|
|
case "":
|
|
System.out.println("Source: " + source.getClass().getName());
|
|
System.out.println("Available commands:");
|
|
System.out.println(":help - this");
|
|
System.out.println(":dump - recap of the chat");
|
|
System.out.println(":opt - change source options");
|
|
System.out.println(":system - system prompt");
|
|
System.out.println("Most commands have a few abbreviations like \033[1m:s\033[0m or \033[1m:sys\033[0m for :system");
|
|
break;
|
|
case "dump":
|
|
case "d":
|
|
int size = chat.messages.size();
|
|
System.out.printf("This chat has %d messages, or %d pairs", size, size / 2);
|
|
|
|
if (chat.systemPrompt == null) {
|
|
System.out.println(".\nThere's no system prompt.");
|
|
} else {
|
|
System.out.printf(", excluding system prompt.\nSystem prompt:\n\"\"\"%s\"\"\"\n", chat.systemPrompt);
|
|
}
|
|
for (ChatMessage message : chat.messages) {
|
|
System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text());
|
|
}
|
|
break;
|
|
case "options":
|
|
case "opt":
|
|
case "o":
|
|
Options options = source.options();
|
|
|
|
boolean shouldShowOptions = parts.length < 3;
|
|
boolean optionExists = parts.length > 1 && source.options().getKeys().contains(parts[1]);
|
|
boolean complainNoOption = parts.length > 1 && !optionExists;
|
|
String chosenOption = parts.length > 1 ? parts[1] : null;
|
|
|
|
if (!shouldShowOptions) {
|
|
try {
|
|
Option<?> option = options.getOption(parts[1]);
|
|
Object value = option.fromString(parts[2]);
|
|
option.setValue(value);
|
|
System.out.printf("Set %s to %s\n", option.label, option.getValue());
|
|
} catch (NoSuchElementException e) {
|
|
shouldShowOptions = true;
|
|
System.out.printf("Unknown option \"%s\". ", parts[1]);
|
|
}
|
|
}
|
|
|
|
if (shouldShowOptions) {
|
|
if (complainNoOption)
|
|
System.out.printf("Unknown option \"%s\". ", chosenOption);
|
|
System.out.println("Available options:");
|
|
|
|
for (Option<?> option : options.getOptions()) {
|
|
String value = option.toString() + " (" + option.getType().getName() + ")";
|
|
|
|
if (option.id.equals(chosenOption)) {
|
|
System.out.printf("\033[1m%s (%s) = %s\033[0m\n", option.label, option.id, value);
|
|
} else {
|
|
if (option.label.toLowerCase().contains("key")) {
|
|
value = "(looks confidential, specify to see)";
|
|
}
|
|
System.out.printf("%s (%s) = %s\n", option.label, option.id, value);
|
|
}
|
|
}
|
|
}
|
|
|
|
break;
|
|
case "system":
|
|
case "sys":
|
|
case "s":
|
|
if (parts.length == 1) {
|
|
if (chat.systemPrompt != null) {
|
|
System.out.printf("System prompt:\n\033[1m%s\033[0m\n\n", chat.systemPrompt);
|
|
System.out.println("Set to \033[1mnull\033[0m to remove");
|
|
} else {
|
|
System.out.println("No system prompt");
|
|
}
|
|
} else {
|
|
System.out.printf("Previous system prompt:\n%s\n\n", chat.systemPrompt);
|
|
if (parts[1].equals("null")) {
|
|
chat.systemPrompt = null;
|
|
System.out.println("System prompt removed");
|
|
} else {
|
|
chat.systemPrompt = prompt.substring(parts[0].length() + 2).replace("\\n", "\n");
|
|
System.out.printf("New system prompt:\n\033[1m%s\033[0m\n", chat.systemPrompt);
|
|
}
|
|
}
|
|
break;
|
|
default:
|
|
System.out.println("Invalid command: \033[1m" + parts[0] + "\033[0m");
|
|
break;
|
|
}
|
|
}
|
|
System.out.print("\033[0m");
|
|
}
|
|
}
|
|
} |