from google import genai
from google.genai import types
from PIL import Image
import json
import os
import random
import base64
import io

from dotenv import load_dotenv


class BBoxResult:
    def __init__(self, bounding_box: list[int], center: list[int], tokens_used: dict, costs_usd: dict):
        self.bounding_box = bounding_box
        self.center = center
        self.tokens_used = tokens_used
        self.costs_usd = costs_usd

    def token_total(self) -> int:
        return sum(self.tokens_used.values())
    
    def cost_total(self) -> float:
        return sum(self.costs_usd.values())

    def __dict__(self):
        return {
            "bounding_box": self.bounding_box,
            "center": self.center,
            "tokens_used": self.tokens_used,
            "costs_usd": self.costs_usd,
        }
    
    def __repr__(self):
        return f"BBoxResult(bounding_box={self.bounding_box}, center={self.center}, tokens_used={self.tokens_used}, costs_usd={self.costs_usd})"


def get_gemini_api_key_by_random_threshold(threshold: float) -> str:
    """
    Returns GEMINI_API_KEY_PAGA if a random float is under the threshold,
    otherwise returns GEMINI_API_KEY_GRATIS.
    The bigger the threshold, the more likely to get the GRATIS key.
    """
    r = random.random()
    if r > threshold:
        return os.getenv("GEMINI_API_KEY_PAGA", "")
    return os.getenv("GEMINI_API_KEY_GRATIS", "")


def detect_bounding_boxes(object_description: str, image_base64: str) -> BBoxResult:
    """
    Calls Gemini to detect bounding boxes in an image.
    
    Args:
        prompt: The prompt to send to Gemini
        image_base64: Base64-encoded image string
        
    Returns:
        List of bounding boxes in format [x1, y1, x2, y2] with absolute pixel coordinates
    """

    load_dotenv()

    client = genai.Client(
        api_key=get_gemini_api_key_by_random_threshold(0.5)
    )


    # Decode base64 to PIL Image
    image_data = base64.b64decode(image_base64)
    image = Image.open(io.BytesIO(image_data))
    
    prompt = f"""
    Detect the one object in the image that matches the following description: 
    {object_description}
    
    Return slight larger bounding box of the object in the following JSON format:
    {{"box_2d": [ymin, xmin, ymax, xmax]}}
    The box_2d should be [ymin, xmin, ymax, xmax] normalized to 0-1000. Only one object should be detected.
    """
    config = types.GenerateContentConfig(
        response_mime_type="application/json",
        thinking_config = types.ThinkingConfig(
            thinking_budget=500,
        ),
        media_resolution="MEDIA_RESOLUTION_MEDIUM",
    )
    
    response = client.models.generate_content(
        model="models/gemini-flash-latest",
        contents=[image, prompt],
        config=config
    )

    prompt_len = response.usage_metadata.prompt_token_count
    thought_len = response.usage_metadata.thoughts_token_count
    answer_len = response.usage_metadata.candidates_token_count

    input_cost_per_1M_tokens = 0.3  # USD
    output_cost_per_1M_tokens = 2.5  # USD

    width, height = image.size
    print(f"🖼️ Image size: {width}x{height} pixels")
    print(f"💬 Gemini response: {response.text}")
    bounding_box = json.loads(response.text)
    
    converted_bounding_boxes = [
        int(bounding_box["box_2d"][1] / 1000 * width),   # xmin
        int(bounding_box["box_2d"][0] / 1000 * height),  # ymin
        int(bounding_box["box_2d"][3] / 1000 * width),   # xmax
        int(bounding_box["box_2d"][2] / 1000 * height)   # ymax
    ]
    print(f"📦 Converted bounding box: {converted_bounding_boxes}")

    center = {
        "x": (converted_bounding_boxes[0] + converted_bounding_boxes[2]) // 2,  # (xmin + xmax) / 2
        "y": (converted_bounding_boxes[1] + converted_bounding_boxes[3]) // 2,  # (ymin + ymax) / 2        
    }


    result = BBoxResult(
        bounding_box=converted_bounding_boxes,
        center=center,
        tokens_used={
            "prompt": prompt_len,
            "thought": thought_len,
            "answer": answer_len,
        },
        costs_usd={
            "input_cost": (prompt_len + thought_len) / 1_000_000 * input_cost_per_1M_tokens,
            "output_cost": answer_len / 1_000_000 * output_cost_per_1M_tokens,
        },
    )

    
    return result