diff --git a/src/bapi/blueprints/discord.py b/src/bapi/blueprints/discord.py index 173f5ec..3f833d3 100644 --- a/src/bapi/blueprints/discord.py +++ b/src/bapi/blueprints/discord.py @@ -1,4 +1,5 @@ import ipaddress +import re import secrets import urllib.parse @@ -26,18 +27,22 @@ @discord_blueprint.route("/auth", methods=["GET"]) def discord_auth(): ip = request.args.get("ip") + ip_str = None if not isinstance(ip, str): return jsonify({"error": "provided IP address invalid"}), 400 try: + ip_str = ip ip = ipaddress.ip_address(ip) except ValueError: return jsonify({"error": "provided IP address invalid"}), 400 + if "," in ip_str: + return jsonify({"error": "provided IP address invalid"}), 400 if ip.version == 6: return jsonify({"error": "IPv6 address not allowed"}), 400 if ip.is_multicast or ip.is_unspecified: return jsonify({"error": "multicast or unspecified address not allowed"}), 400 seeker_port = request.args.get("seeker_port") - if not isinstance(seeker_port, str) or not seeker_port.isdigit(): + if not isinstance(seeker_port, str) or not re.match("^[0-9]+$", seeker_port): seeker_port = "" try: seeker_port = int(seeker_port) @@ -45,8 +50,11 @@ def discord_auth(): seeker_port = "" if not isinstance(seeker_port, int) or seeker_port > 65535 or seeker_port < 1023: seeker_port = "" + nonce = request.args.get("nonce") + if not isinstance(nonce, str) or len(nonce) != 64 or not re.match("^[a-z0-9]+$", nonce): + return jsonify({"error": "bad nonce"}), 400 session["oauth2_state"] = ( - f"{urllib.parse.quote(ip.exploded, safe="", encoding="utf-8")},{seeker_port},{secrets.token_urlsafe(16)}" + f"{urllib.parse.quote(ip_str, safe="", encoding="utf-8")},{seeker_port},{nonce},{secrets.token_urlsafe(16)}" ) return redirect(discord_client.generate_uri(scope=["identify"], state=session["oauth2_state"])) @@ -65,8 +73,34 @@ def discord_callback(): del session["oauth2_state"] state_attrs = state.split(",") + if len(state_attrs) != 4: + return jsonify({"error": "bad state"}), 400 ip = urllib.parse.unquote(state_attrs[0]) - seeker_port = state_attrs[1] + try: + seeker_port = int(state_attrs[1]) + except ValueError: + return jsonify({"error": "bad state"}), 400 + nonce = state_attrs[2] + if not isinstance(nonce, str) or len(nonce) != 64 or not re.match("^[a-z0-9]+$", nonce): + return jsonify({"error": "bad state"}), 400 + nonce_duration = cfg.API.get("nonce-valid-duration") + if nonce_duration is None: + nonce_duration = 240 + try: + nonce_valid, reason_invalid = db.SessionCreationNonce.is_valid_session_creation( + ip, seeker_port, nonce, nonce_duration + ) + if not nonce_valid: + notice = "" + if reason_invalid == "invalid": + notice = " account security risk: check if connected to a genuine BeeStation game server." + elif reason_invalid == "expired": + notice = " log in within a shorter time period." + return jsonify({"error": f"{reason_invalid or "invalid"} nonce.{notice}"}), 401 + except Exception as e: + current_app.logger.error(f"error while checking nonce: {e}") + return jsonify({"error": "error checking nonce"}), 500 + discord_uid = None discord_username = None diff --git a/src/bapi/config/api.yml b/src/bapi/config/api.yml index 23e671e..d9706c7 100644 --- a/src/bapi/config/api.yml +++ b/src/bapi/config/api.yml @@ -6,3 +6,4 @@ api-url: "https://api.beestation13.com" request-source: "bapi" game-session-duration: 90 # valid duration of created game session tokens, in days. +nonce-valid-duration: 240 # valid duration of a session creation nonce, in seconds. diff --git a/src/bapi/db.py b/src/bapi/db.py index 0f25c8d..eb84ebf 100644 --- a/src/bapi/db.py +++ b/src/bapi/db.py @@ -15,6 +15,7 @@ from sqlalchemy import SmallInteger from sqlalchemy import String from sqlalchemy import Text +from sqlalchemy.orm import column_property from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.sql.expression import text @@ -55,6 +56,38 @@ def create_session(cls, ip, external_method, external_uid, external_display_name return random_token +class SessionCreationNonce(sqlalchemy_ext.Model): + __bind_key__ = "session" + __tablename__ = "SS13_session_creation_nonce" + + created = Column("created", DateTime()) + id = Column("id", Integer(), primary_key=True) + ip = Column("ip", String(32)) + session_nonce = Column("session_nonce", String(64)) + seeker_port = Column("seeker_port", Integer()) + seconds_since_creation = column_property(func.timestampdiff(text("SECOND"), created, func.now())) + + @classmethod + def is_valid_session_creation(cls, ip, seeker_port, nonce, valid_duration): + valid_nonce = None + try: + valid_nonce = ( + db_session.query(cls) + .filter(and_(cls.session_nonce == nonce, cls.ip == ip, cls.seeker_port == seeker_port)) + .one() + ) + except NoResultFound: + return (False, "invalid") + if valid_nonce is None: + return (False, "invalid") + else: + db_session.delete(valid_nonce) + print(valid_nonce.seconds_since_creation) + if valid_nonce.seconds_since_creation > (valid_duration or 240): + return (False, "expired") + return (True, "") + + class Player(sqlalchemy_ext.Model): __bind_key__ = "game" __tablename__ = "SS13_player"