made db connection not constant

This commit is contained in:
iou1name 2018-04-14 20:34:49 -04:00
parent 0499adaddf
commit 62d8195705

View File

@ -3,9 +3,10 @@
Simple file host using Flask. Simple file host using Flask.
""" """
import os import os
import sqlite3
import secrets
import string import string
import secrets
import sqlite3
import threading
from passlib.hash import argon2 from passlib.hash import argon2
from flask import Flask, session, request, abort, redirect, url_for, g, \ from flask import Flask, session, request, abort, redirect, url_for, g, \
@ -52,6 +53,21 @@ app.wsgi_app = ReverseProxied(app.wsgi_app)
app.config['MAX_CONTENT_LENGTH'] = 128 * 1024 * 1024 app.config['MAX_CONTENT_LENGTH'] = 128 * 1024 * 1024
app.config["UPLOAD_DIR"] = "/usr/local/www/html/up" app.config["UPLOAD_DIR"] = "/usr/local/www/html/up"
app.config["DB_NAME"] = "fileHost.db" app.config["DB_NAME"] = "fileHost.db"
app.config["DB_LOCK"] = threading.Lock()
def db_execute(*args, **kwargs):
"""
Opens a connection to the app's database and executes the SQL statements
passed to this function.
"""
with sqlite3.connect(app.config.get("DB_NAME")) as con:
app.config.get("DB_LOCK").acquire()
cur = con.cursor()
res = cur.execute(*args, **kwargs)
app.config.get("DB_LOCK").release()
return res
def init(): def init():
@ -69,38 +85,33 @@ def init():
file.write(secret_key) file.write(secret_key)
app.secret_key = secret_key app.secret_key = secret_key
con = sqlite3.connect(app.config.get("DB_NAME"))
db = con.cursor()
try: try:
db.execute("SELECT * FROM users").fetchone() db_execute("SELECT * FROM users").fetchone()
db.execute("SELECT * FROM uploads").fetchone() db_execute("SELECT * FROM uploads").fetchone()
except sqlite3.OperationalError: except sqlite3.OperationalError:
db.execute("CREATE TABLE users(" db_execute("CREATE TABLE users("
"id INTEGER PRIMARY KEY," "id INTEGER PRIMARY KEY,"
"username TEXT," "username TEXT,"
"pw_hash TEXT," "pw_hash TEXT,"
"admin BOOL DEFAULT FALSE)") "admin BOOL DEFAULT FALSE)")
db.execute("CREATE TABLE uploads(" db_execute("CREATE TABLE uploads("
"filename TEXT, uploaded_by TEXT," "filename TEXT, uploaded_by TEXT,"
"uploaded_date INT DEFAULT (STRFTIME('%s', 'now')))") "uploaded_date INT DEFAULT (STRFTIME('%s', 'now')))")
return con, db
def add_user(username, password, admin="FALSE"): def add_user(username, password, admin="FALSE"):
""" """
Adds a user to the database. Adds a user to the database.
""" """
u = db.execute("SELECT username FROM users WHERE username = ?", u = db_execute("SELECT username FROM users WHERE username = ?",
(username,)).fetchone() (username,)).fetchone()
if u: if u:
return False return False
pw_hash = argon2.hash(password) pw_hash = argon2.hash(password)
db.execute("INSERT INTO users (username, pw_hash, admin) VALUES (?,?,?)", db_execute("INSERT INTO users (username, pw_hash, admin) VALUES (?,?,?)",
(username, pw_hash, admin)) (username, pw_hash, admin))
con.commit()
return True return True
@ -126,7 +137,7 @@ def verify_username(username):
""" """
Checks to see if the given username is in the database. Checks to see if the given username is in the database.
""" """
user = db.execute("SELECT * FROM users WHERE username = ?", user = db_execute("SELECT * FROM users WHERE username = ?",
(username,)).fetchone() (username,)).fetchone()
if user: if user:
return user return user
@ -212,7 +223,7 @@ def change_password():
return "The new passwords do not match!" return "The new passwords do not match!"
pw_hash = argon2.hash(new_password) pw_hash = argon2.hash(new_password)
db.execute("UPDATE users SET pw_hash = ? WHERE username = ?", db_execute("UPDATE users SET pw_hash = ? WHERE username = ?",
(pw_hash, username)) (pw_hash, username))
session.pop("username") session.pop("username")
return redirect(url_for("login")) return redirect(url_for("login"))
@ -267,9 +278,8 @@ def index():
fname = pre + "_" + fname fname = pre + "_" + fname
file.save(os.path.join(fdir, fname)) file.save(os.path.join(fdir, fname))
db.execute("INSERT INTO uploads (filename, uploaded_by) VALUES (?,?)", db_execute("INSERT INTO uploads (filename, uploaded_by) VALUES (?,?)",
(fname, username)) (fname, username))
con.commit()
#TODO: make this not hardcoded #TODO: make this not hardcoded
# url = request.url_root + "up/" + fname # url = request.url_root + "up/" + fname
@ -289,8 +299,8 @@ def get_rand_chars(n):
return "".join(chars) return "".join(chars)
con, db = init() init()
# TODO: make these not global variables # TODO: make these not global variables?
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
add_user(sys.argv[1], sys.argv[2], "TRUE") add_user(sys.argv[1], sys.argv[2], "TRUE")