switches and toggles
This commit is contained in:
parent
1eb880be4f
commit
f75d1ca715
9 changed files with 175 additions and 22 deletions
|
@ -4,27 +4,29 @@ import eu.m724.chat.Chat;
|
||||||
import eu.m724.chat.ChatEvent;
|
import eu.m724.chat.ChatEvent;
|
||||||
import eu.m724.chat.ChatMessage;
|
import eu.m724.chat.ChatMessage;
|
||||||
import eu.m724.example.ExampleSource;
|
import eu.m724.example.ExampleSource;
|
||||||
|
import eu.m724.example.OaiSource;
|
||||||
import eu.m724.source.ChatResponse;
|
import eu.m724.source.ChatResponse;
|
||||||
import eu.m724.source.ChatSource;
|
import eu.m724.source.ChatSource;
|
||||||
|
import eu.m724.source.option.Option;
|
||||||
|
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.io.InputStreamReader;
|
import java.io.InputStreamReader;
|
||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Scanner;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.LongStream;
|
import java.util.stream.LongStream;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
public static void main(String[] args) throws InterruptedException {
|
public static void main(String[] args) throws InterruptedException {
|
||||||
ChatSource source = new ExampleSource();
|
//ChatSource source = new ExampleSource();
|
||||||
source.options().setValue("name", "nekalakininahappenenawiwanatin");
|
//source.options().setValue("name", "nekalakininahappenenawiwanatin");
|
||||||
|
|
||||||
Chat chat = new Chat();
|
ChatSource source = new OaiSource(System.getenv("API_KEY"));
|
||||||
|
//source.options().setValue("model", "chatgpt-4o-latest");
|
||||||
|
|
||||||
|
Chat chat = new Chat("Speak in uwu wanguage.");
|
||||||
|
|
||||||
System.out.println("Welcome to CHUT chat. Say something after the \033[1m>\033[0m, or type \033[1m:help\033[0m to see available commands");
|
System.out.println("Welcome to CHUT chat. Say something after the \033[1m>\033[0m, or type \033[1m:help\033[0m to see available commands");
|
||||||
Scanner scanner = new Scanner(System.in);
|
Scanner scanner = new Scanner(System.in);
|
||||||
|
@ -39,14 +41,21 @@ public class Main {
|
||||||
ChatEvent token;
|
ChatEvent token;
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while ((token = chatResponse.eventQueue().take()).finishReason() == null) {
|
do {
|
||||||
|
token = chatResponse.eventQueue().take();
|
||||||
|
|
||||||
|
if (token.finishReason() != "error") {
|
||||||
System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m");
|
System.out.print(i++ % 2 == 1 ? "\033[1m" : "\033[0m");
|
||||||
System.out.print(token.text());
|
System.out.print(token.text());
|
||||||
|
} else {
|
||||||
|
System.out.print("Error: " + token.error().toString());
|
||||||
}
|
}
|
||||||
|
} while (token.finishReason() == null);
|
||||||
|
|
||||||
System.out.println();
|
System.out.println();
|
||||||
} else {
|
} else {
|
||||||
String[] parts = prompt.substring(1).split(" ");
|
String[] parts = prompt.substring(1).split(" ");
|
||||||
|
|
||||||
switch (parts[0]) {
|
switch (parts[0]) {
|
||||||
case "help":
|
case "help":
|
||||||
case "h":
|
case "h":
|
||||||
|
@ -54,6 +63,7 @@ public class Main {
|
||||||
System.out.println("Available commands:");
|
System.out.println("Available commands:");
|
||||||
System.out.println(":help - this");
|
System.out.println(":help - this");
|
||||||
System.out.println(":dump - recap of the chat");
|
System.out.println(":dump - recap of the chat");
|
||||||
|
System.out.println(":opt - change source options");
|
||||||
break;
|
break;
|
||||||
case "dump":
|
case "dump":
|
||||||
case "d":
|
case "d":
|
||||||
|
@ -68,6 +78,21 @@ public class Main {
|
||||||
System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text());
|
System.out.printf("%s: %s\n", message.assistant() ? "ASSISTANT" : "USER", message.text());
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case "options":
|
||||||
|
case "opt":
|
||||||
|
case "o":
|
||||||
|
if (parts.length < 3 || Arrays.stream(source.options().keys()).noneMatch(parts[1]::equals)) {
|
||||||
|
System.out.println("Available options:");
|
||||||
|
for (Map.Entry<String, Option<?>> entry : source.options().getOptions().entrySet()) {
|
||||||
|
System.out.printf("%s (%s) = %s\n", entry.getValue().label, entry.getKey(), entry.getValue().getValue().toString());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Option<?> option = source.options().getOptions().get(parts[1]);
|
||||||
|
Object value = option.fromString(parts[2]);
|
||||||
|
option.setValue(value);
|
||||||
|
System.out.printf("Set %s to %s\n", option.label, option.getValue());
|
||||||
|
}
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
System.out.println("Invalid command: " + parts[0]);
|
System.out.println("Invalid command: " + parts[0]);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -6,7 +6,10 @@ import eu.m724.chat.ChatMessage
|
||||||
import eu.m724.source.ChatResponse
|
import eu.m724.source.ChatResponse
|
||||||
import eu.m724.source.ChatSource
|
import eu.m724.source.ChatSource
|
||||||
import eu.m724.source.ChatSourceInfo
|
import eu.m724.source.ChatSourceInfo
|
||||||
|
import eu.m724.source.NonStreamingChatResponse
|
||||||
import eu.m724.source.SimpleChatResponse
|
import eu.m724.source.SimpleChatResponse
|
||||||
|
import eu.m724.source.option.DoubleOption
|
||||||
|
import eu.m724.source.option.NumberOption
|
||||||
import eu.m724.source.option.Options
|
import eu.m724.source.option.Options
|
||||||
import eu.m724.source.option.StringOption
|
import eu.m724.source.option.StringOption
|
||||||
import org.json.JSONArray
|
import org.json.JSONArray
|
||||||
|
@ -21,11 +24,19 @@ import java.util.concurrent.LinkedBlockingQueue
|
||||||
// for now let's not focus on readability
|
// for now let's not focus on readability
|
||||||
// this is more about find out what is common and should be included in the common api
|
// this is more about find out what is common and should be included in the common api
|
||||||
class OaiSource implements ChatSource {
|
class OaiSource implements ChatSource {
|
||||||
|
private final String apiKey
|
||||||
|
|
||||||
|
OaiSource(String apiKey) {
|
||||||
|
this.apiKey = apiKey
|
||||||
|
}
|
||||||
|
|
||||||
private def info =
|
private def info =
|
||||||
new ChatSourceInfo("oai source", "me", "1.0", 1)
|
new ChatSourceInfo("oai source", "me", "1.0", 1)
|
||||||
|
|
||||||
private def options = new Options(
|
private def options = new Options(
|
||||||
new StringOption("apiKey", "API key", null),
|
new StringOption("apiKey", "API key", apiKey),
|
||||||
|
new StringOption("model", "Model", "gpt-4o-mini"),
|
||||||
|
new DoubleOption("temperature", "Temperature", 1.2)
|
||||||
)
|
)
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -40,8 +51,13 @@ class OaiSource implements ChatSource {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
ChatResponse onAsked(Chat chat) {
|
ChatResponse onAsked(Chat chat) {
|
||||||
def apiKey = options.getOptionValue("apiKey", String::class) // TODO handle null
|
def apiKey = options.getStringValue("apiKey") // TODO handle null
|
||||||
def requestBody = new JSONObject().put("model", "gpt-4o-mini").put("messages", formatChat(chat)).toString()
|
def requestBody = new JSONObject()
|
||||||
|
.put("model", options.getStringValue("model"))
|
||||||
|
.put("temperature", options.getOptionValue("temperature", Double::class))
|
||||||
|
.put("presence_penalty", 0.1)
|
||||||
|
.put("frequency_penalty", 0.1)
|
||||||
|
.put("messages", formatChat(chat)).toString()
|
||||||
|
|
||||||
def request = HttpRequest.newBuilder()
|
def request = HttpRequest.newBuilder()
|
||||||
.uri(URI.create("https://api.openai.com/v1/chat/completions"))
|
.uri(URI.create("https://api.openai.com/v1/chat/completions"))
|
||||||
|
@ -52,8 +68,7 @@ class OaiSource implements ChatSource {
|
||||||
|
|
||||||
def client = HttpClient.newHttpClient()
|
def client = HttpClient.newHttpClient()
|
||||||
|
|
||||||
LinkedBlockingQueue<ChatEvent> eventQueue = new LinkedBlockingQueue<>()
|
ChatResponse chatResponse = new NonStreamingChatResponse()
|
||||||
CompletableFuture<ChatMessage> messageFuture = new CompletableFuture<>()
|
|
||||||
|
|
||||||
def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofString())
|
def response = client.sendAsync(request, HttpResponse.BodyHandlers.ofString())
|
||||||
|
|
||||||
|
@ -65,20 +80,26 @@ class OaiSource implements ChatSource {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (exception != null) {
|
if (exception != null) {
|
||||||
eventQueue.put(ChatEvent.of(exception))
|
chatResponse.completeExceptionally(exception)
|
||||||
messageFuture.completeExceptionally(exception)
|
|
||||||
} else {
|
} else {
|
||||||
|
def json = new JSONObject(it.body())
|
||||||
|
//System.out.println(json)
|
||||||
|
def completion = json.getJSONArray("choices").getJSONObject(0).getJSONObject("message").getString("content")
|
||||||
|
chatResponse.complete(completion)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return new SimpleChatResponse(false, eventQueue, messageFuture)
|
return chatResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
static JSONArray formatChat(Chat chat) {
|
static JSONArray formatChat(Chat chat) {
|
||||||
def array = new JSONArray()
|
def array = new JSONArray()
|
||||||
|
|
||||||
|
if (chat.systemPrompt != null) {
|
||||||
|
array.put(new JSONObject().put("role", "system").put("content", chat.systemPrompt))
|
||||||
|
}
|
||||||
|
|
||||||
chat.messages.each {
|
chat.messages.each {
|
||||||
array.put(new JSONObject().put("role", it.assistant() ? "assistant" : "user").put("content", it.text()))
|
array.put(new JSONObject().put("role", it.assistant() ? "assistant" : "user").put("content", it.text()))
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,9 @@ public interface ChatSource {
|
||||||
ChatResponse chatResponse = onAsked(chat);
|
ChatResponse chatResponse = onAsked(chat);
|
||||||
|
|
||||||
// TODO make sure it works in parallel
|
// TODO make sure it works in parallel
|
||||||
chatResponse.message().thenAccept(chat::addMessage);
|
chatResponse.message().thenAccept(message -> {
|
||||||
|
if (message != null) chat.addMessage(message);
|
||||||
|
});
|
||||||
|
|
||||||
return chatResponse;
|
return chatResponse;
|
||||||
}
|
}
|
||||||
|
|
53
src/main/java/eu/m724/source/NonStreamingChatResponse.java
Normal file
53
src/main/java/eu/m724/source/NonStreamingChatResponse.java
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
package eu.m724.source;
|
||||||
|
|
||||||
|
import eu.m724.chat.ChatEvent;
|
||||||
|
import eu.m724.chat.ChatMessage;
|
||||||
|
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.LinkedBlockingQueue;
|
||||||
|
|
||||||
|
public class NonStreamingChatResponse implements ChatResponse {
|
||||||
|
private LinkedBlockingQueue<ChatEvent> eventQueue = new LinkedBlockingQueue<>();
|
||||||
|
private CompletableFuture<ChatMessage> message = new CompletableFuture<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean streaming() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LinkedBlockingQueue<ChatEvent> eventQueue() {
|
||||||
|
return eventQueue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CompletableFuture<ChatMessage> message() {
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean complete(String content) {
|
||||||
|
if (message.isDone()) return false;
|
||||||
|
|
||||||
|
try {
|
||||||
|
eventQueue.put(ChatEvent.of(content, "stop"));
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
throw new RuntimeException(e); // TODO probably should handle this
|
||||||
|
}
|
||||||
|
message.complete(new ChatMessage(true, content));
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean completeExceptionally(Throwable throwable) {
|
||||||
|
if (message.isDone()) return false;
|
||||||
|
|
||||||
|
try {
|
||||||
|
eventQueue.put(ChatEvent.of(throwable));
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
throw new RuntimeException(e); // TODO probably should handle this
|
||||||
|
}
|
||||||
|
message.complete(null);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
26
src/main/java/eu/m724/source/option/DoubleOption.java
Normal file
26
src/main/java/eu/m724/source/option/DoubleOption.java
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package eu.m724.source.option;
|
||||||
|
|
||||||
|
public class DoubleOption extends Option<Double> {
|
||||||
|
private double minValue = Double.MIN_VALUE;
|
||||||
|
private double maxValue = Double.MAX_VALUE;
|
||||||
|
|
||||||
|
public DoubleOption(String id, String label, Double value) {
|
||||||
|
super(id, label, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
public DoubleOption(String id, String label, Double value, double minValue, double maxValue) {
|
||||||
|
super(id, label, value);
|
||||||
|
this.minValue = minValue;
|
||||||
|
this.maxValue = maxValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
boolean isValid(Double value) {
|
||||||
|
return value >= minValue && value <= maxValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double fromString(String text) {
|
||||||
|
return Double.valueOf(text);
|
||||||
|
}
|
||||||
|
}
|
|
@ -18,4 +18,9 @@ public class NumberOption extends Option<Integer> {
|
||||||
boolean isValid(Integer value) {
|
boolean isValid(Integer value) {
|
||||||
return value >= minValue && value <= maxValue;
|
return value >= minValue && value <= maxValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer fromString(String text) {
|
||||||
|
return Integer.valueOf(text);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package eu.m724.source.option;
|
package eu.m724.source.option;
|
||||||
|
|
||||||
import java.lang.reflect.ParameterizedType;
|
import java.lang.reflect.ParameterizedType;
|
||||||
import java.lang.reflect.Type;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* represents an option that is a text label and value of any type
|
* represents an option that is a text label and value of any type
|
||||||
|
@ -49,5 +48,18 @@ public abstract class Option<T> {
|
||||||
|
|
||||||
// TODO I'm not a fan of that, probably should address the warnings
|
// TODO I'm not a fan of that, probably should address the warnings
|
||||||
|
|
||||||
|
/**
|
||||||
|
* checks if the option is valid given current constraints
|
||||||
|
* @param value the checked value
|
||||||
|
* @return is it valid
|
||||||
|
*/
|
||||||
abstract boolean isValid(T value);
|
abstract boolean isValid(T value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* convert a string to the type of this option
|
||||||
|
* TODO fix english
|
||||||
|
* @param text a text representation of a value
|
||||||
|
* @return a value in an acceptable type
|
||||||
|
*/
|
||||||
|
public abstract T fromString(String text);
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,6 +34,10 @@ public class Options {
|
||||||
return (T) options.get(id).getValue();
|
return (T) options.get(id).getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public String getStringValue(String id) {
|
||||||
|
return (String) getOptionValue(id, String.class);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set a value of an option
|
* set a value of an option
|
||||||
* @param id the option id
|
* @param id the option id
|
||||||
|
|
|
@ -23,4 +23,9 @@ public class StringOption extends Option<String> {
|
||||||
boolean isValid(String value) {
|
boolean isValid(String value) {
|
||||||
return pattern == null || pattern.matcher(value).matches();
|
return pattern == null || pattern.matcher(value).matches();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String fromString(String text) {
|
||||||
|
return text;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue