diff --git a/README.md b/README.md index 5e88654..42b4dc6 100644 --- a/README.md +++ b/README.md @@ -1 +1,39 @@ # shakecities + +``` +version: '3' +services: + main: + image: git.woodburn.au/nathanwoodburn/shakecities:latest + depends_on: + - db + environment: + DB_HOST: db + DB_USER: main + DB_PASSWORD: your-db-password + DB_NAME: main + WORKERS: 2 # number of workers to run (should be 2 * number of cores) + + sites: + image: git.woodburn.au/nathanwoodburn/shakecities-sites:latest + depends_on: + - db + environment: + DB_HOST: db + DB_USER: main + DB_PASSWORD: your-db-password + DB_NAME: main + WORKERS: 2 # number of workers to run (should be 2 * number of cores) + + db: + image: linuxserver/mariadb:latest + environment: + MYSQL_ROOT_PASSWORD: your-root-password + MYSQL_DATABASE: main + MYSQL_USER: main + MYSQL_PASSWORD: your-db-password + volumes: + - db_data:/var/lib/mysql +volumes: + db_data: +``` \ No newline at end of file diff --git a/accounts.py b/accounts.py index 9d6adf5..65b3f51 100644 --- a/accounts.py +++ b/accounts.py @@ -2,6 +2,7 @@ import os import dotenv from passlib.hash import argon2 import json +import db dotenv.load_dotenv() local = os.getenv('LOCAL') @@ -9,6 +10,15 @@ local = os.getenv('LOCAL') def hash_password(password): return argon2.using(rounds=16).hash(password) +def convert_db_users(db_entry): + return { + 'id': db_entry[0], + 'email': db_entry[1], + 'domain': db_entry[2], + 'password': db_entry[3], + 'tokens': db_entry[4].split(',') + } + # Verify a password against a hashed password def verify_password(password, hashed_password): return argon2.verify(password, hashed_password) @@ -40,30 +50,28 @@ def create_user(email, domain, password): token = generate_cookie() user['tokens'] = [token] - # If file doesn't exist, create it - if not os.path.isfile('users.json'): - with open('users.json', 'w') as f: - json.dump([], f) + # Check if user exists + if db.search_users(email) != []: + return {'success': False, 'message': 'User already exists'} - - - # Write to file - with open('users.json', 'r') as f: - users = json.load(f) - - for u in users: - if u['email'] == email: - return {'success': False, 'message': 'Email already exists'} - - users.append(user) - with open('users.json', 'w') as f: - json.dump(users, f) + db.add_user(email, domain, hashed_password, token) return {'success': True, 'message': 'User created', 'token': token} def validate_token(token): - with open('users.json', 'r') as f: - users = json.load(f) - for user in users: - if token in user['tokens']: - return user - return False \ No newline at end of file + search = db.search_users_token(token) + if search == []: + return False + else: + return convert_db_users(search[0]) + +def logout(token): + # Remove token from user + user = validate_token(token) + if not user: + return {'success': False, 'message': 'Invalid token'} + user['tokens'].remove(token) + # Update user + db.update_tokens(user['id'], user['tokens']) + + + return {'success': True, 'message': 'Logged out'} \ No newline at end of file diff --git a/db.py b/db.py new file mode 100644 index 0000000..8a129b2 --- /dev/null +++ b/db.py @@ -0,0 +1,78 @@ +import mysql.connector +import os +import dotenv + +dotenv.load_dotenv() + +# Database connection +dbargs = { + 'host':os.getenv('DB_HOST'), + 'user':os.getenv('DB_USER'), + 'password':os.getenv('DB_PASSWORD'), + 'database':os.getenv('DB_NAME') +} + + + + +def check_tables(): + connection = mysql.connector.connect(**dbargs) + cursor = connection.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS users ( + id INT(11) NOT NULL AUTO_INCREMENT, + email VARCHAR(255) NOT NULL, + domain VARCHAR(255) NOT NULL, + password VARCHAR(255) NOT NULL, + token VARCHAR(255) NOT NULL, + PRIMARY KEY (id) + ) + """) + + cursor.close() + connection.close() + +def add_user(email,domain,password,token): + connection = mysql.connector.connect(**dbargs) + cursor = connection.cursor() + + cursor.execute(""" + INSERT INTO users (email, domain, password, token) + VALUES (%s, %s, %s, %s) + """, (email, domain, password, token)) + connection.commit() + cursor.close() + connection.close() + +def search_users(email): + connection = mysql.connector.connect(**dbargs) + cursor = connection.cursor() + cursor.execute(""" + SELECT * FROM users WHERE email = %s + """, (email,)) + users = cursor.fetchall() + cursor.close() + connection.close() + return users + +def search_users_token(token): + connection = mysql.connector.connect(**dbargs) + cursor = connection.cursor() + query = "SELECT * FROM users WHERE token LIKE %s" + cursor.execute(query, ('%' + token + '%',)) + + users = cursor.fetchall() + cursor.close() + connection.close() + return users + +def update_tokens(id,tokens): + tokens = ','.join(tokens) + connection = mysql.connector.connect(**dbargs) + cursor = connection.cursor() + cursor.execute(""" + UPDATE users SET token = %s WHERE id = %s + """, (tokens, id)) + connection.commit() + cursor.close() + connection.close() diff --git a/main.py b/main.py index 1865a0b..95e8071 100644 --- a/main.py +++ b/main.py @@ -7,17 +7,25 @@ import schedule import time from email_validator import validate_email, EmailNotValidError import accounts +import db app = Flask(__name__) dotenv.load_dotenv() +# Database connection +dbargs = { + 'host':os.getenv('DB_HOST'), + 'user':os.getenv('DB_USER'), + 'password':os.getenv('DB_PASSWORD'), + 'database':os.getenv('DB_NAME') +} #Assets routes @app.route('/assets/') def assets(path): return send_from_directory('templates/assets', path) -#! TODO make prettier + def error(message): return jsonify({'success': False, 'message': message}), 400 @@ -56,6 +64,17 @@ def signup(): except EmailNotValidError as e: return jsonify({'success': False, 'message': 'Invalid email'}), 400 +@app.route('/logout') +def logout(): + token = request.cookies['token'] + if not accounts.logout(token)['success']: + return error('Invalid token') + + # Remove cookie + resp = make_response(redirect('/')) + resp.set_cookie('token', '', expires=0) + return resp + @app.route('/') def catch_all(path): # If file exists, load it @@ -73,5 +92,7 @@ def not_found(e): return redirect('/') + if __name__ == '__main__': + db.check_tables() app.run(debug=False, port=5000, host='0.0.0.0') \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f69f776..81354d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ schedule email-validator py3dns passlib -argon2-cffi \ No newline at end of file +argon2-cffi +mysql-connector-python \ No newline at end of file