#!/usr/bin/env python3 """ Database tools and functions. """ import MySQLdb from flask import session from passlib.hash import argon2 class Database(): """ An interface to interact with the database. """ def __init__(self): """ Initalizes the database. """ with open("db_key", "r") as file: # TODO: encrypt this self.user, self.db, self.key = file.read().strip().split() try: self.execute("SELECT * FROM `users`").fetchone() except MySQLdb.ProgrammingError: # database not initialized with open("anonkun.sql", "r") as file: commands = file.read().split(";") for cmd in commands: cmd = cmd.strip() if not cmd: continue self.execute(cmd) def execute(self, *args, **kwargs): """ Opens a connection to the app's database and executes the SQL statements passed to this function. """ with MySQLdb.connect(user=self.user,passwd=self.key,db=self.db) as cur: cur.execute(*args, **kwargs) return cur _DB = Database() def add_user(username, password): """ Adds a user to the database. """ if verify_username(username): # username taken return "username_taken" elif len(username) > 20: return "username_too_long" pw_hash = argon2.hash(password) _DB.execute( "INSERT INTO `users` (`username`, `password_hash`) VALUES (%s, %s)", (username, pw_hash)) return "success" def verify_password(username, password): """ Verifies a user's password. """ user = verify_username(username) if not user: return False user_id, _, pw_hash = user if argon2.verify(password, pw_hash): session["user_id"] = user_id return True else: return False def verify_username(username): """ Checks to see if the given username is in the database. """ user = _DB.execute("SELECT * FROM `users` WHERE `username` = %s", (username,)).fetchone() if user: return user else: return False def log_chat_message(data): """ Logs chat messages into the database. 'data' should be a dict containing: message, date, room_id, name, and user_id (optional). """ message = data["message"] date = data["date"] room_id = int(data["room"]) name = data["name"] user_id = data.get("user_id") _DB.execute( "INSERT INTO `chat_messages` (" \ + "`message`, `room_id`, `date`, `name`, `name_id`) VALUES (" \ + "%s, %s, %s, %s, %s)", (message, room_id, date, name, user_id)) def get_chat_messages(room_id): """ Retrieves all chat messages for the provided room_id. """ res = _DB.execute( "SELECT * FROM `chat_messages` WHERE `room_id` = %s " \ + "ORDER BY `date` ASC", (room_id,)).fetchall() return res