import argparse
import json
import sqlite3
import os
from typing import Optional
[docs]
class Database:
DEFAULT_DATABASE = os.path.expanduser("~/.smah/smah.db")
[docs]
@staticmethod
def default_database() -> str:
return Database.DEFAULT_DATABASE
[docs]
@staticmethod
def args_to_dict(args: argparse.Namespace) -> dict:
return vars(args)
def __init__(self, args):
file = args.database or self.default_database()
migrate = False
if not os.path.exists(file):
os.makedirs(os.path.dirname(file), exist_ok=True)
migrate = True
self.connection: sqlite3.Connection = sqlite3.connect(file)
[docs]
def last_session(self):
cursor = self.connection.cursor()
cursor.execute(
"""
SELECT setting_value
FROM settings
WHERE setting = ?
""",
("last_session",)
)
result = cursor.fetchone()
cursor.close()
if result:
(session_id,) = result
session_id = int(session_id)
return self.session(session_id)
return None
[docs]
def session(self, session_id: int):
cursor = self.connection.cursor()
cursor.execute(
"""
SELECT chat_history.id, chat_history.title, chat_history.created_on, chat_history.modified_on, chat_history_details.args, chat_history_details.plan, chat_history_details.pipe_input
FROM chat_history
JOIN chat_history_details
ON chat_history.id = chat_history_details.chat_history_id
WHERE chat_history.id = ?
""",
(session_id,)
)
result = cursor.fetchone()
# get chat_history_messages
mr = cursor.execute(
"""
SELECT message
FROM chat_history_message
WHERE chat_history_id = ?
ORDER BY id ASC
""",
(session_id,)
).fetchall()
messages = []
for row in mr:
(row,) = row
messages.append(json.loads(row))
cursor.close()
if result:
id, title, created_on, modified_on, args, plan, pipe = result
return {
"id": id,
"title": title,
"created_on": created_on,
"modified_on": modified_on,
"args": json.loads(args),
"plan": json.loads(plan),
"pipe": pipe,
"messages": messages
}
return None
[docs]
def history(self, limit: int = 10):
cursor = self.connection.cursor()
cursor.execute(
"""
SELECT chat_history.id, chat_history.title, chat_history.created_on, chat_history.modified_on
FROM chat_history
ORDER BY created_on DESC
LIMIT ?
""",
(limit,)
)
result = cursor.fetchall()
cursor.close()
response = []
for row in result:
id, title, created_on, modified_on = row
response.append({
"id": id,
"title": title,
"created_on": created_on,
"modified_on": modified_on
})
response.reverse()
return response
[docs]
def append_to_chat(self, session_id: int, messages: list) -> None:
cursor = self.connection.cursor()
cursor.execute("BEGIN TRANSACTION")
for message in messages:
cursor.execute(
"""
INSERT INTO chat_history_message (chat_history_id, message)
VALUES (?, ?)
""",
(session_id, json.dumps(message))
)
cursor.execute("COMMIT")
cursor.close()
[docs]
def save_chat(self, title: str, args: argparse.Namespace, plan: dict, messages: list, pipe: Optional[str] = None) -> None:
cursor = self.connection.cursor()
cursor.execute("BEGIN TRANSACTION")
# Insert into chat_history
cursor.execute(
"""
INSERT INTO chat_history (title)
VALUES (?)
""", (title,)
)
chat_history_id = cursor.lastrowid
# Insert into chat_history_details
cursor.execute(
"""
INSERT INTO chat_history_details (chat_history_id, args, plan, pipe_input)
VALUES (?, ?, ?, ?)
""", (chat_history_id, json.dumps(self.args_to_dict(args)), json.dumps(plan), pipe)
)
# Insert into chat_history_message
for message in messages:
cursor.execute(
"""
INSERT INTO chat_history_message (chat_history_id, message)
VALUES (?, ?)
""", (chat_history_id, json.dumps(message))
)
cursor.execute(
"""
INSERT INTO settings (setting, setting_value)
VALUES (?, ?)
ON CONFLICT(setting) DO UPDATE SET
setting_value = excluded.setting_value
""",
("last_session", f"{chat_history_id}")
)
# Commit the transaction
cursor.execute("COMMIT")
cursor.close()