diff --git a/fileHost.py b/fileHost.py index 187dd97..7a209ab 100755 --- a/fileHost.py +++ b/fileHost.py @@ -3,9 +3,10 @@ Simple file host using Flask. """ import os -import sqlite3 -import secrets import string +import secrets +import sqlite3 +import threading from passlib.hash import argon2 from flask import Flask, session, request, abort, redirect, url_for, g, \ @@ -52,6 +53,21 @@ app.wsgi_app = ReverseProxied(app.wsgi_app) app.config['MAX_CONTENT_LENGTH'] = 128 * 1024 * 1024 app.config["UPLOAD_DIR"] = "/usr/local/www/html/up" app.config["DB_NAME"] = "fileHost.db" +app.config["DB_LOCK"] = threading.Lock() + + +def db_execute(*args, **kwargs): + """ + Opens a connection to the app's database and executes the SQL statements + passed to this function. + """ + with sqlite3.connect(app.config.get("DB_NAME")) as con: + app.config.get("DB_LOCK").acquire() + cur = con.cursor() + res = cur.execute(*args, **kwargs) + app.config.get("DB_LOCK").release() + return res + def init(): @@ -69,38 +85,33 @@ def init(): file.write(secret_key) app.secret_key = secret_key - con = sqlite3.connect(app.config.get("DB_NAME")) - db = con.cursor() try: - db.execute("SELECT * FROM users").fetchone() - db.execute("SELECT * FROM uploads").fetchone() + db_execute("SELECT * FROM users").fetchone() + db_execute("SELECT * FROM uploads").fetchone() except sqlite3.OperationalError: - db.execute("CREATE TABLE users(" + db_execute("CREATE TABLE users(" "id INTEGER PRIMARY KEY," "username TEXT," "pw_hash TEXT," "admin BOOL DEFAULT FALSE)") - db.execute("CREATE TABLE uploads(" + db_execute("CREATE TABLE uploads(" "filename TEXT, uploaded_by TEXT," "uploaded_date INT DEFAULT (STRFTIME('%s', 'now')))") - return con, db - def add_user(username, password, admin="FALSE"): """ Adds a user to the database. """ - u = db.execute("SELECT username FROM users WHERE username = ?", + u = db_execute("SELECT username FROM users WHERE username = ?", (username,)).fetchone() if u: return False pw_hash = argon2.hash(password) - db.execute("INSERT INTO users (username, pw_hash, admin) VALUES (?,?,?)", + db_execute("INSERT INTO users (username, pw_hash, admin) VALUES (?,?,?)", (username, pw_hash, admin)) - con.commit() return True @@ -126,7 +137,7 @@ def verify_username(username): """ Checks to see if the given username is in the database. """ - user = db.execute("SELECT * FROM users WHERE username = ?", + user = db_execute("SELECT * FROM users WHERE username = ?", (username,)).fetchone() if user: return user @@ -212,7 +223,7 @@ def change_password(): return "The new passwords do not match!" pw_hash = argon2.hash(new_password) - db.execute("UPDATE users SET pw_hash = ? WHERE username = ?", + db_execute("UPDATE users SET pw_hash = ? WHERE username = ?", (pw_hash, username)) session.pop("username") return redirect(url_for("login")) @@ -267,9 +278,8 @@ def index(): fname = pre + "_" + fname file.save(os.path.join(fdir, fname)) - db.execute("INSERT INTO uploads (filename, uploaded_by) VALUES (?,?)", + db_execute("INSERT INTO uploads (filename, uploaded_by) VALUES (?,?)", (fname, username)) - con.commit() #TODO: make this not hardcoded # url = request.url_root + "up/" + fname @@ -289,8 +299,8 @@ def get_rand_chars(n): return "".join(chars) -con, db = init() -# TODO: make these not global variables +init() +# TODO: make these not global variables? if __name__ == "__main__": import sys add_user(sys.argv[1], sys.argv[2], "TRUE")