
189 lines
5.0 KiB
Raw Normal View History

2023-12-29 14:02:11 +01:00
import os
import sys
import uuid
import redis
from cryptography.fernet import Fernet
from flask import abort, Flask, request, jsonify
from redis.exceptions import ConnectionError as RedisConnectionError
2023-12-29 14:02:11 +01:00
from .utils import strtobool
2023-12-29 19:59:25 +01:00
NO_SSL = bool(strtobool(os.environ.get("NO_SSL", "False")))
HOST_OVERRIDE = os.environ.get("HOST_OVERRIDE", None)
2023-12-29 14:02:11 +01:00
# Initialize Flask Application
app = Flask(__name__)
2023-12-29 19:59:25 +01:00
if os.environ.get("DEBUG"):
2023-12-29 14:02:11 +01:00
app.debug = True
2023-12-29 19:59:25 +01:00
app.secret_key = os.environ.get("SECRET_KEY", "Secret Key")
app.config.update(dict(STATIC_URL=os.environ.get("STATIC_URL", "static")))
2023-12-29 14:02:11 +01:00
# Initialize Redis
2023-12-29 19:59:25 +01:00
if os.environ.get("MOCK_REDIS"):
2023-12-29 14:02:11 +01:00
from fakeredis import FakeStrictRedis
2023-12-29 19:59:25 +01:00
2023-12-29 14:02:11 +01:00
redis_client = FakeStrictRedis()
2023-12-29 19:59:25 +01:00
elif os.environ.get("REDIS_URL"):
redis_client = redis.StrictRedis.from_url(os.environ.get("REDIS_URL"))
2023-12-29 14:02:11 +01:00
2023-12-29 19:59:25 +01:00
redis_host = os.environ.get("REDIS_HOST", "localhost")
redis_port = os.environ.get("REDIS_PORT", 6379)
redis_db = os.environ.get("SNAPPASS_REDIS_DB", 0)
redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db)
REDIS_PREFIX = os.environ.get("REDIS_PREFIX", "snappass")
2023-12-29 14:02:11 +01:00
def check_redis_alive(fn):
def inner(*args, **kwargs):
2023-12-29 19:59:25 +01:00
if fn.__name__ == "main":
2023-12-29 14:02:11 +01:00
return fn(*args, **kwargs)
except RedisConnectionError as error:
2023-12-29 19:59:25 +01:00
print("Failed to connect to redis! %s" % error)
if fn.__name__ == "main":
2023-12-29 14:02:11 +01:00
return abort(500)
2023-12-29 19:59:25 +01:00
2023-12-29 14:02:11 +01:00
return inner
def encrypt(password):
Take a password string, encrypt it with Fernet symmetric encryption,
and return the result (bytes), with the decryption key (bytes)
encryption_key = Fernet.generate_key()
fernet = Fernet(encryption_key)
2023-12-29 19:59:25 +01:00
encrypted_password = fernet.encrypt(password.encode("utf-8"))
2023-12-29 14:02:11 +01:00
return encrypted_password, encryption_key
def decrypt(password, decryption_key):
Decrypt a password (bytes) using the provided key (bytes),
and return the plain-text password (bytes).
fernet = Fernet(decryption_key)
return fernet.decrypt(password)
def parse_token(token):
token_fragments = token.split(TOKEN_SEPARATOR, 1) # Split once, not more.
storage_key = token_fragments[0]
2023-12-29 19:59:25 +01:00
decryption_key = token_fragments[1].encode("utf-8")
2023-12-29 14:02:11 +01:00
except IndexError:
decryption_key = None
return storage_key, decryption_key
def set_password(password, ttl):
Encrypt and store the password for the specified lifetime.
Returns a token comprised of the key where the encrypted password
is stored, and the decryption key.
storage_key = REDIS_PREFIX + uuid.uuid4().hex
encrypted_password, encryption_key = encrypt(password)
redis_client.setex(storage_key, ttl, encrypted_password)
2023-12-29 19:59:25 +01:00
encryption_key = encryption_key.decode("utf-8")
2023-12-29 14:02:11 +01:00
token = TOKEN_SEPARATOR.join([storage_key, encryption_key])
return token
def get_password(token):
From a given token, return the initial password.
If the token is tilde-separated, we decrypt the password fetched from Redis.
If not, the password is simply returned as is.
storage_key, decryption_key = parse_token(token)
password = redis_client.get(storage_key)
if password is not None:
if decryption_key is not None:
password = decrypt(password, decryption_key)
2023-12-29 19:59:25 +01:00
return password.decode("utf-8")
2023-12-29 14:02:11 +01:00
def password_exists(token):
storage_key, decryption_key = parse_token(token)
return redis_client.exists(storage_key)
def empty(value):
if not value:
return True
def _clean_ttl(request):
if not request.form.get("ttl"):
return 604800
time_period = int(request.form.get("ttl"))
except ValueError:
abort(400, "TTL must be an integer")
if time_period > 2419200:
abort(400, "TTL must be less than 2419200 seconds (4 weeks)")
return time_period
2023-12-29 14:02:11 +01:00
def clean_input():
Make sure we're not getting bad data from the front end,
format data to be machine readable
2023-12-29 19:59:25 +01:00
if empty(request.form.get("password", "")):
time_period = _clean_ttl(request)
return time_period, request.form["password"]
2023-12-29 14:02:11 +01:00
2023-12-29 19:59:25 +01:00
@app.route("/", methods=["POST"])
2023-12-29 14:02:11 +01:00
def handle_password():
if request.is_json:
request.form = request.get_json()
ttl, password = clean_input()
token = set_password(password, ttl)
return jsonify(key=token)
2023-12-29 14:02:11 +01:00
2023-12-29 19:59:25 +01:00
@app.route("/get-secret", methods=["POST"])
def show_password():
if request.is_json:
request.form = request.get_json()
2023-12-29 19:59:25 +01:00
if empty(request.form.get("key", "")):
2023-12-29 19:59:25 +01:00
password_key = request.form["key"]
2023-12-29 14:02:11 +01:00
password = get_password(password_key)
if not password:
return abort(404)
return jsonify(password=password)
2023-12-29 14:02:11 +01:00
def main():
2023-12-29 19:59:25 +01:00"")
2023-12-29 14:02:11 +01:00
2023-12-29 19:59:25 +01:00
if __name__ == "__main__":
2023-12-29 14:02:11 +01:00