Initial commit
This commit is contained in:
commit
ed7ac078f9
16 changed files with 1821 additions and 0 deletions
3
.dockerignore
Normal file
3
.dockerignore
Normal file
|
@ -0,0 +1,3 @@
|
|||
__pycache__/
|
||||
doc/
|
||||
.env
|
11
.env.example
Normal file
11
.env.example
Normal file
|
@ -0,0 +1,11 @@
|
|||
# The bot token
|
||||
BOT_TOKEN=
|
||||
|
||||
# The forum ID
|
||||
FORUM_ID=
|
||||
|
||||
# API key for the OpenAI-compatible provider
|
||||
OPENAI_API_KEY=
|
||||
|
||||
# Base URL for the OpenAI-compatible provider
|
||||
OPENAI_BASE_URL=
|
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
__pycache__/
|
||||
doc/
|
||||
.env
|
66
Containerfile
Normal file
66
Containerfile
Normal file
|
@ -0,0 +1,66 @@
|
|||
# =====================================================================
|
||||
# Builder Stage: Install dependencies into a virtual environment
|
||||
# =====================================================================
|
||||
FROM python:3.13-alpine AS builder
|
||||
|
||||
# Set environment variables
|
||||
# - PYTHONDONTWRITEBYTECODE: Prevents Python from writing .pyc files
|
||||
# - PYTHONUNBUFFERED: Ensures that Python output is sent straight to the terminal
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
# Install build-time system dependencies like gcc.
|
||||
# 'build-base' is a meta-package on Alpine that includes gcc, g++, make, and other essentials.
|
||||
# We use --no-cache to avoid storing the package index, keeping the layer smaller.
|
||||
RUN apk add --no-cache build-base libaio-dev linux-headers
|
||||
|
||||
# Install Poetry
|
||||
RUN pip install poetry
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Configure Poetry to create the virtual environment inside the project's directory
|
||||
# This makes it easy to copy the venv to the next stage
|
||||
RUN poetry config virtualenvs.in-project true
|
||||
|
||||
# Copy only the dependency files to leverage Docker cache
|
||||
COPY pyproject.toml poetry.lock ./
|
||||
|
||||
# Install production dependencies
|
||||
# --no-interaction and --no-ansi prevent interactive prompts and color output
|
||||
RUN poetry install --no-interaction --no-ansi --no-root --only main
|
||||
|
||||
# Copy the rest of the application code
|
||||
COPY src/ ./src/
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Final Stage: Create the production image
|
||||
# =====================================================================
|
||||
FROM python:3.13-alpine AS final
|
||||
|
||||
# Create a non-root user and group
|
||||
RUN addgroup -S appuser && adduser -S -G appuser appuser
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy the virtual environment from the builder stage
|
||||
COPY --from=builder --chown=appuser:appuser /app/.venv ./.venv
|
||||
|
||||
# Copy the application code from the builder stage
|
||||
COPY --from=builder --chown=appuser:appuser /app/src ./src
|
||||
|
||||
# Add the virtual environment's bin directory to the PATH
|
||||
# This allows us to run executables directly (e.g., `gunicorn`, `uvicorn`)
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
# Switch to the non-root user
|
||||
USER appuser
|
||||
|
||||
# Define the command to run the application.
|
||||
# Assumes your `pyproject.toml` has a `[tool.poetry.scripts]` entry like:
|
||||
# start = "gunicorn --bind 0.0.0.0:8000 my_app.wsgi:application"
|
||||
# The `start` script is now on the PATH, so we can call it directly.
|
||||
CMD ["python3", "-m", "src.plugin_assist"]
|
0
README.md
Normal file
0
README.md
Normal file
11
docker-compose.yml
Normal file
11
docker-compose.yml
Normal file
|
@ -0,0 +1,11 @@
|
|||
services:
|
||||
bot:
|
||||
image:
|
||||
build: .
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
BOT_TOKEN:
|
||||
FORUM_ID:
|
||||
OPENAI_API_KEY:
|
||||
OPENAI_BASE_URL:
|
1127
poetry.lock
generated
Normal file
1127
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
23
pyproject.toml
Normal file
23
pyproject.toml
Normal file
|
@ -0,0 +1,23 @@
|
|||
[project]
|
||||
name = "plugin_assist"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = [
|
||||
{name = "Minecon724",email = "dm@m724.eu"}
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"openai (>=1.98.0,<2.0.0)",
|
||||
"disnake (>=2.10.1,<3.0.0)",
|
||||
"caio (>=0.9.24,<0.10.0)"
|
||||
]
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
|
||||
[tool.poetry.scripts]
|
||||
start = "plugin_assist:main"
|
25
src/plugin_assist/__init__.py
Normal file
25
src/plugin_assist/__init__.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from .assistant import Assistant
|
||||
from .bot import AssistantBot
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from os import environ
|
||||
|
||||
def main():
|
||||
client = AsyncOpenAI(
|
||||
base_url = environ.get("OPENAI_BASE_URL")
|
||||
)
|
||||
|
||||
assistant = Assistant(
|
||||
client = client,
|
||||
documentation_directory = "doc"
|
||||
)
|
||||
|
||||
bot = AssistantBot(
|
||||
assistant = assistant,
|
||||
forum_id = int(environ.get("FORUM_ID"))
|
||||
)
|
||||
|
||||
print("Starting bot")
|
||||
|
||||
token = environ.get("BOT_TOKEN")
|
||||
bot.run(token)
|
4
src/plugin_assist/__main__.py
Normal file
4
src/plugin_assist/__main__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from . import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
4
src/plugin_assist/assistant/__init__.py
Normal file
4
src/plugin_assist/assistant/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from .assistant import Assistant
|
||||
from .documentation import Documentation
|
||||
|
||||
__all__ = ['Assistant', 'Documentation']
|
147
src/plugin_assist/assistant/assistant.py
Normal file
147
src/plugin_assist/assistant/assistant.py
Normal file
|
@ -0,0 +1,147 @@
|
|||
from openai import AsyncOpenAI
|
||||
from asyncio import AbstractEventLoop, Task, get_running_loop, CancelledError
|
||||
from dataclasses import dataclass, field
|
||||
from json import loads, dumps
|
||||
from typing import Any, Literal, List, Dict
|
||||
|
||||
from .documentation import Documentation
|
||||
|
||||
_DEFAULT_SYSTEM_PROMPT = """
|
||||
You are an assistant who helps with the RealWeather Minecraft plugin.
|
||||
RealWeather is a plugin that synchronizes real weather, time and lightning strikes.
|
||||
It's alpha software, and right now it's kinda more ideal for ambience, docs will tell you more.
|
||||
You have the documentaton available, use it.
|
||||
"""
|
||||
|
||||
_DEFAULT_TOOLS = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_documentation",
|
||||
"description": "Search the plugin documentation.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["query"],
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The detailed search query as tags."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
_DEFAULT_MODEL = "kimi-k2-instruct-fast"
|
||||
|
||||
# A type hint for the status, making it clear what the possible values are.
|
||||
ResponseStatus = Literal["completed", "cancelled", "error"]
|
||||
|
||||
@dataclass
|
||||
class AssistantResponse:
|
||||
"""A structured response from the Assistant."""
|
||||
status: ResponseStatus
|
||||
content: str = ""
|
||||
messages: List[Dict[str, Any]] = field(default_factory=list)
|
||||
error: str | None = None
|
||||
|
||||
class Assistant:
|
||||
client: AsyncOpenAI
|
||||
documentation: Documentation
|
||||
system_prompt: str
|
||||
model: str
|
||||
tools: list[dict]
|
||||
|
||||
def __init__(self, client: AsyncOpenAI, documentation_directory: str, system_prompt: str | None = None):
|
||||
self.client = client
|
||||
self.documentation = Documentation(documentation_directory)
|
||||
self.system_prompt = system_prompt if system_prompt is not None else _DEFAULT_SYSTEM_PROMPT
|
||||
self.model = _DEFAULT_MODEL
|
||||
self.tools = _DEFAULT_TOOLS
|
||||
|
||||
def ask(
|
||||
self,
|
||||
messages: list[dict],
|
||||
*,
|
||||
loop: AbstractEventLoop | None = None,
|
||||
) -> Task[AssistantResponse]:
|
||||
"""
|
||||
Start generating the assistant response and immediately return the asyncio.Task, so the caller can `await` it OR cancel it.
|
||||
"""
|
||||
loop = loop or get_running_loop()
|
||||
task = loop.create_task(self._ask(messages))
|
||||
return task
|
||||
|
||||
# MODIFIED: The method now returns our structured response and handles cancellation.
|
||||
async def _ask(self, messages: list[dict]) -> AssistantResponse:
|
||||
messages_copy = messages.copy() # The messages are not expected to be modified
|
||||
messages_copy.insert(0, {"role": "system", "content": self.system_prompt})
|
||||
|
||||
try:
|
||||
# The core logic is now wrapped to catch cancellation.
|
||||
response_content = await self._ask_inner(messages_copy)
|
||||
print(messages_copy)
|
||||
|
||||
return AssistantResponse(
|
||||
status="completed",
|
||||
content=response_content,
|
||||
messages=messages_copy
|
||||
)
|
||||
|
||||
except CancelledError:
|
||||
print("Assistant 'ask' task was cancelled.")
|
||||
# Return a specific response indicating cancellation.
|
||||
return AssistantResponse(
|
||||
status="cancelled",
|
||||
messages=messages_copy # Return the history up to the point of cancellation
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred in 'ask': {e}")
|
||||
# Return a response for any other errors.
|
||||
return AssistantResponse(
|
||||
status="error",
|
||||
messages=messages_copy,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def _ask_inner(self, messages: List[Dict[str, Any]]) -> str:
|
||||
response_segments = []
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
while finish_reason == "tool_calls":
|
||||
print("Making API call...")
|
||||
completion = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.0,
|
||||
tools=self.tools,
|
||||
tool_choice="auto"
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
messages.append(choice.message)
|
||||
|
||||
if choice.message.content:
|
||||
response_segments.append(choice.message.content)
|
||||
|
||||
if finish_reason == "tool_calls" and choice.message.tool_calls is not None:
|
||||
tool_map = {"search_documentation": self.documentation.search_documentation}
|
||||
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_arguments = loads(tool_call.function.arguments)
|
||||
tool_function = tool_map[tool_call_name]
|
||||
|
||||
print(f"Executing tool: {tool_call_name}...")
|
||||
tool_result = await tool_function(**tool_call_arguments)
|
||||
print("Tool result:", tool_result)
|
||||
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call_name,
|
||||
"content": dumps(tool_result)
|
||||
})
|
||||
|
||||
return '\n\n'.join(response_segments)
|
111
src/plugin_assist/assistant/documentation.py
Normal file
111
src/plugin_assist/assistant/documentation.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
import os
|
||||
from caio import AsyncioContext
|
||||
from collections import Counter
|
||||
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
from .persistent import caio_read_file
|
||||
|
||||
def _partition_by_most_frequent(data_list: List[Any]) -> Tuple[List[Any], List[Any]]:
|
||||
"""
|
||||
Partitions a list into two lists based on element frequency.
|
||||
|
||||
The first list contains all occurrences of the element(s) that appear
|
||||
most frequently. The second list contains all other elements.
|
||||
|
||||
Args:
|
||||
data_list: The input list of elements.
|
||||
|
||||
Returns:
|
||||
A tuple containing two lists:
|
||||
- The first list with all instances of the most frequent element(s).
|
||||
- The second list with the remaining elements.
|
||||
"""
|
||||
if not data_list:
|
||||
return [], []
|
||||
|
||||
counts = Counter(data_list)
|
||||
|
||||
max_frequency = counts.most_common(1)[0][1]
|
||||
|
||||
most_frequent_keys = {
|
||||
item for item, count in counts.items() if count == max_frequency
|
||||
}
|
||||
|
||||
most_list = [item for item in data_list if item in most_frequent_keys]
|
||||
other_list = [item for item in data_list if item not in most_frequent_keys]
|
||||
|
||||
return most_list, other_list
|
||||
|
||||
def _load_tags(base_directory: str) -> dict:
|
||||
tag_map = {}
|
||||
|
||||
for root, _, files in os.walk(base_directory):
|
||||
for name in files:
|
||||
file_path = os.path.join(root, name)
|
||||
tag_line = open(file_path).readline()
|
||||
tags = [tag.strip() for tag in tag_line[6:].split(",")]
|
||||
|
||||
for tag in tags:
|
||||
if tag in tag_map:
|
||||
tag_map[tag].append(file_path)
|
||||
else :
|
||||
tag_map[tag] = [file_path]
|
||||
|
||||
return tag_map
|
||||
|
||||
class Documentation:
|
||||
base_directory: str
|
||||
tag_map: dict[str, list]
|
||||
context: AsyncioContext
|
||||
|
||||
def __init__(self, base_directory: str):
|
||||
self.base_directory = base_directory
|
||||
self.tag_map = _load_tags(base_directory)
|
||||
self.context = AsyncioContext()
|
||||
|
||||
print("Tags loaded:", self.tag_map)
|
||||
|
||||
async def search_documentation(self, query: str) -> dict:
|
||||
print("Saerching docs with query", query)
|
||||
|
||||
if [query] in self.tag_map.values():
|
||||
fd = os.open(query, os.O_RDONLY)
|
||||
|
||||
try:
|
||||
return await caio_read_file(self.context, fd)
|
||||
finally:
|
||||
os.close(fd)
|
||||
|
||||
candidates = []
|
||||
|
||||
tags_processed = ['']
|
||||
for c in query:
|
||||
c = c.lower()
|
||||
if c.isalpha():
|
||||
tags_processed[-1] += c
|
||||
else:
|
||||
tags_processed.append('')
|
||||
|
||||
tags_processed = [tag for tag in tags_processed if tag != '']
|
||||
|
||||
print("Searching for tags:", tags_processed)
|
||||
|
||||
|
||||
candidates = []
|
||||
|
||||
for tc in tags_processed:
|
||||
for tag in [t for t in self.tag_map.keys() if t in tc]:
|
||||
candidates += self.tag_map[tag]
|
||||
|
||||
print("Candidate files:", tags_processed)
|
||||
|
||||
p = _partition_by_most_frequent(candidates)
|
||||
|
||||
return {
|
||||
"matches": {
|
||||
"best": p[0],
|
||||
"probable": p[1]
|
||||
},
|
||||
"guide": "Call this function again with the filename as query to view that file."
|
||||
}
|
46
src/plugin_assist/assistant/persistent.py
Normal file
46
src/plugin_assist/assistant/persistent.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
import os
|
||||
from caio import AsyncioContext
|
||||
|
||||
async def caio_read_file(context: AsyncioContext, fd: int) -> str:
|
||||
message = bytearray()
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
content = await context.read(512, fd, offset)
|
||||
|
||||
if len(content) == 0:
|
||||
break
|
||||
|
||||
message += content
|
||||
offset += len(content)
|
||||
|
||||
return message.decode('utf-8')
|
||||
|
||||
async def caio_write_file(context: AsyncioContext, fd: int, content: str):
|
||||
await context.write(content.encode(), fd, 0)
|
||||
|
||||
class Persistence:
|
||||
context: AsyncioContext
|
||||
base_directory: str
|
||||
|
||||
def __init__(self, base_directory: str):
|
||||
self.context = AsyncioContext()
|
||||
self.base_directory = base_directory
|
||||
|
||||
async def save(self, id: str, content: str):
|
||||
filename = os.path.join(self.base_directory, id)
|
||||
fd = os.open(filename, os.O_CREAT | os.O_WRONLY)
|
||||
|
||||
try:
|
||||
await caio_write_file(self.context, fd, content)
|
||||
finally:
|
||||
os.close(fd)
|
||||
|
||||
async def load(self, id: str) -> str:
|
||||
filename = os.path.join(self.base_directory, id)
|
||||
fd = os.open(filename, os.O_RDONLY)
|
||||
|
||||
try:
|
||||
return await caio_read_file(self.context, fd)
|
||||
finally:
|
||||
os.close(fd)
|
76
src/plugin_assist/bot/__init__.py
Normal file
76
src/plugin_assist/bot/__init__.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
import disnake
|
||||
from disnake.ext import commands
|
||||
|
||||
from ..assistant import Assistant
|
||||
from .threads import AssistedThread
|
||||
|
||||
class AssistantBot(commands.InteractionBot):
|
||||
forum_id: int
|
||||
assistant: Assistant
|
||||
threads: dict[int, AssistedThread]
|
||||
|
||||
def __init__(self, forum_id: int, assistant: Assistant, **kwargs):
|
||||
intents = disnake.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
super().__init__(intents=intents, **kwargs)
|
||||
|
||||
self.forum_id = forum_id
|
||||
self.assistant = assistant
|
||||
self.threads = {}
|
||||
|
||||
async def on_ready(self):
|
||||
print(f"Logged in as {self.user} (ID: {self.user.id})\n------")
|
||||
|
||||
await self.change_presence(
|
||||
activity = disnake.Game("with you")
|
||||
)
|
||||
|
||||
async def on_thread_create(self, thread: disnake.Thread):
|
||||
if thread.parent_id != self.forum_id:
|
||||
return
|
||||
|
||||
if thread.id not in self.threads:
|
||||
self.threads[thread.id] = AssistedThread(thread, self.assistant)
|
||||
|
||||
async def on_thread_update(self, before: disnake.Thread, thread: disnake.Thread):
|
||||
if thread.parent_id != self.forum_id:
|
||||
return
|
||||
if thread.id not in self.threads:
|
||||
return # TODO the two checks will be replaced
|
||||
|
||||
print("Thread updated", thread.applied_tags)
|
||||
|
||||
if len(thread.applied_tags) != 1:
|
||||
return
|
||||
|
||||
plugin_name = thread.applied_tags[0].name
|
||||
|
||||
print("and thread is good")
|
||||
|
||||
if await self.threads[thread.id].initialize(plugin_name):
|
||||
await self.threads[thread.id].respond()
|
||||
|
||||
|
||||
async def on_message(self, message: disnake.Message):
|
||||
if message.author == self.user:
|
||||
return
|
||||
|
||||
if isinstance(message.channel, disnake.DMChannel):
|
||||
await message.reply("I can't respond to DMs. I can help only on our server.")
|
||||
return
|
||||
|
||||
if not isinstance(message.channel, disnake.Thread):
|
||||
return
|
||||
|
||||
if '[-ai]' in message.content:
|
||||
return
|
||||
|
||||
thread = message.channel
|
||||
|
||||
if thread.parent_id != self.forum_id:
|
||||
return
|
||||
if thread.id not in self.threads:
|
||||
return # TODO the two checks will be replaced
|
||||
|
||||
await self.threads[thread.id].on_message(message)
|
164
src/plugin_assist/bot/threads.py
Normal file
164
src/plugin_assist/bot/threads.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
import disnake
|
||||
from dataclasses import dataclass, field
|
||||
from asyncio import Task
|
||||
from typing import List
|
||||
|
||||
from ..assistant.assistant import Assistant
|
||||
|
||||
@dataclass
|
||||
class AssistedThread:
|
||||
thread: disnake.Thread
|
||||
assistant: Assistant
|
||||
enabled: bool = False
|
||||
plugin: str = ""
|
||||
cached_messages: List[dict] = field(default_factory=list)
|
||||
last_warning: disnake.Message | None = None
|
||||
current_response: Task | None = None
|
||||
|
||||
async def on_message(self, message: disnake.Message):
|
||||
"""TODO handle editing and alike: messages = [{
|
||||
"role": "user" if m.author != self.user else "assistant",
|
||||
"content": m.content
|
||||
} async for m in query_message.channel.history(oldest_first=True)]"""
|
||||
|
||||
content = message.content
|
||||
|
||||
for attachment in message.attachments:
|
||||
if 'text' in attachment.content_type:
|
||||
data = await attachment.read()
|
||||
content += f"\n\nAttached file {attachment.filename}:\n---\n{data.decode('utf-8')}\n---"
|
||||
else:
|
||||
content += "\n\nAttached file " + attachment.filename + " is of unsupported type, cannot read."
|
||||
|
||||
self.cached_messages.append(
|
||||
{'role': 'user', 'content': content}
|
||||
)
|
||||
|
||||
if not self.enabled:
|
||||
if self.last_warning is None:
|
||||
if len(self.thread.applied_tags) != 1:
|
||||
await self.warning(":label: Please keep only a single tag on this thread, I'll then be able to assist you better.")
|
||||
return
|
||||
else:
|
||||
plugin_name = self.thread.applied_tags[0].name
|
||||
await self.initialize(plugin_name)
|
||||
else:
|
||||
return
|
||||
|
||||
#await self.respond(message)
|
||||
await self.respond()
|
||||
|
||||
async def initialize(self, plugin: str) -> bool:
|
||||
if self.enabled:
|
||||
return False
|
||||
|
||||
await self.remove_last_warning()
|
||||
|
||||
if plugin != "RealWeather":
|
||||
#await self.warning(":slight_frown: The automated assistant cannot assist you with this plugin. :hourglass: Wait for human support.")
|
||||
self.enabled = False
|
||||
return False
|
||||
|
||||
self.plugin = plugin
|
||||
self.enabled = True
|
||||
|
||||
return True
|
||||
|
||||
async def warning(self, message: str, reply_to: disnake.Message | None = None):
|
||||
if reply_to is None:
|
||||
self.last_warning = await self.thread.send(message)
|
||||
else:
|
||||
self.last_warning = await reply_to.reply(message)
|
||||
|
||||
async def remove_last_warning(self):
|
||||
if self.last_warning is not None:
|
||||
await self.last_warning.delete()
|
||||
self.last_warning = None
|
||||
|
||||
async def respond(self, message: disnake.Message | None = None):
|
||||
if self.current_response is not None:
|
||||
print("Cancelled old message")
|
||||
self.current_response.cancel()
|
||||
|
||||
async with self.thread.typing():
|
||||
self.current_response = self.assistant.ask(self.cached_messages)
|
||||
response = await self.current_response
|
||||
|
||||
if response.status == 'completed':
|
||||
self.cached_messages = response.messages
|
||||
|
||||
if message is not None:
|
||||
await reply_long(message, response.content)
|
||||
else:
|
||||
await send_long(self.thread, response.content)
|
||||
|
||||
elif response.status == 'cancelled':
|
||||
pass
|
||||
else:
|
||||
await message.reply(":bangbang: Sorry, an error occurred. Cannot proceed with your request.")
|
||||
|
||||
|
||||
# TODO move this stuff
|
||||
|
||||
async def reply_long(reply_to: disnake.Message, message: str):
|
||||
split_message = split_text(message, 2000)
|
||||
|
||||
if message is None:
|
||||
await reply_to.reply(split_message[0])
|
||||
|
||||
for part in split_message[1:]:
|
||||
await reply_to.channel.send(part)
|
||||
|
||||
async def send_long(channel: disnake.abc.Messageable, message: str):
|
||||
split_message = split_text(message, 2000)
|
||||
|
||||
for part in split_message:
|
||||
await channel.send(part)
|
||||
|
||||
def split_text(text: str, max_len: int) -> List[str]:
|
||||
"""
|
||||
Split `text` into chunks whose length is <= `max_len`.
|
||||
Priority of break-points:
|
||||
1. double newline (\n\n)
|
||||
2. single newline (\n)
|
||||
3. space (word boundary)
|
||||
4. hard cut (character boundary)
|
||||
"""
|
||||
if len(text) <= max_len:
|
||||
return [text]
|
||||
|
||||
def _pack(pieces: List[str], sep: str) -> List[str]:
|
||||
"""
|
||||
Greedily packs `pieces` (already split on `sep`) into chunks that
|
||||
respect `max_len`, re-adding `sep` between pieces.
|
||||
If an individual piece is still too large, we signal with None.
|
||||
"""
|
||||
chunks, buff = [], ""
|
||||
for piece in pieces:
|
||||
add_len = len(piece) + (len(sep) if buff else 0)
|
||||
if len(piece) > max_len:
|
||||
return None
|
||||
if len(buff) + add_len <= max_len:
|
||||
buff += (sep if buff else "") + piece
|
||||
else:
|
||||
chunks.append(buff)
|
||||
buff = piece
|
||||
if buff:
|
||||
chunks.append(buff)
|
||||
return chunks
|
||||
|
||||
for delim in ("\n\n", "\n", " "):
|
||||
parts = text.split(delim)
|
||||
packed = _pack(parts, delim)
|
||||
if packed is None:
|
||||
final_chunks = []
|
||||
for p in parts:
|
||||
if len(p) > max_len:
|
||||
final_chunks.extend(split_text(p, max_len))
|
||||
else:
|
||||
final_chunks.append(p)
|
||||
return split_text(delim.join(final_chunks), max_len)
|
||||
else:
|
||||
return packed
|
||||
|
||||
return [text[i : i + max_len] for i in range(0, len(text), max_len)]
|
Loading…
Add table
Add a link
Reference in a new issue