diff --git a/.env.example b/.env.example index 498f4d8..2b1779c 100644 --- a/.env.example +++ b/.env.example @@ -1,9 +1,8 @@ +# You can also set those in settings.json + # The bot token BOT_TOKEN= -# The forum ID -FORUM_ID= - # API key for the OpenAI-compatible provider OPENAI_API_KEY= diff --git a/.gitignore b/.gitignore index a8dac6b..ced5fd5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ __pycache__/ + doc/ -.env \ No newline at end of file +threads/ + +.env +settings.json \ No newline at end of file diff --git a/Containerfile b/Containerfile index 696d6af..e7c0dda 100644 --- a/Containerfile +++ b/Containerfile @@ -56,6 +56,8 @@ COPY --from=builder --chown=appuser:appuser /app/src ./src # This allows us to run executables directly (e.g., `gunicorn`, `uvicorn`) ENV PATH="/app/.venv/bin:$PATH" +ENV SETTINGS_PATH="/settings.json" + # Switch to the non-root user USER appuser diff --git a/docker-compose.yml b/docker-compose.yml index d741563..d063022 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,10 +2,5 @@ services: bot: build: dockerfile: Containerfile - env_file: - - .env - environment: - BOT_TOKEN: - FORUM_ID: - OPENAI_API_KEY: - OPENAI_BASE_URL: \ No newline at end of file + volumes: + - ./settings.json:/settings.json:ro \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 74f9eef..1614671 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,4 +20,4 @@ build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] -start = "plugin_assist:main" +start = "src.plugin_assist:main" diff --git a/settings.example.json b/settings.example.json new file mode 100644 index 0000000..9bca5a8 --- /dev/null +++ b/settings.example.json @@ -0,0 +1,36 @@ +{ + "inference": { + "api_key": "replace-me-with-your-actual-key", + "base_url": "https://nano-gpt.com/api/v1", + "models": { + "fast": "kimi-k2-instruct-fast", + "cheap": "openai/gpt-oss-120b", + "image_transcript": "openai/gpt-4.1-nano" + } + }, + "assistants": { + "RealWeather": { + "system_prompt": [ + "You are an assistant who helps with the RealWeather Minecraft plugin.", + "RealWeather is a plugin that synchronizes real weather, time and lightning strikes.", + "It's alpha software, and the current limitations make it more ideal for ambience, docs will tell you more.", + "You have the documentaton available, use it." + ], + "documentation_directory": "doc/RealWeather" + }, + "Other": { + "system_prompt": [ + "You are an assistant who answers inquiries about Minecraft plugins.", + "Ask if help is needed first. Don't offer any help unless explicitly asked." + ] + } + }, + "bot": { + "token": "Replace.Me.With.Your.Actual.Token", + "allowed_forum_ids": [ + 1401918766044151808, + 1402900706054242445 + ], + "thread_persistence_directory": "threads" + } +} \ No newline at end of file diff --git a/src/plugin_assist/__init__.py b/src/plugin_assist/__init__.py index b88d4a4..ff9f8eb 100644 --- a/src/plugin_assist/__init__.py +++ b/src/plugin_assist/__init__.py @@ -1,25 +1,45 @@ -from .assistant import Assistant +from .assistant import Assistant, Assistants, Documentation from .bot import AssistantBot +from .settings import Settings, AssistantSettings, BotSettings from openai import AsyncOpenAI -from os import environ +from os import environ, mkdir, path +from json import loads + +def load_settings() -> Settings: + settings_dict = loads(open(environ.get("SETTINGS_PATH", "settings.json")).read()) + return Settings.from_dict(settings_dict) + +def make_assistants(settings: Settings, client: AsyncOpenAI) -> Assistants: + return Assistants({assistant_name: Assistant( + client = client, + documentation = Documentation(assistant_settings.documentation_directory) if assistant_settings.documentation_directory else None, + system_prompt = assistant_settings.system_prompt, + models = settings.inference.models + ) for assistant_name, assistant_settings in settings.assistants.items()}) + +def make_bot(settings: BotSettings, assistants: Assistants) -> AssistantBot: + return AssistantBot( + assistants = assistants, + forum_ids = settings.allowed_forum_ids, + thread_persistence_directory = settings.thread_persistence_directory + ) def main(): + settings = load_settings() + + if not path.exists(settings.bot.thread_persistence_directory): + mkdir(settings.bot.thread_persistence_directory) + client = AsyncOpenAI( - base_url = environ.get("OPENAI_BASE_URL") + base_url = environ.get("OPENAI_BASE_URL", settings.inference.base_url), + api_key = environ.get("OPENAI_API_KEY", settings.inference.api_key), ) - assistant = Assistant( - client = client, - documentation_directory = "doc" - ) - - bot = AssistantBot( - assistant = assistant, - forum_id = int(environ.get("FORUM_ID")) - ) + assistants = make_assistants(settings, client) + bot = make_bot(settings.bot, assistants) print("Starting bot") - token = environ.get("BOT_TOKEN") + token = environ.get("BOT_TOKEN", settings.bot.token) bot.run(token) \ No newline at end of file diff --git a/src/plugin_assist/assistant/__init__.py b/src/plugin_assist/assistant/__init__.py index b4c0e3f..37b3694 100644 --- a/src/plugin_assist/assistant/__init__.py +++ b/src/plugin_assist/assistant/__init__.py @@ -1,4 +1,5 @@ from .assistant import Assistant +from .multi import Assistants from .documentation import Documentation -__all__ = ['Assistant', 'Documentation'] \ No newline at end of file +__all__ = ['Assistant', 'Assistants', 'Documentation'] \ No newline at end of file diff --git a/src/plugin_assist/assistant/assistant.py b/src/plugin_assist/assistant/assistant.py index 8da17d0..833921a 100644 --- a/src/plugin_assist/assistant/assistant.py +++ b/src/plugin_assist/assistant/assistant.py @@ -2,18 +2,11 @@ from openai import AsyncOpenAI from asyncio import AbstractEventLoop, Task, get_running_loop, CancelledError from dataclasses import dataclass, field from json import loads, dumps -from typing import Any, Literal, List, Dict +from typing import Any, Literal, List, Dict, Tuple, Coroutine from .documentation import Documentation -_DEFAULT_SYSTEM_PROMPT = """ -You are an assistant who helps with the RealWeather Minecraft plugin. -RealWeather is a plugin that synchronizes real weather, time and lightning strikes. -It's alpha software, and right now it's kinda more ideal for ambience, docs will tell you more. -You have the documentaton available, use it. -""" - -_DEFAULT_TOOLS = [{ +_DOCUMENTATION_SEARCH_TOOL = { "type": "function", "function": { "name": "search_documentation", @@ -24,14 +17,12 @@ _DEFAULT_TOOLS = [{ "properties": { "query": { "type": "string", - "description": "The detailed search query as tags." + "description": "The detailed search query as separated tags." } } } } -}] - -_DEFAULT_MODEL = "kimi-k2-instruct-fast" +} # A type hint for the status, making it clear what the possible values are. ResponseStatus = Literal["completed", "cancelled", "error"] @@ -41,26 +32,31 @@ class AssistantResponse: """A structured response from the Assistant.""" status: ResponseStatus content: str = "" - messages: List[Dict[str, Any]] = field(default_factory=list) + new_messages: List[Dict[str, Any]] = field(default_factory=list) error: str | None = None class Assistant: - client: AsyncOpenAI - documentation: Documentation + _client: AsyncOpenAI + documentation: Documentation | None system_prompt: str - model: str - tools: list[dict] + models: Dict[str, str] - def __init__(self, client: AsyncOpenAI, documentation_directory: str, system_prompt: str | None = None): - self.client = client - self.documentation = Documentation(documentation_directory) - self.system_prompt = system_prompt if system_prompt is not None else _DEFAULT_SYSTEM_PROMPT - self.model = _DEFAULT_MODEL - self.tools = _DEFAULT_TOOLS + _tools: List[Dict] + _tool_map: Dict[str, Coroutine[Any, Any, Dict]] + + def __init__(self, client: AsyncOpenAI, documentation: Documentation | None, system_prompt: str, models: dict[str, str]): + self._client = client + self.documentation = documentation + self.system_prompt = system_prompt + self.models = models + + self._tools = [_DOCUMENTATION_SEARCH_TOOL] if documentation is not None else [] + self._tool_map = {"search_documentation": self.documentation.search_documentation} if documentation is not None else {} def ask( self, messages: list[dict], + images_to_transcript: List[Tuple[int, int, str]] = [], *, loop: AbstractEventLoop | None = None, ) -> Task[AssistantResponse]: @@ -68,80 +64,107 @@ class Assistant: Start generating the assistant response and immediately return the asyncio.Task, so the caller can `await` it OR cancel it. """ loop = loop or get_running_loop() - task = loop.create_task(self._ask(messages)) + task = loop.create_task(self._ask(messages, images_to_transcript)) return task # MODIFIED: The method now returns our structured response and handles cancellation. - async def _ask(self, messages: list[dict]) -> AssistantResponse: + async def _ask(self, messages: list[dict], images_to_transcript: List[Tuple[int, int, str]]) -> AssistantResponse: messages_copy = messages.copy() # The messages are not expected to be modified messages_copy.insert(0, {"role": "system", "content": self.system_prompt}) try: # The core logic is now wrapped to catch cancellation. - response_content = await self._ask_inner(messages_copy) + new_messages, response_text = await self._ask_inner(messages_copy, images_to_transcript) print(messages_copy) return AssistantResponse( status="completed", - content=response_content, - messages=messages_copy + content=response_text, + new_messages=new_messages ) except CancelledError: - print("Assistant 'ask' task was cancelled.") - # Return a specific response indicating cancellation. return AssistantResponse( status="cancelled", - messages=messages_copy # Return the history up to the point of cancellation + new_messages=[] ) except Exception as e: - print(f"An unexpected error occurred in 'ask': {e}") - # Return a response for any other errors. + print(f"An error occurred in 'ask': {e}") + return AssistantResponse( status="error", - messages=messages_copy, + new_messages=[], error=str(e) ) - async def _ask_inner(self, messages: List[Dict[str, Any]]) -> str: - response_segments = [] + async def _ask_inner(self, messages: List[Dict[str, Any]], images_to_transcript: List[Tuple[int, int, str]]) -> Tuple[List[Dict[str, Any]], str]: + # +1 because system prompt + for message_index, transcript_index, image_url in images_to_transcript: + print("Transcribing", message_index+1, transcript_index, image_url) + transcript = await self.transcribe_image(image_url) + messages[message_index+1]['content'] = messages[message_index+1]['content'].replace(f"%TRANSCRIPT_{transcript_index}%", transcript) + + new_messages = [] + response_text_parts = [] finish_reason = "tool_calls" + model = self.models['cheap'] if len(messages) > 10 else self.models['fast'] # TODO not hardcode + print("Using model", model) + while finish_reason == "tool_calls": print("Making API call...") - completion = await self.client.chat.completions.create( - model=self.model, - messages=messages, + completion = await self._client.chat.completions.create( + model=model, + messages=messages + new_messages, temperature=0.0, - tools=self.tools, + tools=self._tools, tool_choice="auto" ) choice = completion.choices[0] finish_reason = choice.finish_reason - messages.append(choice.message) + new_messages.append(choice.message) if choice.message.content: - response_segments.append(choice.message.content) + response_text_parts.append(choice.message.content) if finish_reason == "tool_calls" and choice.message.tool_calls is not None: - tool_map = {"search_documentation": self.documentation.search_documentation} - for tool_call in choice.message.tool_calls: tool_call_name = tool_call.function.name tool_call_arguments = loads(tool_call.function.arguments) - tool_function = tool_map[tool_call_name] print(f"Executing tool: {tool_call_name}...") + tool_function = self._tool_map[tool_call_name] tool_result = await tool_function(**tool_call_arguments) print("Tool result:", tool_result) - messages.append({ + new_messages.append({ "role": "tool", "tool_call_id": tool_call.id, "name": tool_call_name, "content": dumps(tool_result) }) - return '\n\n'.join(response_segments) \ No newline at end of file + return new_messages, '\n\n'.join(response_text_parts) + + async def transcribe_image(self, image_url: str) -> str: + response = await self._client.chat.completions.create( + model=self.models['image_transcript'], + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Transcribe this image in detail."}, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + }, + ], + }], + ) + + print("Image transcript result", response) + + return response.choices[0].message.content \ No newline at end of file diff --git a/src/plugin_assist/assistant/documentation.py b/src/plugin_assist/assistant/documentation.py index 9f24978..e2578d4 100644 --- a/src/plugin_assist/assistant/documentation.py +++ b/src/plugin_assist/assistant/documentation.py @@ -4,55 +4,7 @@ from collections import Counter from typing import List, Tuple, Any -from .persistent import caio_read_file - -def _partition_by_most_frequent(data_list: List[Any]) -> Tuple[List[Any], List[Any]]: - """ - Partitions a list into two lists based on element frequency. - - The first list contains all occurrences of the element(s) that appear - most frequently. The second list contains all other elements. - - Args: - data_list: The input list of elements. - - Returns: - A tuple containing two lists: - - The first list with all instances of the most frequent element(s). - - The second list with the remaining elements. - """ - if not data_list: - return [], [] - - counts = Counter(data_list) - - max_frequency = counts.most_common(1)[0][1] - - most_frequent_keys = { - item for item, count in counts.items() if count == max_frequency - } - - most_list = [item for item in data_list if item in most_frequent_keys] - other_list = [item for item in data_list if item not in most_frequent_keys] - - return most_list, other_list - -def _load_tags(base_directory: str) -> dict: - tag_map = {} - - for root, _, files in os.walk(base_directory): - for name in files: - file_path = os.path.join(root, name) - tag_line = open(file_path).readline() - tags = [tag.strip() for tag in tag_line[6:].split(",")] - - for tag in tags: - if tag in tag_map: - tag_map[tag].append(file_path) - else : - tag_map[tag] = [file_path] - - return tag_map +from ..common.persistence import caio_read_file class Documentation: base_directory: str @@ -108,4 +60,42 @@ class Documentation: "probable": p[1] }, "guide": "Call this function again with the filename as query to view that file." - } \ No newline at end of file + } + +def _partition_by_most_frequent(data_list: List[Any]) -> Tuple[List[Any], List[Any]]: + if not data_list: + return [], [] + + counts = Counter(data_list) + + # Find the highest frequency value. Using max() is more direct than most_common(). + max_frequency = max(counts.values()) + + most_frequent_items = [] + other_items = [] + + # Iterate through the unique items and their counts + for item, count in counts.items(): + if count == max_frequency: + most_frequent_items.append(item) + else: + other_items.append(item) + + return most_frequent_items, other_items + +def _load_tags(base_directory: str) -> dict: + tag_map = {} + + for root, _, files in os.walk(base_directory): + for name in files: + file_path = os.path.join(root, name) + tag_line = open(file_path).readline() + tags = [tag.strip() for tag in tag_line[6:].split(",")] + + for tag in tags: + if tag in tag_map: + tag_map[tag].append(file_path) + else : + tag_map[tag] = [file_path] + + return tag_map \ No newline at end of file diff --git a/src/plugin_assist/assistant/multi.py b/src/plugin_assist/assistant/multi.py new file mode 100644 index 0000000..ae5f3fd --- /dev/null +++ b/src/plugin_assist/assistant/multi.py @@ -0,0 +1,16 @@ +from .assistant import Assistant + +class Assistants: + _assistants: dict[str, Assistant] + + def __init__(self, assistants: dict[str, Assistant]): + self._assistants = assistants + + def get_assistant(self, id: str): + return self._assistants[id] + + def get_assistant_ids(self) -> set[str]: + return set(self._assistants.keys()) + + def has_assistant(self, id: str): + return id in self._assistants \ No newline at end of file diff --git a/src/plugin_assist/bot/__init__.py b/src/plugin_assist/bot/__init__.py index 90ddaa5..48df1bd 100644 --- a/src/plugin_assist/bot/__init__.py +++ b/src/plugin_assist/bot/__init__.py @@ -1,23 +1,25 @@ import disnake from disnake.ext import commands -from ..assistant import Assistant -from .threads import AssistedThread +from ..assistant import Assistants +from .threads import AssistedThreads +from .persistence import ThreadPersistence class AssistantBot(commands.InteractionBot): - forum_id: int - assistant: Assistant - threads: dict[int, AssistedThread] + assistants: Assistants + forum_ids: list[int] + + threads: AssistedThreads - def __init__(self, forum_id: int, assistant: Assistant, **kwargs): + def __init__(self, assistants: Assistants, forum_ids: list[int], thread_persistence_directory: str, **kwargs): intents = disnake.Intents.default() intents.message_content = True super().__init__(intents=intents, **kwargs) - self.forum_id = forum_id - self.assistant = assistant - self.threads = {} + self.assistants = assistants + self.forum_ids = forum_ids + self.threads = AssistedThreads(assistants, ThreadPersistence(thread_persistence_directory)) async def on_ready(self): print(f"Logged in as {self.user} (ID: {self.user.id})\n------") @@ -27,50 +29,52 @@ class AssistantBot(commands.InteractionBot): ) async def on_thread_create(self, thread: disnake.Thread): - if thread.parent_id != self.forum_id: + if thread.parent_id not in self.forum_ids: return - if thread.id not in self.threads: - self.threads[thread.id] = AssistedThread(thread, self.assistant) + self.threads.create(thread) async def on_thread_update(self, before: disnake.Thread, thread: disnake.Thread): - if thread.parent_id != self.forum_id: + if thread.parent_id not in self.forum_ids: + return + + assisted_thread = await self.threads.get_by_thread(thread) + if assisted_thread is None: return - if thread.id not in self.threads: - return # TODO the two checks will be replaced print("Thread updated", thread.applied_tags) if len(thread.applied_tags) != 1: return - plugin_name = thread.applied_tags[0].name + assistant_name = thread.applied_tags[0].name print("and thread is good") - if await self.threads[thread.id].initialize(plugin_name): - await self.threads[thread.id].respond() - + if await assisted_thread.initialize(assistant_name): + await assisted_thread.respond() async def on_message(self, message: disnake.Message): - if message.author == self.user: + if not await self.should_respond_to_message(message): return + await self.threads.process_message(message) + + async def should_respond_to_message(self, message: disnake.Message) -> bool: + if message.author == self.user: + return False + if isinstance(message.channel, disnake.DMChannel): await message.reply("I can't respond to DMs. I can help only on our server.") - return + return False if not isinstance(message.channel, disnake.Thread): - return + return False + + if message.channel.parent_id not in self.forum_ids: + return False if '[-ai]' in message.content: - return + return False - thread = message.channel - - if thread.parent_id != self.forum_id: - return - if thread.id not in self.threads: - return # TODO the two checks will be replaced - - await self.threads[thread.id].on_message(message) \ No newline at end of file + return True \ No newline at end of file diff --git a/src/plugin_assist/bot/persistence.py b/src/plugin_assist/bot/persistence.py new file mode 100644 index 0000000..96991fc --- /dev/null +++ b/src/plugin_assist/bot/persistence.py @@ -0,0 +1,48 @@ +from json import JSONEncoder, JSONDecoder, loads, dumps +from openai.types.chat.chat_completion import ChatCompletionMessage +from disnake import Thread + +from ..common.persistence import Persistence +from ..assistant import Assistants + +from .threads import AssistedThread + +class ThreadPersistence(Persistence): + async def save(self, thread: AssistedThread): + data = { + "enabled": thread.enabled, + "assistant_name": thread.assistant_name, + "messages": thread.cached_messages, + "transcript_queue": thread._transcript_queue + } + + print("Saving", data) + content = dumps(data, cls=ChatCompletionMessageJSONEncoder) + await self.save_string(str(thread.thread.id), content) + + async def load(self, thread: Thread, assistants: Assistants) -> AssistedThread | None: + data_raw = await self.load_string(str(thread.id)) + if data_raw is None: + return None + + data = loads(data_raw) + print("Loaded", data) + + return AssistedThread( + thread=thread, + assistants=assistants, + enabled=data['enabled'], + assistant_name=data['assistant_name'], + cached_messages=data['messages'], + _transcript_queue=data['transcript_queue'] + ) + +class ChatCompletionMessageJSONEncoder(JSONEncoder): + def _prepare_message(self, message: ChatCompletionMessage): + return {k: v for k, v in message.__dict__.items() if v is not None} + + def default(self, obj): + if isinstance(obj, ChatCompletionMessage): + return self._prepare_message(obj) + + return super().default(obj) \ No newline at end of file diff --git a/src/plugin_assist/bot/threads.py b/src/plugin_assist/bot/threads.py index f0071cb..b0dbd69 100644 --- a/src/plugin_assist/bot/threads.py +++ b/src/plugin_assist/bot/threads.py @@ -1,101 +1,171 @@ import disnake from dataclasses import dataclass, field from asyncio import Task -from typing import List +from typing import List, Dict, Tuple -from ..assistant.assistant import Assistant +from ..assistant.multi import Assistants + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .persistence import ThreadPersistence + +@dataclass +class AssistedThreads: + assistants: Assistants + persistence: 'ThreadPersistence' + + _threads: dict[int, 'AssistedThread'] = field(default_factory=dict) + + def create(self, thread: disnake.Thread) -> 'AssistedThread': + if thread.id in self._threads: + return self._threads[thread.id] + + self._threads[thread.id] = AssistedThread(thread, self.assistants) + return self._threads[thread.id] + + def get_by_id(self, id: int) -> 'AssistedThread': + return self._threads[id] + + async def get_by_thread(self, thread: disnake.Thread) -> 'AssistedThread': + if not thread.id in self._threads: + loaded_thread = await self.persistence.load(thread, self.assistants) + + if loaded_thread is not None: + self._threads[thread.id] = loaded_thread + else: + self.create(thread) + + return self._threads[thread.id] + + async def process_message(self, message: disnake.Message): + if not isinstance(message.channel, disnake.Thread): + return + + assisted_thread = await self.get_by_thread(message.channel) + + await assisted_thread.on_message(message) + await self.persistence.save(assisted_thread) @dataclass class AssistedThread: thread: disnake.Thread - assistant: Assistant + assistants: Assistants enabled: bool = False - plugin: str = "" - cached_messages: List[dict] = field(default_factory=list) - last_warning: disnake.Message | None = None - current_response: Task | None = None + assistant_name: str = "" + cached_messages: List[Dict] = field(default_factory=list) - async def on_message(self, message: disnake.Message): - """TODO handle editing and alike: messages = [{ - "role": "user" if m.author != self.user else "assistant", - "content": m.content - } async for m in query_message.channel.history(oldest_first=True)]""" - + _transcript_queue: List[Tuple[int, int, str]] = field(default_factory=list) + _last_warning: disnake.Message | None = None + _current_response: Task | None = None + + async def _add_new_message(self, message: disnake.Message): content = message.content + message_index = len(self.cached_messages) for attachment in message.attachments: + content += f"\n\nAttached file {attachment.filename}:\n" if 'text' in attachment.content_type: - data = await attachment.read() - content += f"\n\nAttached file {attachment.filename}:\n---\n{data.decode('utf-8')}\n---" + data = await attachment.read() # TODO this is the only await maybe we could remove it? + content += "---\n" + data.decode('utf-8') + "\n---" + elif 'image' in attachment.content_type: + self._transcript_queue.append((message_index, attachment.id, attachment.proxy_url)) + content += f"---\nImage transcript:\n%TRANSCRIPT_{attachment.id}%\n---" else: - content += "\n\nAttached file " + attachment.filename + " is of unsupported type, cannot read." + content += "unsupported type, cannot read." self.cached_messages.append( {'role': 'user', 'content': content} ) - if not self.enabled: - if self.last_warning is None: - if len(self.thread.applied_tags) != 1: - await self.warning(":label: Please keep only a single tag on this thread, I'll then be able to assist you better.") - return - else: - plugin_name = self.thread.applied_tags[0].name - await self.initialize(plugin_name) - else: - return + async def _check_tags(self) -> bool: + if len(self.thread.applied_tags) == 0: + await self.warning(":label: Please tag this thread, I'll be able to assist you better.") + return False - #await self.respond(message) - await self.respond() + if len(self.thread.applied_tags) > 1: + await self.warning(":label: Please keep only a single tag on this thread, I'll be able to assist you better.") + return False + + return True - async def initialize(self, plugin: str) -> bool: + async def on_message(self, message: disnake.Message): + """TODO handle editing and alike?? messages = [{ + "role": "user" if m.author != self.user else "assistant", + "content": m.content + } async for m in query_message.channel.history(oldest_first=True)]""" + + await self._add_new_message(message) + + if not self.enabled: + if self._last_warning is not None: + return False + + if await self._check_tags(): + assistant_name = self.thread.applied_tags[0].name + await self.initialize(assistant_name) + + return await self.respond() + + async def initialize(self, assistant_name: str) -> bool: if self.enabled: return False await self.remove_last_warning() - - if plugin != "RealWeather": - #await self.warning(":slight_frown: The automated assistant cannot assist you with this plugin. :hourglass: Wait for human support.") - self.enabled = False + + if not self.assistants.has_assistant(assistant_name): return False - self.plugin = plugin + self.assistant_name = assistant_name self.enabled = True return True async def warning(self, message: str, reply_to: disnake.Message | None = None): if reply_to is None: - self.last_warning = await self.thread.send(message) + self._last_warning = await self.thread.send(message) else: - self.last_warning = await reply_to.reply(message) + self._last_warning = await reply_to.reply(message) async def remove_last_warning(self): - if self.last_warning is not None: - await self.last_warning.delete() - self.last_warning = None + if self._last_warning is not None: + await self._last_warning.delete() + self._last_warning = None - async def respond(self, message: disnake.Message | None = None): - if self.current_response is not None: + async def respond(self, message: disnake.Message | None = None) -> bool: + if not self.enabled: + print("Want to respond but disabled") + return False + + if self._current_response is not None: print("Cancelled old message") - self.current_response.cancel() + self._current_response.cancel() async with self.thread.typing(): - self.current_response = self.assistant.ask(self.cached_messages) - response = await self.current_response + assistant = self.assistants.get_assistant(self.assistant_name) + self._current_response = assistant.ask(self.cached_messages, self._transcript_queue) + response = await self._current_response if response.status == 'completed': - self.cached_messages = response.messages - - if message is not None: - await reply_long(message, response.content) - else: - await send_long(self.thread, response.content) + print("New messages:", response.new_messages) + self.cached_messages += response.new_messages + self._transcript_queue = [] + self._current_response = None + await self.send_or_reply(response.content, message) + + return True elif response.status == 'cancelled': pass else: - await message.reply(":bangbang: Sorry, an error occurred. Cannot proceed with your request.") + await self.send_or_reply(":bangbang: Sorry, an error occurred. Cannot proceed with your request.", message) + + return False + + async def send_or_reply(self, content: str, message: disnake.Message | None = None): + if message is not None: + await reply_long(message, content) + else: + await send_long(self.thread, content) # TODO move this stuff diff --git a/src/plugin_assist/common/file_utils.py b/src/plugin_assist/common/file_utils.py new file mode 100644 index 0000000..a8cdc1a --- /dev/null +++ b/src/plugin_assist/common/file_utils.py @@ -0,0 +1,19 @@ +from caio import AsyncioContext + +async def caio_read_file(context: AsyncioContext, fd: int) -> str: + message = bytearray() + offset = 0 + + while True: + content = await context.read(512, fd, offset) + + if len(content) == 0: + break + + message += content + offset += len(content) + + return message.decode('utf-8') + +async def caio_write_file(context: AsyncioContext, fd: int, content: str): + await context.write(content.encode(), fd, 0) \ No newline at end of file diff --git a/src/plugin_assist/assistant/persistent.py b/src/plugin_assist/common/persistence.py similarity index 56% rename from src/plugin_assist/assistant/persistent.py rename to src/plugin_assist/common/persistence.py index e981eaa..d14b29f 100644 --- a/src/plugin_assist/assistant/persistent.py +++ b/src/plugin_assist/common/persistence.py @@ -1,23 +1,7 @@ import os from caio import AsyncioContext -async def caio_read_file(context: AsyncioContext, fd: int) -> str: - message = bytearray() - offset = 0 - - while True: - content = await context.read(512, fd, offset) - - if len(content) == 0: - break - - message += content - offset += len(content) - - return message.decode('utf-8') - -async def caio_write_file(context: AsyncioContext, fd: int, content: str): - await context.write(content.encode(), fd, 0) +from .file_utils import caio_read_file, caio_write_file class Persistence: context: AsyncioContext @@ -27,7 +11,7 @@ class Persistence: self.context = AsyncioContext() self.base_directory = base_directory - async def save(self, id: str, content: str): + async def save_string(self, id: str, content: str): filename = os.path.join(self.base_directory, id) fd = os.open(filename, os.O_CREAT | os.O_WRONLY) @@ -36,7 +20,7 @@ class Persistence: finally: os.close(fd) - async def load(self, id: str) -> str: + async def load_string(self, id: str) -> str | None: filename = os.path.join(self.base_directory, id) fd = os.open(filename, os.O_RDONLY) diff --git a/src/plugin_assist/settings.py b/src/plugin_assist/settings.py new file mode 100644 index 0000000..ccceebf --- /dev/null +++ b/src/plugin_assist/settings.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass + +@dataclass +class InferenceSettings: + base_url: str | None + api_key: str | None + models: dict[str, str] + + @staticmethod + def from_dict(data: dict) -> 'InferenceSettings': + return InferenceSettings( + base_url = data.get('base_url', None), + api_key = data.get('api_key', None), + models = data['models'] + ) + +@dataclass +class AssistantSettings: + system_prompt: str + documentation_directory: str | None + + @staticmethod + def from_dict(data: dict) -> 'AssistantSettings': + return AssistantSettings( + system_prompt = '\n'.join(data['system_prompt']), + documentation_directory = data.get('documentation_directory', None) + ) + +@dataclass +class BotSettings: + token: str | None + allowed_forum_ids: list[int] + thread_persistence_directory: str + + @staticmethod + def from_dict(data: dict) -> 'BotSettings': + return BotSettings( + token = data.get('token', None), + allowed_forum_ids = data['allowed_forum_ids'], + thread_persistence_directory = data.get('thread_persistence_directory', 'threads') + ) + +@dataclass +class Settings: + inference: InferenceSettings + assistants: dict[str, AssistantSettings] + bot: BotSettings + + @staticmethod + def from_dict(data: dict) -> 'Settings': + return Settings( + inference = InferenceSettings.from_dict(data['inference']), + assistants = {k: AssistantSettings.from_dict(v) for k, v in data['assistants'].items()}, + bot = BotSettings.from_dict(data['bot']) + ) \ No newline at end of file