Source code for smah.database.migration

import argparse
import hashlib
import os
import time
import importlib
import textwrap
import sqlite3
from smah.database.database import Database

[docs] class Migration: # Migrations are stored in the `migrations` directory MIGRATIONS_DIR = os.path.join(os.path.dirname(__file__), "migrations") def __init__(self): pass
[docs] @staticmethod def get_schema_migrations(database: Database): cursor = database.connection.cursor() create_table = textwrap.dedent( """ CREATE TABLE IF NOT EXISTS schema_migrations( migration VARCHAR(255) PRIMARY KEY, checksum CHAR(32), applied BOOLEAN DEFAULT FALSE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, modified_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """ ).strip() cursor.execute(create_table) cursor.execute("SELECT migration, checksum, applied, created_at, modified_at FROM schema_migrations ORDER BY migration ASC") result = cursor.fetchall() cursor.close() migrations = [] for row in result: migration, checksum, applied, created_at, modified_at = row migrations.append( { 'migration': migration, 'checksum': checksum, 'applied': applied == 1, 'created_at': created_at, 'modified_at': modified_at } ) return migrations
[docs] @staticmethod def get_migrations(): migrations = [] os.makedirs(Migration.MIGRATIONS_DIR, exist_ok=True) for migration in os.listdir(Migration.MIGRATIONS_DIR): if migration.endswith(".py"): digest = hashlib.md5(open(os.path.join(Migration.MIGRATIONS_DIR, migration), "rb").read()).hexdigest() migrations.append({'file': migration, 'checksum': digest}) return migrations
[docs] @staticmethod def apply_migration(database: Database, migration: dict, silent: bool = False, exit_on_finish: bool = True) -> tuple: cursor = database.connection.cursor() module = importlib.import_module(f"smah.database.migrations.{migration['file'][:-3]}") cursor.execute("BEGIN TRANSACTION") try: module.up(cursor) cursor.execute( """ INSERT INTO schema_migrations (migration, checksum, applied) VALUES (?, ?, ?) ON CONFLICT(migration) DO UPDATE SET checksum = excluded.checksum, applied = excluded.applied """, (migration['file'], migration['checksum'], True) ) cursor.execute("COMMIT") cursor.close() r = f"Applied Migration {migration['file']}" if not silent: print(r) return "success", r except Exception as e: cursor.execute("ROLLBACK") cursor.close() r = f"Error Applying Migration {migration['file']}: {e}" if not silent: print(r) if exit_on_finish: exit(2) return ("error", ("Migration Failed", r))
[docs] @staticmethod def rollback_migration(database: Database, migration: dict, silent: bool = False, exit_on_finish: bool = True) -> tuple: cursor = database.connection.cursor() module = importlib.import_module(f"smah.database.migrations.{migration['file'][:-3]}") cursor.execute("BEGIN TRANSACTION") try: module.down(cursor) cursor.execute( """ INSERT INTO schema_migrations (migration, checksum, applied) VALUES (?, ?, ?) ON CONFLICT(migration) DO UPDATE SET checksum = excluded.checksum, applied = excluded.applied """, (migration['file'], migration['checksum'], False) ) cursor.execute("COMMIT") cursor.close() r = f"Reverted Migration '{migration['file']}'" if not silent: print(r) return "success", r except Exception as e: cursor.execute("ROLLBACK") cursor.close() r = f"Error Reverting Migration '{migration['file']}': {e}" if not silent: print(r) if exit_on_finish: exit(2) return "error", ("Rollback Failed", r)
[docs] @staticmethod def migrate( database: Database, args: argparse.Namespace, silent: bool = False, exit_on_finish: bool = True ) -> tuple: count = 0 available_migrations = Migration.get_migrations() tracked_migrations = {m['migration']: m for m in Migration.get_schema_migrations(database)} for migration in available_migrations: # Skip migrations that have already been applied tracked = tracked_migrations.get(migration['file']) if tracked and tracked['applied']: if tracked['checksum'] != migration['checksum']: r = f"Checksum Mismatch in {migration['file']}: expected {tracked['checksum']} but got {migration['checksum']}" if not silent: print(r) if args.reset_checksums: cursor = database.connection.cursor() cursor.execute("UPDATE schema_migrations SET checksum = ? WHERE migration = ?", (migration['checksum'], migration['file'])) cursor.close() else: if exit_on_finish: exit(1) else: return "error", ("Checksum Mismatch", r) if args.to and args.to == migration['file']: r = f"Migration Complete: changes applied {count}" if not silent: print(r) if exit_on_finish: exit(0) else: return "success", r else: outcome, details = Migration.apply_migration(database, migration, silent=silent, exit_on_finish=exit_on_finish) if outcome == "error": if exit_on_finish: exit(2) else: return outcome, details count += 1 if args.count and count >= args.count: r = f"Migration Complete: changes applied {count}" if not silent: print(r) if exit_on_finish: exit(0) else: return "success", r if args.to and args.to == migration['file']: r = f"Migration Complete: changes applied {count}" if not silent: print(r) if exit_on_finish: exit(0) else: return "success", r if count == 0: r = "No Migrations Pending" if not silent: print(r) return "nop", r else: r = f"Migration Complete: changes applied {count}" if not silent: print(r) return "success", r
[docs] @staticmethod def rollback(database: Database, args: argparse.Namespace): count = 0 available_migrations = {m['file']: m for m in Migration.get_migrations()} tracked_migrations = [m for m in Migration.get_schema_migrations(database) if m['applied'] and m['migration'] in available_migrations] tracked_migrations.reverse() if args.to: # Verify to target is applied available = False for migration in tracked_migrations: if migration['migration'] == args.to: available = True break if not available: print(f"Migration --to '{args.to}' not applied or available") exit(1) for migration in tracked_migrations: Migration.rollback_migration(database, available_migrations[migration['migration']]) count += 1 if args.count and count >= args.count: print(f"Rollback Complete: changes applied {count}") exit(0) if args.to and args.to == migration['migration']: print(f"Rollback Complete: changes applied {count}") exit(0) if count == 0: print("No Migrations Available for Rollback") else: print(f"Rollback Complete #{count} migrations reverted.")
[docs] @staticmethod def status(database: Database): available_migrations = Migration.get_migrations() tracked_migrations = {m['migration']: m for m in Migration.get_schema_migrations(database)} if not available_migrations: print("No Migrations Available") exit(0) out = "" for migration in available_migrations: migration['applied'] = False migration['checksum_mismatch'] = False tracked = tracked_migrations.get(migration['file']) if tracked: migration['applied'] = tracked['applied'] if tracked['checksum'] != migration['checksum']: migration['checksum_mismatch'] = True out += f"{migration['file']} {'(applied)' if migration['applied'] else ''} {'(checksum mismatch)' if migration['checksum_mismatch'] else ''}\n" print(f"Migrations:\n{out}")
[docs] @staticmethod def create(name: str) -> None: epoch = int(time.time()) template = textwrap.dedent( """ def up(cursor): \"\"\" Apply schema. \"\"\" pass def down(cursor): \"\"\" Rollback schema. \"\"\" pass """).strip() mf = os.path.join(Migration.MIGRATIONS_DIR, f"{epoch}_{name}.py") os.makedirs(Migration.MIGRATIONS_DIR, exist_ok=True) with open(mf, "w") as file: file.write(template) print(f"Created migration: {mf}")