Update
This commit is contained in:
parent
2976862362
commit
a332805cd6
17 changed files with 491 additions and 225 deletions
|
|
@ -1,9 +1,8 @@
|
|||
# You can also set those in settings.json
|
||||
|
||||
# The bot token
|
||||
BOT_TOKEN=
|
||||
|
||||
# The forum ID
|
||||
FORUM_ID=
|
||||
|
||||
# API key for the OpenAI-compatible provider
|
||||
OPENAI_API_KEY=
|
||||
|
||||
|
|
|
|||
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -1,3 +1,7 @@
|
|||
__pycache__/
|
||||
|
||||
doc/
|
||||
.env
|
||||
threads/
|
||||
|
||||
.env
|
||||
settings.json
|
||||
|
|
@ -56,6 +56,8 @@ COPY --from=builder --chown=appuser:appuser /app/src ./src
|
|||
# This allows us to run executables directly (e.g., `gunicorn`, `uvicorn`)
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
ENV SETTINGS_PATH="/settings.json"
|
||||
|
||||
# Switch to the non-root user
|
||||
USER appuser
|
||||
|
||||
|
|
|
|||
|
|
@ -2,10 +2,5 @@ services:
|
|||
bot:
|
||||
build:
|
||||
dockerfile: Containerfile
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
BOT_TOKEN:
|
||||
FORUM_ID:
|
||||
OPENAI_API_KEY:
|
||||
OPENAI_BASE_URL:
|
||||
volumes:
|
||||
- ./settings.json:/settings.json:ro
|
||||
|
|
@ -20,4 +20,4 @@ build-backend = "poetry.core.masonry.api"
|
|||
|
||||
|
||||
[tool.poetry.scripts]
|
||||
start = "plugin_assist:main"
|
||||
start = "src.plugin_assist:main"
|
||||
|
|
|
|||
36
settings.example.json
Normal file
36
settings.example.json
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
{
|
||||
"inference": {
|
||||
"api_key": "replace-me-with-your-actual-key",
|
||||
"base_url": "https://nano-gpt.com/api/v1",
|
||||
"models": {
|
||||
"fast": "kimi-k2-instruct-fast",
|
||||
"cheap": "openai/gpt-oss-120b",
|
||||
"image_transcript": "openai/gpt-4.1-nano"
|
||||
}
|
||||
},
|
||||
"assistants": {
|
||||
"RealWeather": {
|
||||
"system_prompt": [
|
||||
"You are an assistant who helps with the RealWeather Minecraft plugin.",
|
||||
"RealWeather is a plugin that synchronizes real weather, time and lightning strikes.",
|
||||
"It's alpha software, and the current limitations make it more ideal for ambience, docs will tell you more.",
|
||||
"You have the documentaton available, use it."
|
||||
],
|
||||
"documentation_directory": "doc/RealWeather"
|
||||
},
|
||||
"Other": {
|
||||
"system_prompt": [
|
||||
"You are an assistant who answers inquiries about Minecraft plugins.",
|
||||
"Ask if help is needed first. Don't offer any help unless explicitly asked."
|
||||
]
|
||||
}
|
||||
},
|
||||
"bot": {
|
||||
"token": "Replace.Me.With.Your.Actual.Token",
|
||||
"allowed_forum_ids": [
|
||||
1401918766044151808,
|
||||
1402900706054242445
|
||||
],
|
||||
"thread_persistence_directory": "threads"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,25 +1,45 @@
|
|||
from .assistant import Assistant
|
||||
from .assistant import Assistant, Assistants, Documentation
|
||||
from .bot import AssistantBot
|
||||
from .settings import Settings, AssistantSettings, BotSettings
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from os import environ
|
||||
from os import environ, mkdir, path
|
||||
from json import loads
|
||||
|
||||
def load_settings() -> Settings:
|
||||
settings_dict = loads(open(environ.get("SETTINGS_PATH", "settings.json")).read())
|
||||
return Settings.from_dict(settings_dict)
|
||||
|
||||
def make_assistants(settings: Settings, client: AsyncOpenAI) -> Assistants:
|
||||
return Assistants({assistant_name: Assistant(
|
||||
client = client,
|
||||
documentation = Documentation(assistant_settings.documentation_directory) if assistant_settings.documentation_directory else None,
|
||||
system_prompt = assistant_settings.system_prompt,
|
||||
models = settings.inference.models
|
||||
) for assistant_name, assistant_settings in settings.assistants.items()})
|
||||
|
||||
def make_bot(settings: BotSettings, assistants: Assistants) -> AssistantBot:
|
||||
return AssistantBot(
|
||||
assistants = assistants,
|
||||
forum_ids = settings.allowed_forum_ids,
|
||||
thread_persistence_directory = settings.thread_persistence_directory
|
||||
)
|
||||
|
||||
def main():
|
||||
settings = load_settings()
|
||||
|
||||
if not path.exists(settings.bot.thread_persistence_directory):
|
||||
mkdir(settings.bot.thread_persistence_directory)
|
||||
|
||||
client = AsyncOpenAI(
|
||||
base_url = environ.get("OPENAI_BASE_URL")
|
||||
base_url = environ.get("OPENAI_BASE_URL", settings.inference.base_url),
|
||||
api_key = environ.get("OPENAI_API_KEY", settings.inference.api_key),
|
||||
)
|
||||
|
||||
assistant = Assistant(
|
||||
client = client,
|
||||
documentation_directory = "doc"
|
||||
)
|
||||
|
||||
bot = AssistantBot(
|
||||
assistant = assistant,
|
||||
forum_id = int(environ.get("FORUM_ID"))
|
||||
)
|
||||
assistants = make_assistants(settings, client)
|
||||
bot = make_bot(settings.bot, assistants)
|
||||
|
||||
print("Starting bot")
|
||||
|
||||
token = environ.get("BOT_TOKEN")
|
||||
token = environ.get("BOT_TOKEN", settings.bot.token)
|
||||
bot.run(token)
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
from .assistant import Assistant
|
||||
from .multi import Assistants
|
||||
from .documentation import Documentation
|
||||
|
||||
__all__ = ['Assistant', 'Documentation']
|
||||
__all__ = ['Assistant', 'Assistants', 'Documentation']
|
||||
|
|
@ -2,18 +2,11 @@ from openai import AsyncOpenAI
|
|||
from asyncio import AbstractEventLoop, Task, get_running_loop, CancelledError
|
||||
from dataclasses import dataclass, field
|
||||
from json import loads, dumps
|
||||
from typing import Any, Literal, List, Dict
|
||||
from typing import Any, Literal, List, Dict, Tuple, Coroutine
|
||||
|
||||
from .documentation import Documentation
|
||||
|
||||
_DEFAULT_SYSTEM_PROMPT = """
|
||||
You are an assistant who helps with the RealWeather Minecraft plugin.
|
||||
RealWeather is a plugin that synchronizes real weather, time and lightning strikes.
|
||||
It's alpha software, and right now it's kinda more ideal for ambience, docs will tell you more.
|
||||
You have the documentaton available, use it.
|
||||
"""
|
||||
|
||||
_DEFAULT_TOOLS = [{
|
||||
_DOCUMENTATION_SEARCH_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_documentation",
|
||||
|
|
@ -24,14 +17,12 @@ _DEFAULT_TOOLS = [{
|
|||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The detailed search query as tags."
|
||||
"description": "The detailed search query as separated tags."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
_DEFAULT_MODEL = "kimi-k2-instruct-fast"
|
||||
}
|
||||
|
||||
# A type hint for the status, making it clear what the possible values are.
|
||||
ResponseStatus = Literal["completed", "cancelled", "error"]
|
||||
|
|
@ -41,26 +32,31 @@ class AssistantResponse:
|
|||
"""A structured response from the Assistant."""
|
||||
status: ResponseStatus
|
||||
content: str = ""
|
||||
messages: List[Dict[str, Any]] = field(default_factory=list)
|
||||
new_messages: List[Dict[str, Any]] = field(default_factory=list)
|
||||
error: str | None = None
|
||||
|
||||
class Assistant:
|
||||
client: AsyncOpenAI
|
||||
documentation: Documentation
|
||||
_client: AsyncOpenAI
|
||||
documentation: Documentation | None
|
||||
system_prompt: str
|
||||
model: str
|
||||
tools: list[dict]
|
||||
models: Dict[str, str]
|
||||
|
||||
def __init__(self, client: AsyncOpenAI, documentation_directory: str, system_prompt: str | None = None):
|
||||
self.client = client
|
||||
self.documentation = Documentation(documentation_directory)
|
||||
self.system_prompt = system_prompt if system_prompt is not None else _DEFAULT_SYSTEM_PROMPT
|
||||
self.model = _DEFAULT_MODEL
|
||||
self.tools = _DEFAULT_TOOLS
|
||||
_tools: List[Dict]
|
||||
_tool_map: Dict[str, Coroutine[Any, Any, Dict]]
|
||||
|
||||
def __init__(self, client: AsyncOpenAI, documentation: Documentation | None, system_prompt: str, models: dict[str, str]):
|
||||
self._client = client
|
||||
self.documentation = documentation
|
||||
self.system_prompt = system_prompt
|
||||
self.models = models
|
||||
|
||||
self._tools = [_DOCUMENTATION_SEARCH_TOOL] if documentation is not None else []
|
||||
self._tool_map = {"search_documentation": self.documentation.search_documentation} if documentation is not None else {}
|
||||
|
||||
def ask(
|
||||
self,
|
||||
messages: list[dict],
|
||||
images_to_transcript: List[Tuple[int, int, str]] = [],
|
||||
*,
|
||||
loop: AbstractEventLoop | None = None,
|
||||
) -> Task[AssistantResponse]:
|
||||
|
|
@ -68,80 +64,107 @@ class Assistant:
|
|||
Start generating the assistant response and immediately return the asyncio.Task, so the caller can `await` it OR cancel it.
|
||||
"""
|
||||
loop = loop or get_running_loop()
|
||||
task = loop.create_task(self._ask(messages))
|
||||
task = loop.create_task(self._ask(messages, images_to_transcript))
|
||||
return task
|
||||
|
||||
# MODIFIED: The method now returns our structured response and handles cancellation.
|
||||
async def _ask(self, messages: list[dict]) -> AssistantResponse:
|
||||
async def _ask(self, messages: list[dict], images_to_transcript: List[Tuple[int, int, str]]) -> AssistantResponse:
|
||||
messages_copy = messages.copy() # The messages are not expected to be modified
|
||||
messages_copy.insert(0, {"role": "system", "content": self.system_prompt})
|
||||
|
||||
try:
|
||||
# The core logic is now wrapped to catch cancellation.
|
||||
response_content = await self._ask_inner(messages_copy)
|
||||
new_messages, response_text = await self._ask_inner(messages_copy, images_to_transcript)
|
||||
print(messages_copy)
|
||||
|
||||
return AssistantResponse(
|
||||
status="completed",
|
||||
content=response_content,
|
||||
messages=messages_copy
|
||||
content=response_text,
|
||||
new_messages=new_messages
|
||||
)
|
||||
|
||||
except CancelledError:
|
||||
print("Assistant 'ask' task was cancelled.")
|
||||
# Return a specific response indicating cancellation.
|
||||
return AssistantResponse(
|
||||
status="cancelled",
|
||||
messages=messages_copy # Return the history up to the point of cancellation
|
||||
new_messages=[]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred in 'ask': {e}")
|
||||
# Return a response for any other errors.
|
||||
print(f"An error occurred in 'ask': {e}")
|
||||
|
||||
return AssistantResponse(
|
||||
status="error",
|
||||
messages=messages_copy,
|
||||
new_messages=[],
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def _ask_inner(self, messages: List[Dict[str, Any]]) -> str:
|
||||
response_segments = []
|
||||
async def _ask_inner(self, messages: List[Dict[str, Any]], images_to_transcript: List[Tuple[int, int, str]]) -> Tuple[List[Dict[str, Any]], str]:
|
||||
# +1 because system prompt
|
||||
for message_index, transcript_index, image_url in images_to_transcript:
|
||||
print("Transcribing", message_index+1, transcript_index, image_url)
|
||||
transcript = await self.transcribe_image(image_url)
|
||||
messages[message_index+1]['content'] = messages[message_index+1]['content'].replace(f"%TRANSCRIPT_{transcript_index}%", transcript)
|
||||
|
||||
new_messages = []
|
||||
response_text_parts = []
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
model = self.models['cheap'] if len(messages) > 10 else self.models['fast'] # TODO not hardcode
|
||||
print("Using model", model)
|
||||
|
||||
while finish_reason == "tool_calls":
|
||||
print("Making API call...")
|
||||
completion = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
completion = await self._client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages + new_messages,
|
||||
temperature=0.0,
|
||||
tools=self.tools,
|
||||
tools=self._tools,
|
||||
tool_choice="auto"
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
messages.append(choice.message)
|
||||
new_messages.append(choice.message)
|
||||
|
||||
if choice.message.content:
|
||||
response_segments.append(choice.message.content)
|
||||
response_text_parts.append(choice.message.content)
|
||||
|
||||
if finish_reason == "tool_calls" and choice.message.tool_calls is not None:
|
||||
tool_map = {"search_documentation": self.documentation.search_documentation}
|
||||
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_arguments = loads(tool_call.function.arguments)
|
||||
tool_function = tool_map[tool_call_name]
|
||||
|
||||
print(f"Executing tool: {tool_call_name}...")
|
||||
tool_function = self._tool_map[tool_call_name]
|
||||
tool_result = await tool_function(**tool_call_arguments)
|
||||
print("Tool result:", tool_result)
|
||||
|
||||
messages.append({
|
||||
new_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call_name,
|
||||
"content": dumps(tool_result)
|
||||
})
|
||||
|
||||
return '\n\n'.join(response_segments)
|
||||
return new_messages, '\n\n'.join(response_text_parts)
|
||||
|
||||
async def transcribe_image(self, image_url: str) -> str:
|
||||
response = await self._client.chat.completions.create(
|
||||
model=self.models['image_transcript'],
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Transcribe this image in detail."},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
)
|
||||
|
||||
print("Image transcript result", response)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
|
@ -4,55 +4,7 @@ from collections import Counter
|
|||
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
from .persistent import caio_read_file
|
||||
|
||||
def _partition_by_most_frequent(data_list: List[Any]) -> Tuple[List[Any], List[Any]]:
|
||||
"""
|
||||
Partitions a list into two lists based on element frequency.
|
||||
|
||||
The first list contains all occurrences of the element(s) that appear
|
||||
most frequently. The second list contains all other elements.
|
||||
|
||||
Args:
|
||||
data_list: The input list of elements.
|
||||
|
||||
Returns:
|
||||
A tuple containing two lists:
|
||||
- The first list with all instances of the most frequent element(s).
|
||||
- The second list with the remaining elements.
|
||||
"""
|
||||
if not data_list:
|
||||
return [], []
|
||||
|
||||
counts = Counter(data_list)
|
||||
|
||||
max_frequency = counts.most_common(1)[0][1]
|
||||
|
||||
most_frequent_keys = {
|
||||
item for item, count in counts.items() if count == max_frequency
|
||||
}
|
||||
|
||||
most_list = [item for item in data_list if item in most_frequent_keys]
|
||||
other_list = [item for item in data_list if item not in most_frequent_keys]
|
||||
|
||||
return most_list, other_list
|
||||
|
||||
def _load_tags(base_directory: str) -> dict:
|
||||
tag_map = {}
|
||||
|
||||
for root, _, files in os.walk(base_directory):
|
||||
for name in files:
|
||||
file_path = os.path.join(root, name)
|
||||
tag_line = open(file_path).readline()
|
||||
tags = [tag.strip() for tag in tag_line[6:].split(",")]
|
||||
|
||||
for tag in tags:
|
||||
if tag in tag_map:
|
||||
tag_map[tag].append(file_path)
|
||||
else :
|
||||
tag_map[tag] = [file_path]
|
||||
|
||||
return tag_map
|
||||
from ..common.persistence import caio_read_file
|
||||
|
||||
class Documentation:
|
||||
base_directory: str
|
||||
|
|
@ -108,4 +60,42 @@ class Documentation:
|
|||
"probable": p[1]
|
||||
},
|
||||
"guide": "Call this function again with the filename as query to view that file."
|
||||
}
|
||||
}
|
||||
|
||||
def _partition_by_most_frequent(data_list: List[Any]) -> Tuple[List[Any], List[Any]]:
|
||||
if not data_list:
|
||||
return [], []
|
||||
|
||||
counts = Counter(data_list)
|
||||
|
||||
# Find the highest frequency value. Using max() is more direct than most_common().
|
||||
max_frequency = max(counts.values())
|
||||
|
||||
most_frequent_items = []
|
||||
other_items = []
|
||||
|
||||
# Iterate through the unique items and their counts
|
||||
for item, count in counts.items():
|
||||
if count == max_frequency:
|
||||
most_frequent_items.append(item)
|
||||
else:
|
||||
other_items.append(item)
|
||||
|
||||
return most_frequent_items, other_items
|
||||
|
||||
def _load_tags(base_directory: str) -> dict:
|
||||
tag_map = {}
|
||||
|
||||
for root, _, files in os.walk(base_directory):
|
||||
for name in files:
|
||||
file_path = os.path.join(root, name)
|
||||
tag_line = open(file_path).readline()
|
||||
tags = [tag.strip() for tag in tag_line[6:].split(",")]
|
||||
|
||||
for tag in tags:
|
||||
if tag in tag_map:
|
||||
tag_map[tag].append(file_path)
|
||||
else :
|
||||
tag_map[tag] = [file_path]
|
||||
|
||||
return tag_map
|
||||
16
src/plugin_assist/assistant/multi.py
Normal file
16
src/plugin_assist/assistant/multi.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from .assistant import Assistant
|
||||
|
||||
class Assistants:
|
||||
_assistants: dict[str, Assistant]
|
||||
|
||||
def __init__(self, assistants: dict[str, Assistant]):
|
||||
self._assistants = assistants
|
||||
|
||||
def get_assistant(self, id: str):
|
||||
return self._assistants[id]
|
||||
|
||||
def get_assistant_ids(self) -> set[str]:
|
||||
return set(self._assistants.keys())
|
||||
|
||||
def has_assistant(self, id: str):
|
||||
return id in self._assistants
|
||||
|
|
@ -1,23 +1,25 @@
|
|||
import disnake
|
||||
from disnake.ext import commands
|
||||
|
||||
from ..assistant import Assistant
|
||||
from .threads import AssistedThread
|
||||
from ..assistant import Assistants
|
||||
from .threads import AssistedThreads
|
||||
from .persistence import ThreadPersistence
|
||||
|
||||
class AssistantBot(commands.InteractionBot):
|
||||
forum_id: int
|
||||
assistant: Assistant
|
||||
threads: dict[int, AssistedThread]
|
||||
assistants: Assistants
|
||||
forum_ids: list[int]
|
||||
|
||||
threads: AssistedThreads
|
||||
|
||||
def __init__(self, forum_id: int, assistant: Assistant, **kwargs):
|
||||
def __init__(self, assistants: Assistants, forum_ids: list[int], thread_persistence_directory: str, **kwargs):
|
||||
intents = disnake.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
super().__init__(intents=intents, **kwargs)
|
||||
|
||||
self.forum_id = forum_id
|
||||
self.assistant = assistant
|
||||
self.threads = {}
|
||||
self.assistants = assistants
|
||||
self.forum_ids = forum_ids
|
||||
self.threads = AssistedThreads(assistants, ThreadPersistence(thread_persistence_directory))
|
||||
|
||||
async def on_ready(self):
|
||||
print(f"Logged in as {self.user} (ID: {self.user.id})\n------")
|
||||
|
|
@ -27,50 +29,52 @@ class AssistantBot(commands.InteractionBot):
|
|||
)
|
||||
|
||||
async def on_thread_create(self, thread: disnake.Thread):
|
||||
if thread.parent_id != self.forum_id:
|
||||
if thread.parent_id not in self.forum_ids:
|
||||
return
|
||||
|
||||
if thread.id not in self.threads:
|
||||
self.threads[thread.id] = AssistedThread(thread, self.assistant)
|
||||
self.threads.create(thread)
|
||||
|
||||
async def on_thread_update(self, before: disnake.Thread, thread: disnake.Thread):
|
||||
if thread.parent_id != self.forum_id:
|
||||
if thread.parent_id not in self.forum_ids:
|
||||
return
|
||||
|
||||
assisted_thread = await self.threads.get_by_thread(thread)
|
||||
if assisted_thread is None:
|
||||
return
|
||||
if thread.id not in self.threads:
|
||||
return # TODO the two checks will be replaced
|
||||
|
||||
print("Thread updated", thread.applied_tags)
|
||||
|
||||
if len(thread.applied_tags) != 1:
|
||||
return
|
||||
|
||||
plugin_name = thread.applied_tags[0].name
|
||||
assistant_name = thread.applied_tags[0].name
|
||||
|
||||
print("and thread is good")
|
||||
|
||||
if await self.threads[thread.id].initialize(plugin_name):
|
||||
await self.threads[thread.id].respond()
|
||||
|
||||
if await assisted_thread.initialize(assistant_name):
|
||||
await assisted_thread.respond()
|
||||
|
||||
async def on_message(self, message: disnake.Message):
|
||||
if message.author == self.user:
|
||||
if not await self.should_respond_to_message(message):
|
||||
return
|
||||
|
||||
await self.threads.process_message(message)
|
||||
|
||||
async def should_respond_to_message(self, message: disnake.Message) -> bool:
|
||||
if message.author == self.user:
|
||||
return False
|
||||
|
||||
if isinstance(message.channel, disnake.DMChannel):
|
||||
await message.reply("I can't respond to DMs. I can help only on our server.")
|
||||
return
|
||||
return False
|
||||
|
||||
if not isinstance(message.channel, disnake.Thread):
|
||||
return
|
||||
return False
|
||||
|
||||
if message.channel.parent_id not in self.forum_ids:
|
||||
return False
|
||||
|
||||
if '[-ai]' in message.content:
|
||||
return
|
||||
return False
|
||||
|
||||
thread = message.channel
|
||||
|
||||
if thread.parent_id != self.forum_id:
|
||||
return
|
||||
if thread.id not in self.threads:
|
||||
return # TODO the two checks will be replaced
|
||||
|
||||
await self.threads[thread.id].on_message(message)
|
||||
return True
|
||||
48
src/plugin_assist/bot/persistence.py
Normal file
48
src/plugin_assist/bot/persistence.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
from json import JSONEncoder, JSONDecoder, loads, dumps
|
||||
from openai.types.chat.chat_completion import ChatCompletionMessage
|
||||
from disnake import Thread
|
||||
|
||||
from ..common.persistence import Persistence
|
||||
from ..assistant import Assistants
|
||||
|
||||
from .threads import AssistedThread
|
||||
|
||||
class ThreadPersistence(Persistence):
|
||||
async def save(self, thread: AssistedThread):
|
||||
data = {
|
||||
"enabled": thread.enabled,
|
||||
"assistant_name": thread.assistant_name,
|
||||
"messages": thread.cached_messages,
|
||||
"transcript_queue": thread._transcript_queue
|
||||
}
|
||||
|
||||
print("Saving", data)
|
||||
content = dumps(data, cls=ChatCompletionMessageJSONEncoder)
|
||||
await self.save_string(str(thread.thread.id), content)
|
||||
|
||||
async def load(self, thread: Thread, assistants: Assistants) -> AssistedThread | None:
|
||||
data_raw = await self.load_string(str(thread.id))
|
||||
if data_raw is None:
|
||||
return None
|
||||
|
||||
data = loads(data_raw)
|
||||
print("Loaded", data)
|
||||
|
||||
return AssistedThread(
|
||||
thread=thread,
|
||||
assistants=assistants,
|
||||
enabled=data['enabled'],
|
||||
assistant_name=data['assistant_name'],
|
||||
cached_messages=data['messages'],
|
||||
_transcript_queue=data['transcript_queue']
|
||||
)
|
||||
|
||||
class ChatCompletionMessageJSONEncoder(JSONEncoder):
|
||||
def _prepare_message(self, message: ChatCompletionMessage):
|
||||
return {k: v for k, v in message.__dict__.items() if v is not None}
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, ChatCompletionMessage):
|
||||
return self._prepare_message(obj)
|
||||
|
||||
return super().default(obj)
|
||||
|
|
@ -1,101 +1,171 @@
|
|||
import disnake
|
||||
from dataclasses import dataclass, field
|
||||
from asyncio import Task
|
||||
from typing import List
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
from ..assistant.assistant import Assistant
|
||||
from ..assistant.multi import Assistants
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from .persistence import ThreadPersistence
|
||||
|
||||
@dataclass
|
||||
class AssistedThreads:
|
||||
assistants: Assistants
|
||||
persistence: 'ThreadPersistence'
|
||||
|
||||
_threads: dict[int, 'AssistedThread'] = field(default_factory=dict)
|
||||
|
||||
def create(self, thread: disnake.Thread) -> 'AssistedThread':
|
||||
if thread.id in self._threads:
|
||||
return self._threads[thread.id]
|
||||
|
||||
self._threads[thread.id] = AssistedThread(thread, self.assistants)
|
||||
return self._threads[thread.id]
|
||||
|
||||
def get_by_id(self, id: int) -> 'AssistedThread':
|
||||
return self._threads[id]
|
||||
|
||||
async def get_by_thread(self, thread: disnake.Thread) -> 'AssistedThread':
|
||||
if not thread.id in self._threads:
|
||||
loaded_thread = await self.persistence.load(thread, self.assistants)
|
||||
|
||||
if loaded_thread is not None:
|
||||
self._threads[thread.id] = loaded_thread
|
||||
else:
|
||||
self.create(thread)
|
||||
|
||||
return self._threads[thread.id]
|
||||
|
||||
async def process_message(self, message: disnake.Message):
|
||||
if not isinstance(message.channel, disnake.Thread):
|
||||
return
|
||||
|
||||
assisted_thread = await self.get_by_thread(message.channel)
|
||||
|
||||
await assisted_thread.on_message(message)
|
||||
await self.persistence.save(assisted_thread)
|
||||
|
||||
@dataclass
|
||||
class AssistedThread:
|
||||
thread: disnake.Thread
|
||||
assistant: Assistant
|
||||
assistants: Assistants
|
||||
enabled: bool = False
|
||||
plugin: str = ""
|
||||
cached_messages: List[dict] = field(default_factory=list)
|
||||
last_warning: disnake.Message | None = None
|
||||
current_response: Task | None = None
|
||||
assistant_name: str = ""
|
||||
cached_messages: List[Dict] = field(default_factory=list)
|
||||
|
||||
async def on_message(self, message: disnake.Message):
|
||||
"""TODO handle editing and alike: messages = [{
|
||||
"role": "user" if m.author != self.user else "assistant",
|
||||
"content": m.content
|
||||
} async for m in query_message.channel.history(oldest_first=True)]"""
|
||||
|
||||
_transcript_queue: List[Tuple[int, int, str]] = field(default_factory=list)
|
||||
_last_warning: disnake.Message | None = None
|
||||
_current_response: Task | None = None
|
||||
|
||||
async def _add_new_message(self, message: disnake.Message):
|
||||
content = message.content
|
||||
message_index = len(self.cached_messages)
|
||||
|
||||
for attachment in message.attachments:
|
||||
content += f"\n\nAttached file {attachment.filename}:\n"
|
||||
if 'text' in attachment.content_type:
|
||||
data = await attachment.read()
|
||||
content += f"\n\nAttached file {attachment.filename}:\n---\n{data.decode('utf-8')}\n---"
|
||||
data = await attachment.read() # TODO this is the only await maybe we could remove it?
|
||||
content += "---\n" + data.decode('utf-8') + "\n---"
|
||||
elif 'image' in attachment.content_type:
|
||||
self._transcript_queue.append((message_index, attachment.id, attachment.proxy_url))
|
||||
content += f"---\nImage transcript:\n%TRANSCRIPT_{attachment.id}%\n---"
|
||||
else:
|
||||
content += "\n\nAttached file " + attachment.filename + " is of unsupported type, cannot read."
|
||||
content += "unsupported type, cannot read."
|
||||
|
||||
self.cached_messages.append(
|
||||
{'role': 'user', 'content': content}
|
||||
)
|
||||
|
||||
if not self.enabled:
|
||||
if self.last_warning is None:
|
||||
if len(self.thread.applied_tags) != 1:
|
||||
await self.warning(":label: Please keep only a single tag on this thread, I'll then be able to assist you better.")
|
||||
return
|
||||
else:
|
||||
plugin_name = self.thread.applied_tags[0].name
|
||||
await self.initialize(plugin_name)
|
||||
else:
|
||||
return
|
||||
async def _check_tags(self) -> bool:
|
||||
if len(self.thread.applied_tags) == 0:
|
||||
await self.warning(":label: Please tag this thread, I'll be able to assist you better.")
|
||||
return False
|
||||
|
||||
#await self.respond(message)
|
||||
await self.respond()
|
||||
if len(self.thread.applied_tags) > 1:
|
||||
await self.warning(":label: Please keep only a single tag on this thread, I'll be able to assist you better.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def initialize(self, plugin: str) -> bool:
|
||||
async def on_message(self, message: disnake.Message):
|
||||
"""TODO handle editing and alike?? messages = [{
|
||||
"role": "user" if m.author != self.user else "assistant",
|
||||
"content": m.content
|
||||
} async for m in query_message.channel.history(oldest_first=True)]"""
|
||||
|
||||
await self._add_new_message(message)
|
||||
|
||||
if not self.enabled:
|
||||
if self._last_warning is not None:
|
||||
return False
|
||||
|
||||
if await self._check_tags():
|
||||
assistant_name = self.thread.applied_tags[0].name
|
||||
await self.initialize(assistant_name)
|
||||
|
||||
return await self.respond()
|
||||
|
||||
async def initialize(self, assistant_name: str) -> bool:
|
||||
if self.enabled:
|
||||
return False
|
||||
|
||||
await self.remove_last_warning()
|
||||
|
||||
if plugin != "RealWeather":
|
||||
#await self.warning(":slight_frown: The automated assistant cannot assist you with this plugin. :hourglass: Wait for human support.")
|
||||
self.enabled = False
|
||||
|
||||
if not self.assistants.has_assistant(assistant_name):
|
||||
return False
|
||||
|
||||
self.plugin = plugin
|
||||
self.assistant_name = assistant_name
|
||||
self.enabled = True
|
||||
|
||||
return True
|
||||
|
||||
async def warning(self, message: str, reply_to: disnake.Message | None = None):
|
||||
if reply_to is None:
|
||||
self.last_warning = await self.thread.send(message)
|
||||
self._last_warning = await self.thread.send(message)
|
||||
else:
|
||||
self.last_warning = await reply_to.reply(message)
|
||||
self._last_warning = await reply_to.reply(message)
|
||||
|
||||
async def remove_last_warning(self):
|
||||
if self.last_warning is not None:
|
||||
await self.last_warning.delete()
|
||||
self.last_warning = None
|
||||
if self._last_warning is not None:
|
||||
await self._last_warning.delete()
|
||||
self._last_warning = None
|
||||
|
||||
async def respond(self, message: disnake.Message | None = None):
|
||||
if self.current_response is not None:
|
||||
async def respond(self, message: disnake.Message | None = None) -> bool:
|
||||
if not self.enabled:
|
||||
print("Want to respond but disabled")
|
||||
return False
|
||||
|
||||
if self._current_response is not None:
|
||||
print("Cancelled old message")
|
||||
self.current_response.cancel()
|
||||
self._current_response.cancel()
|
||||
|
||||
async with self.thread.typing():
|
||||
self.current_response = self.assistant.ask(self.cached_messages)
|
||||
response = await self.current_response
|
||||
assistant = self.assistants.get_assistant(self.assistant_name)
|
||||
self._current_response = assistant.ask(self.cached_messages, self._transcript_queue)
|
||||
response = await self._current_response
|
||||
|
||||
if response.status == 'completed':
|
||||
self.cached_messages = response.messages
|
||||
|
||||
if message is not None:
|
||||
await reply_long(message, response.content)
|
||||
else:
|
||||
await send_long(self.thread, response.content)
|
||||
print("New messages:", response.new_messages)
|
||||
self.cached_messages += response.new_messages
|
||||
self._transcript_queue = []
|
||||
self._current_response = None
|
||||
|
||||
await self.send_or_reply(response.content, message)
|
||||
|
||||
return True
|
||||
elif response.status == 'cancelled':
|
||||
pass
|
||||
else:
|
||||
await message.reply(":bangbang: Sorry, an error occurred. Cannot proceed with your request.")
|
||||
await self.send_or_reply(":bangbang: Sorry, an error occurred. Cannot proceed with your request.", message)
|
||||
|
||||
return False
|
||||
|
||||
async def send_or_reply(self, content: str, message: disnake.Message | None = None):
|
||||
if message is not None:
|
||||
await reply_long(message, content)
|
||||
else:
|
||||
await send_long(self.thread, content)
|
||||
|
||||
|
||||
# TODO move this stuff
|
||||
|
|
|
|||
19
src/plugin_assist/common/file_utils.py
Normal file
19
src/plugin_assist/common/file_utils.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from caio import AsyncioContext
|
||||
|
||||
async def caio_read_file(context: AsyncioContext, fd: int) -> str:
|
||||
message = bytearray()
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
content = await context.read(512, fd, offset)
|
||||
|
||||
if len(content) == 0:
|
||||
break
|
||||
|
||||
message += content
|
||||
offset += len(content)
|
||||
|
||||
return message.decode('utf-8')
|
||||
|
||||
async def caio_write_file(context: AsyncioContext, fd: int, content: str):
|
||||
await context.write(content.encode(), fd, 0)
|
||||
|
|
@ -1,23 +1,7 @@
|
|||
import os
|
||||
from caio import AsyncioContext
|
||||
|
||||
async def caio_read_file(context: AsyncioContext, fd: int) -> str:
|
||||
message = bytearray()
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
content = await context.read(512, fd, offset)
|
||||
|
||||
if len(content) == 0:
|
||||
break
|
||||
|
||||
message += content
|
||||
offset += len(content)
|
||||
|
||||
return message.decode('utf-8')
|
||||
|
||||
async def caio_write_file(context: AsyncioContext, fd: int, content: str):
|
||||
await context.write(content.encode(), fd, 0)
|
||||
from .file_utils import caio_read_file, caio_write_file
|
||||
|
||||
class Persistence:
|
||||
context: AsyncioContext
|
||||
|
|
@ -27,7 +11,7 @@ class Persistence:
|
|||
self.context = AsyncioContext()
|
||||
self.base_directory = base_directory
|
||||
|
||||
async def save(self, id: str, content: str):
|
||||
async def save_string(self, id: str, content: str):
|
||||
filename = os.path.join(self.base_directory, id)
|
||||
fd = os.open(filename, os.O_CREAT | os.O_WRONLY)
|
||||
|
||||
|
|
@ -36,7 +20,7 @@ class Persistence:
|
|||
finally:
|
||||
os.close(fd)
|
||||
|
||||
async def load(self, id: str) -> str:
|
||||
async def load_string(self, id: str) -> str | None:
|
||||
filename = os.path.join(self.base_directory, id)
|
||||
fd = os.open(filename, os.O_RDONLY)
|
||||
|
||||
55
src/plugin_assist/settings.py
Normal file
55
src/plugin_assist/settings.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class InferenceSettings:
|
||||
base_url: str | None
|
||||
api_key: str | None
|
||||
models: dict[str, str]
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> 'InferenceSettings':
|
||||
return InferenceSettings(
|
||||
base_url = data.get('base_url', None),
|
||||
api_key = data.get('api_key', None),
|
||||
models = data['models']
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class AssistantSettings:
|
||||
system_prompt: str
|
||||
documentation_directory: str | None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> 'AssistantSettings':
|
||||
return AssistantSettings(
|
||||
system_prompt = '\n'.join(data['system_prompt']),
|
||||
documentation_directory = data.get('documentation_directory', None)
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class BotSettings:
|
||||
token: str | None
|
||||
allowed_forum_ids: list[int]
|
||||
thread_persistence_directory: str
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> 'BotSettings':
|
||||
return BotSettings(
|
||||
token = data.get('token', None),
|
||||
allowed_forum_ids = data['allowed_forum_ids'],
|
||||
thread_persistence_directory = data.get('thread_persistence_directory', 'threads')
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class Settings:
|
||||
inference: InferenceSettings
|
||||
assistants: dict[str, AssistantSettings]
|
||||
bot: BotSettings
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> 'Settings':
|
||||
return Settings(
|
||||
inference = InferenceSettings.from_dict(data['inference']),
|
||||
assistants = {k: AssistantSettings.from_dict(v) for k, v in data['assistants'].items()},
|
||||
bot = BotSettings.from_dict(data['bot'])
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue