chatapi/src/main/java/eu/m724/Main.java
2024-08-29 19:50:58 +02:00

96 lines
No EOL
3.3 KiB
Java

package eu.m724;
import eu.m724.chat.Chat;
import eu.m724.chat.ChatEvent;
import eu.m724.chat.ChatMessage;
import eu.m724.example.ExampleSource;
import eu.m724.source.ChatResponse;
import eu.m724.source.ChatSource;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
public class Main {
public static void main(String[] args) throws InterruptedException {
ChatSource source = new ExampleSource();
source.options().setValue("name", readResourceFile("name.txt"));
Chat chat = new Chat();
chat.messages.add(new ChatMessage(false, "hello"));
ChatResponse chatResponse = source.ask(chat);
// I was thinking about integrating this into ChatMessage
List<String> tokens = new ArrayList<>();
List<Long> delays = new ArrayList<>();
System.out.printf("%s has %d options: %s\n\n",
source.getClass().getName(),
source.options().count(),
source.options().getOptions().values().stream().map(o -> "%s (%s)".formatted(o.label, o.getType().getName())).collect(Collectors.joining(", "))
);
System.out.println("Streaming response now\n");
ChatEvent token;
// usually finish reason will be alongside a token but this is simpler
while ((token = chatResponse.eventQueue().take()).finishReason() == null) {
System.out.print(token.text());
tokens.add(token.text());
long now = System.currentTimeMillis();
delays.add(now);
}
System.out.println("\n");
System.out.printf("Tokens: %d\n", tokens.size());
long time = delays.getFirst();
System.out.printf("\"%s\"", tokens.getFirst());
for (int i = 1; i < tokens.size(); i++) {
System.out.print(i % 2 == 1 ? "\033[1m" : "\033[0m");
System.out.printf(" + %dms + \"%s\"", delays.get(i) - time, tokens.get(i).replace("\n", "\\n"));
time = delays.get(i);
if (i % 5 == 0 && i != tokens.size() - 1) {
System.out.println(" +\033[0m");
}
}
System.out.println("\033[0m\n");
System.out.printf("\033[5mTotal: \033[8m%dms\033[0m\n", delays.getLast() - delays.getFirst());
//System.out.printf("Text: %s\n", chatResponse.message().join().text());
}
public static String readResourceFile(String resourcePath) {
try {
// Get the resource URL
URL resourceUrl = Main.class.getClassLoader().getResource(resourcePath);
if (resourceUrl == null) {
System.out.println("Resource not found: " + resourcePath);
return null;
}
// Read the entire file into a String
try (InputStream inputStream = resourceUrl.openStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
return reader.lines().collect(Collectors.joining("\n"));
}
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
}