import pygame
from client.resource_path import get_asset_path
from network.client import GameClient
import time
import threading
import re
from client.data import config

class UserCreation:
    def __init__(self, screen, client: GameClient):
        self.screen = screen
        self.client = client

        # ---------------- Input fields ----------------
        self.fields = ["first_name", "last_name", "username", "email", "password", "confirm_password"]
        self.buttons = ["sign_up_btn", "cancel_btn"]
        self.focus_order = self.fields + self.buttons
        self.focus_index = 0
        self.active_field = self.focus_order[self.focus_index]

        self.text_values = {f: "" for f in self.fields}

        # ---------------- Load window & mask ----------------
        self.window_img = pygame.image.load(get_asset_path("images/ui/signup_window.png")).convert_alpha()
        self.mask_img_orig = pygame.image.load(get_asset_path("images/ui/signup_window_mask.png")).convert()

        self.base_w, self.base_h = self.window_img.get_size()
        scale_ratio = config.SCREEN_HEIGHT * 1 / self.base_h
        self.scaled_w = int(self.base_w * scale_ratio)
        self.scaled_h = int(self.base_h * scale_ratio)
        self.window_img = pygame.transform.scale(self.window_img, (self.scaled_w, self.scaled_h))
        self.mask_img = pygame.transform.scale(self.mask_img_orig, (self.scaled_w, self.scaled_h))
        self.window_rect = self.window_img.get_rect(center=(config.SCREEN_WIDTH // 2, config.SCREEN_HEIGHT // 2))

        # ---------------- Map colors to fields/buttons ----------------
        self.color_map = {
            (0, 0, 255): "first_name",
            (0, 255, 255): "last_name",
            (0, 255, 0): "username",
            (255, 0, 255): "email",
            (255, 255, 0): "password",
            (255, 128, 0): "confirm_password",
            (0, 255, 128): "sign_up_btn",
            (255, 0, 0): "cancel_btn"
        }

        # ---------------- Compute field rects ----------------
        self.fields_rects = {}
        for color, name in self.color_map.items():
            rect = self._find_color_bounds(color)
            if rect:
                rect.inflate_ip(2, 2)
                self.fields_rects[name] = rect.move(self.window_rect.topleft)

        # ---------------- Font & cursor ----------------
        # Load fonts
        font_path = get_asset_path('fonts/UncialAntiqua-Regular.ttf')
        self.font = pygame.font.Font(font_path, 24)
        self.cursor_visible = True
        self.last_blink = time.time()

        # ---------------- Networking ----------------
        self.server_event = threading.Event()
        self.server_action = None
        self.server_payload = None

    # ---------------- Networking callback ----------------
    def _on_server_message(self, message: dict):
        self.server_action = message.get("action")
        self.server_payload = message
        self.server_event.set()

    # ---------------- Mask helpers ----------------
    def _find_color_bounds(self, color):
        pixels = pygame.PixelArray(self.mask_img)
        coords = [(x, y) for x in range(self.mask_img.get_width())
                          for y in range(self.mask_img.get_height())
                          if self.mask_img.get_at((x, y))[:3] == color]
        pixels.close()
        if not coords:
            return None
        xs, ys = zip(*coords)
        return pygame.Rect(min(xs), min(ys), max(xs)-min(xs), max(ys)-min(ys))

    # ---------------- Validation ----------------
    def _validate_fields(self):
        first = self.text_values["first_name"].strip()
        last = self.text_values["last_name"].strip()
        username = self.text_values["username"].strip()
        email = self.text_values["email"].strip()
        password = self.text_values["password"]
        confirm = self.text_values["confirm_password"]

        if not first or not last:
            return False, "First/Last name cannot be empty"
        if not username or len(username) < 6:
            return False, "Username must be at least 6 characters"
        if not re.match(r"^[A-Za-z0-9_]{3,12}$", username):
            return False, "Username must be 3-12 chars, letters/numbers/_ only"
        if not email or not re.match(r"[^@]+@[^@]+\.[^@]+", email):
            return False, "Invalid email"
        if not password or len(password) < 6:
            return False, "Password must be at least 6 characters"
        if password != confirm:
            return False, "Passwords do not match"
        return True, None

    # ---------------- Draw ----------------
    def draw(self):
        self.screen.blit(self.window_img, self.window_rect)

        # Highlight active field/button
        if self.active_field in self.fields_rects:
            rect = self.fields_rects[self.active_field]
            s = pygame.Surface((rect.w, rect.h), pygame.SRCALPHA)
            s.fill((255, 215, 0, 50))
            self.screen.blit(s, (rect.x, rect.y))

        # Cursor blink
        if time.time() - self.last_blink > 0.5:
            self.cursor_visible = not self.cursor_visible
            self.last_blink = time.time()

        # Draw text
        for field in self.fields:
            if field in self.fields_rects:
                rect = self.fields_rects[field]
                text = self.text_values[field]
                if field in ["password", "confirm_password"]:
                    text = "*" * len(text)
                text_surface = self.font.render(text, True, (255, 255, 255))
                self.screen.blit(text_surface, (rect.x + 5, rect.y + (rect.h - text_surface.get_height()) // 2))

                if self.active_field == field and self.cursor_visible:
                    cursor_h = int(text_surface.get_height() * 0.6)
                    cursor_x = rect.x + 5 + text_surface.get_width() + 2
                    cursor_y = rect.y + (rect.h - cursor_h) // 2
                    pygame.draw.line(self.screen, (255, 255, 255), (cursor_x, cursor_y), (cursor_x, cursor_y + cursor_h), 1)

        pygame.display.flip()

    # ---------------- Signup ----------------
    def attempt_create_async(self):
        def task():
            valid, reason = self._validate_fields()
            if not valid:
                print("[!] Validation failed:", reason)
                return

            if not self.client.connected:
                print("[*] Connecting to server...")
                self.client.connect()

            payload = {
                "action": "signup",
                "data": {
                    "first_name": self.text_values["first_name"].strip(),
                    "last_name": self.text_values["last_name"].strip(),
                    "username": self.text_values["username"].strip(),
                    "email": self.text_values["email"].strip(),
                    "password": self.text_values["password"],
                    "confirm_password": self.text_values["confirm_password"],
                },
            }

            print("[→] Sending signup data:", payload)
            self.client.request(payload, expect_action="signup_ok", timeout=5)

        threading.Thread(target=task, daemon=True).start()

    # ---------------- Main loop ----------------
    def run(self):
        print("Active field:", self.active_field)
        print("Mouse pos:", pygame.mouse.get_pos())
        clock = pygame.time.Clock()
        running = True

        # Install temporary on_message handler and restore when leaving
        original_callback = self.client.on_message
        def temp_callback(message):
            try:
                # Only handle signup-related messages here
                if message.get("action") in ["signup_ok", "signup_failed"]:
                    self._on_server_message(message)
                else:
                    # forward other messages to the original handler if present
                    if original_callback:
                        try:
                            original_callback(message)
                        except Exception:
                            pass
            except Exception as e:
                print("[!] UserCreation on_message error:", e)

        self.client.on_message = temp_callback

        try:
            while running:
                self.draw()

                # Handle server responses
                if self.server_event.is_set():
                    action = self.server_action
                    payload = self.server_payload or {}
                    self.server_event.clear()
                    self.server_action = None
                    self.server_payload = None

                    if action == "signup_ok":
                        print("[+] Account created successfully!")
                        return "login"
                    elif action == "signup_failed":
                        print("[!] Signup failed:", payload.get("reason"))

                # Handle input events
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        return None
                    elif event.type == pygame.KEYDOWN:
                        if event.key == pygame.K_ESCAPE:
                            return None
                        elif event.key == pygame.K_TAB:
                            self.focus_index = (self.focus_index + 1) % len(self.focus_order)
                            self.active_field = self.focus_order[self.focus_index]
                        elif event.key == pygame.K_BACKSPACE:
                            if self.active_field in self.fields:
                                self.text_values[self.active_field] = self.text_values[self.active_field][:-1]
                        elif event.key == pygame.K_RETURN:
                            if self.active_field == "sign_up_btn":
                                self.attempt_create_async()
                            elif self.active_field == "cancel_btn":
                                return "login"
                        else:
                            char = event.unicode
                            if char.isprintable() and self.active_field in self.fields:
                                max_len = 12 if self.active_field in ["username"] else 22
                                if len(self.text_values[self.active_field]) < max_len:
                                    self.text_values[self.active_field] += char

                    elif event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
                        mouse_pos = event.pos
                        for name, rect in self.fields_rects.items():
                            if rect.collidepoint(mouse_pos):
                                self.active_field = name
                                if name == "sign_up_btn":
                                    self.attempt_create_async()
                                elif name == "cancel_btn":
                                    return "login"

                clock.tick(config.FPS)
        finally:
            # Restore original on_message handler
            self.client.on_message = original_callback

        return None
