#!/usr/bin/env python3 """ Database tools and functions. """ import MySQLdb from flask import session from passlib.hash import argon2 class Database(): """ An interface to interact with the database. """ def __init__(self): """ Initalizes the database. """ with open("db_key", "r") as file: # TODO: encrypt this self.user, self.db, self.key = file.read().strip().split() try: self.execute("SELECT * FROM `users`").fetchone() except MySQLdb.ProgrammingError: # database not initialized with open("anonkun.sql", "r") as file: commands = file.read().split(";") for cmd in commands: cmd = cmd.strip() if not cmd: continue self.execute(cmd) def execute(self, *args, **kwargs): """ Opens a connection to the app's database and executes the SQL statements passed to this function. """ with MySQLdb.connect( user=self.user, passwd=self.key, db=self.db, charset='utf8mb4') as cur: cur.execute(*args, **kwargs) return cur _DB = Database() def add_user(username, password, timestamp): """ Adds a user to the database. """ if verify_username(username): # username taken return "username_taken" elif len(username) > 20: return "username_too_long" pw_hash = argon2.hash(password) _DB.execute( "INSERT INTO `users` (`username`, `password_hash`, `signup_date`) " \ + "VALUES (%s, %s, %s)", (username, pw_hash, timestamp)) return "success" def verify_password(username, password): """ Verifies a user's password. """ user = verify_username(username) if not user: return False user_id, _, pw_hash, _ = user if argon2.verify(password, pw_hash): session["user_id"] = user_id return True else: return False def verify_username(username): """ Checks to see if the given username is in the database. """ user = _DB.execute("SELECT * FROM `users` WHERE `username` = %s", (username,)).fetchone() if user: return user else: return False def log_chat_message(data): """ Logs chat messages into the database. 'data' should be a dict containing: message, date, room_id, name, and user_id (optional). """ message = data["message"] date = data["date"] room_id = data["room"] user_id = data.get("user_id") _DB.execute( "INSERT INTO `chat_messages` (" \ + "`message`, `room_id`, `date`, `name_id`) VALUES (" \ + "%s, %s, %s, %s)", (message, room_id, date, user_id)) post_id = _DB.execute( "SELECT `message_id` FROM `chat_messages` WHERE `room_id` = %s " \ + "ORDER BY `message_id` DESC", (room_id,)).fetchone()[0] return post_id def get_chat_messages(room_id): """Retrieves all chat messages for the provided room_id.""" res = _DB.execute( "SELECT * FROM `chat_messages` WHERE `room_id` = %s " \ + "ORDER BY `date` ASC", (room_id,)).fetchall() return res def insert_quest(canon_title, ident_title, owner_id): """Creates a new quest entry.""" _DB.execute( "INSERT INTO `quest_meta` (`canon_title`, `ident_title`, `owner_id`) "\ + "VALUES (%s, %s, %s)", (canon_title, ident_title, owner_id)) quest_id = _DB.execute( "SELECT `quest_id` FROM `quest_meta` WHERE `ident_title` = %s" \ + "ORDER BY `quest_id` DESC", (ident_title,)).fetchone()[0] return quest_id def insert_quest_post(quest_id, post_type, post, timestamp): """Insers a new quest post.""" _DB.execute( "INSERT INTO `quest_data`" \ + "(`quest_id`, `post_type`, `post`, `timestamp`) " \ + "VALUES (%s, %s, %s, %s)", (quest_id, post_type, post, timestamp)) post_id = _DB.execute( "SELECT `post_id` FROM `quest_data` WHERE `quest_id` = %s " \ + "ORDER BY `post_id` DESC", (quest_id,)).fetchone()[0] return post_id def get_quest_meta(quest_id=None, ident_title=None): """ Retrieves all meta info about a quest. Allows searching by either quest_id or ident_title. """ statement = "SELECT * FROM `quest_meta` WHERE " if quest_id: statement += "`quest_id` = %s" data = _DB.execute(statement, (quest_id,)).fetchone() elif ident_title: statement += "`ident_title` = %s" data = _DB.execute(statement, (ident_title,)).fetchone() else: return return data def get_quest_data(quest_id): """Retrieves all quest posts.""" data = _DB.execute( "SELECT * FROM `quest_data` WHERE `quest_id` = %s " \ + "ORDER BY `post_id` ASC", (quest_id,)).fetchall() return data def get_quest_post(post_id=None, quest_id=None): """ Retrieves the post data for the given post_id. If no post_id is given, it returns the most recent post. """ if post_id: data = _DB.execute( "SELECT * FROM `quest_data` WHERE `post_id` = %s", (post_id,)).fetchone() elif quest_id: data = _DB.execute( "SELECT * FROM `quest_data` WHERE `quest_id` = %s "\ + "ORDER BY `post_id` DESC", (quest_id,)).fetchone() else: return return data def get_user_info(username): """Retrives relevant user data.""" data = _DB.execute( "SELECT `user_id`, `signup_date` FROM `users` WHERE `username` = %s", (username,)).fetchone() return data def get_user_quests(user_id): """Retrieves all quests ran by a particular user_id.""" data = _DB.execute( "SELECT * FROM `quest_meta` WHERE `owner_id` = %s", (user_id,)).fetchall() return data def update_quest_post(post_id, new_post): """Updates a quest post.""" _DB.execute( "UPDATE `quest_data` SET `post` = %s WHERE `post_id` = %s", (new_post, post_id)) def set_post_open(post_id, quest_id): """Sets an active post open for the given quest.""" _DB.execute( "UPDATE `quest_meta` SET `open_post_id` = %s WHERE `quest_id` = %s", (post_id, quest_id)) def set_post_closed(quest_id): """Closes a quest's dice call.""" _DB.execute( "UPDATE `quest_meta` SET `open_post_id` = NULL WHERE `quest_id` = %s", (quest_id,)) def insert_dice_call(post_id, quest_id, new_call): """Inserts a new dice call.""" dice_roll, strict, dice_challence, rolls_taken = new_call _DB.execute( "INSERT INTO `dice_calls`" \ + "(`post_id`, `dice_roll`, `strict`,`dice_challenge`,`rolls_taken`)" \ + "VALUES (%s, %s, %s, %s, %s)", (post_id, dice_roll, strict, dice_challence, rolls_taken)) def get_dice_call(post_id): """Retrives the currently open dice call, if there is one.""" data = _DB.execute( "SELECT * FROM `dice_calls` WHERE `post_id` = %s", (post_id,)).fetchone() return data def insert_quest_roll(message_id, quest_id, post_id, roll_data): """Inserts a user roll into the `dice_rolls` table.""" ins = (message_id, quest_id, post_id) + roll_data _DB.execute( "INSERT INTO `dice_rolls`" \ + "(`message_id`, `quest_id`, `post_id`, " \ + "`roll_dice`, `roll_results`, `roll_total`)" \ + "VALUES (%s, %s, %s, %s, %s, %s)", (ins)) def get_dice_rolls(quest_id=None, post_id=None): """Gets all rolls for the given quest.""" if quest_id: sql = "SELECT * FROM `dice_rolls` WHERE `quest_id` = %s " ins = quest_id elif post_id: sql = "SELECT * FROM `dice_rolls` WHERE `post_id` = %s " ins = post_id else: return sql += "ORDER BY `message_id` ASC" data = _DB.execute(sql, (ins,)).fetchall() return data def insert_poll(post_id, quest_id, multi_choice, allow_writein): """Inserts a new poll post.""" _DB.execute( "INSERT INTO `polls` " \ + "(`post_id`, `quest_id`, `multi_choice`, `allow_writein`) " \ + "VALUES (%s, %s, %s, %s)", (post_id, quest_id, multi_choice, allow_writein)) def get_poll(post_id): """Gets poll information.""" data = _DB.execute( "SELECT * FROM `polls` WHERE `post_id` = %s", (post_id,)).fetchone() return data def insert_poll_option(post_id, option_text): """Insert a new poll option. ips_voted will be NULL to start.""" _DB.execute( "INSERT INTO `poll_options` " \ + "(`post_id`, `option_text`) VALUES (%s, %s)", (post_id, option_text)) def get_poll_options(quest_id): """Gets all relevent poll options for a given quest or post.""" data = _DB.execute( "SELECT DISTINCT " \ + "`poll_options`.* FROM `poll_options`, `quest_data` " \ + "WHERE `quest_data`.`post_id` = `poll_options`.`post_id` " \ + "AND `quest_data`.`quest_id` = %s", (quest_id,)).fetchall() return data def insert_poll_vote(option_id, ip_address): """Inserts a new vote for a poll option.""" try: _DB.execute( "INSERT INTO `poll_votes` " \ "(`option_id`, `ip_address`) VALUES (%s, %s)", (option_id, ip_address)) except MySQLdb.IntegrityError: # this ip has already voted for this option return def remove_poll_vote(option_id, ip_address): """Removes a vote from a poll.""" _DB.execute( "DELETE FROM `poll_votes` " \ "WHERE `option_id` = %s AND `ip_address` = %s", (option_id, ip_address)) def get_poll_votes(quest_id): """Gets all poll votes for the given quest_id.""" data = _DB.execute( "SELECT `poll_votes`.* FROM `poll_votes` LEFT JOIN `poll_options` " \ + "ON `poll_votes`.`option_id` = `poll_options`.`option_id` " \ + "LEFT JOIN `quest_data` " \ + "ON `poll_options`.`post_id`=`quest_data`.`post_id` " \ + "WHERE `quest_data`.`quest_id` = %s", (quest_id,)).fetchall() return data def get_poll_votes_voted(post_id, ip_address): """Gets all votes made by the given ip_address for the given poll.""" data = _DB.execute( "SELECT `poll_votes`.* FROM `poll_votes` LEFT JOIN `poll_options` " \ + "ON `poll_votes`.`option_id` = `poll_options`.`option_id` " \ + "WHERE `poll_options`.`post_id` = %s " \ + "AND `poll_votes`.`ip_address` = %s", (post_id, ip_address)).fetchall() return data