from functools import wraps
import json
from fastapi import HTTPException
from database import get_db_connection

def log_api_call(endpoint_name):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            api_id = kwargs.get("api_id")
            if not api_id:
                return func(*args, **kwargs)

            db_connection = None
            cursor = None
            user_id = None
            request_payload = None

            try:
                db_connection = get_db_connection()
                cursor = db_connection.cursor()

                cursor.execute("SELECT id FROM users WHERE app_id = %s", (api_id,))
                user = cursor.fetchone()
                if not user:
                    raise HTTPException(status_code=401, detail="Invalid API ID")
                user_id = user[0]
                
                request_payload = {k: v for k, v in kwargs.items() if k != 'api_id'}

                sql = "INSERT INTO api_logs (user_id, api_endpoint, request_payload, status_code) VALUES (%s, %s, %s, %s)"
                val = (user_id, endpoint_name, json.dumps(request_payload), 200)
                cursor.execute(sql, val)
                db_connection.commit()

                response = func(*args, **kwargs)

                sql = "UPDATE api_logs SET response_payload = %s WHERE user_id = %s AND api_endpoint = %s AND request_payload = %s"
                val = (json.dumps(response), user_id, endpoint_name, json.dumps(request_payload))
                cursor.execute(sql, val)
                db_connection.commit()

                return response

            except Exception as e:
                if user_id and cursor:
                    sql = "UPDATE api_logs SET status_code = %s, response_payload = %s WHERE user_id = %s AND api_endpoint = %s AND request_payload = %s"
                    val = (500, json.dumps({"error": str(e)}), user_id, endpoint_name, json.dumps(request_payload))
                    cursor.execute(sql, val)
                    db_connection.commit()
                
                import traceback
                error_detail = f"Error in {endpoint_name}: {str(e)}\n{traceback.format_exc()}"
                raise HTTPException(status_code=500, detail=error_detail)

            finally:
                if cursor:
                    cursor.close()
                if db_connection:
                    db_connection.close()
        return wrapper
    return decorator