"""Tom's solution to Asteracer. Moves on the vertices of the asteroid graph by simulating all angles."""
import os
from typing import Optional

import drawsvg as draw
import networkx as nx
import numpy as np

from asteracer import *

directional_instructions_cache = {}

Vertex = tuple[int, int]


def load_asteroid_graph(path: str):
    with open(path) as file:
        lines = [line.strip() for line in file.readlines()]

    # Filter out comments and empty lines
    contents = [line for line in lines if line and not line.startswith('#')]

    iter_lines = iter(contents)
    first_line = list(map(int, next(iter_lines).split()))
    n_racer, n_asteroid, n_goal, m = first_line

    vertices = []
    edges = []
    vertex_objects = []

    # Load vertices
    for i in range(n_racer + n_asteroid + n_goal):
        line = list(map(int, next(iter_lines).split()))
        vertices.append((line[0], line[1]))

        if i < n_racer:
            vertex_objects.append(('S', i))
        elif i < (n_racer + n_asteroid):
            vertex_objects.append(('A', line[2]))
        else:
            vertex_objects.append(('G', line[2]))

    # Load edges
    for _ in range(m):
        line = list(map(int, next(iter_lines).split()))
        edges.append((line[0], line[1]))

    return vertices, edges, vertex_objects


def _yield_points_at_distance(x: float, y: float, r: float, n: int):
    """Generate n points uniformly at distance r from the coordinates (x, y)."""
    for i in range(n):
        t = (i / n) * np.pi * 2
        yield x + np.cos(t) * r, y + np.sin(t) * r


def _get_directional_instructions(
        n: int = 1000,
        sort_by_angle_to: tuple[float, float] = None,
        limit_angle=None
) -> list[Instruction]:
    """Generate normalized instructions in at most n different angles with maximum velocity."""
    instructions = []

    if n in directional_instructions_cache:
        instructions = directional_instructions_cache[n]

    else:
        for point in _yield_points_at_distance(0, 0, 127, n):
            i1 = Instruction(*point)

            if i1 in instructions:
                continue

            for i2 in instructions:
                if np.sign(i1.vx) == np.sign(i2.vx) and abs(i1.vx) <= abs(i2.vx) \
                        and np.sign(i1.vy) == np.sign(i2.vy) and abs(i1.vy) <= abs(i2.vy):
                    break
            else:
                instructions.append(i1)

        directional_instructions_cache[n] = instructions

    def instr_vector_angle(i_from: Instruction, v_to: tuple[float, float]) -> float:
        """The angle between an instruction and a vector."""
        i_vec = np.array([i_from.vx, i_from.vy])
        o_vec = np.array(v_to)
        return (1 - np.dot(i_vec, o_vec) / (np.linalg.norm(i_vec) * np.linalg.norm(o_vec))) / 2

    # possibly sort by an angle to a given instruction
    if sort_by_angle_to:
        instructions = sorted(instructions, key=lambda x: instr_vector_angle(x, sort_by_angle_to))

        if limit_angle:
            return [i for i in instructions if abs(instr_vector_angle(i, sort_by_angle_to)) <= limit_angle]
        else:
            return instructions
    else:
        return instructions


def instructions_to_position(
        simulation: Simulation,
        position: Vertex,
        try_ticks=7,
        max_distance=1.25,
) -> Optional[list[Instruction]]:
    """Return a sequence of instructions that brings us the closest to the provided coordinates.
    Returns None if there is no such set that brings us at most max_distance * racer_radius away."""
    x, y = position

    best_distance = float('inf')
    best_instructions = None

    def is_safe(simulation: Simulation) -> bool:
        """Returns True if the racer has a sequence of instructions to not crash at this stage of the simulation."""
        simulation.push()
        for instruction in _get_directional_instructions(10):
            simulation.apply()

            for _ in range(try_ticks):
                result = simulation.tick(instruction)

                if result & TickFlag.COLLIDED:
                    break
            else:
                simulation.pop()
                return True

        simulation.pop()
        return False

    simulation.push()
    for inst in _get_directional_instructions(
            sort_by_angle_to=(x - int(simulation.racer.x), y - int(simulation.racer.y)),
            limit_angle=0.5
    ):
        simulation.apply()
        started_improving = False
        prev_dist = 0

        i = 0
        while True:
            result = simulation.tick(inst)

            if result & TickFlag.COLLIDED:
                break

            distance = euclidean_distance(simulation.racer.x, simulation.racer.y, x, y)

            # if we're closer than before, we started improving
            if distance < prev_dist:
                started_improving = True

            # if we've started improving but are now getting worse, terminate this instruction
            if started_improving and distance > prev_dist:
                break

            prev_dist = distance

            # if we improved and are safe, save results
            if distance < best_distance and is_safe(simulation):
                best_distance = distance
                best_instructions = [inst] * i

            i += 1

    simulation.pop()

    if best_distance > max_distance * simulation.racer.radius:
        return None

    return best_instructions


def get_preview(simulation: Simulation, size=1000):
    """Generate preview of the simulation as an SVG."""
    d = draw.Drawing(size, size, origin='center')

    # background
    d.append(
        draw.Rectangle(
            -size / 2, -size / 2,
            size, size,
            fill="White"
        )
    )

    def draw_circles(circles: list[Asteroid | Goal | Racer], color: str):
        """Draw a circle with a specific color on the SVG."""
        s = simulation.bounding_box.width() / size

        for i, circle in enumerate(circles):
            d.append(draw.Circle(circle.x / s, circle.y / s, circle.radius / s, fill=color, stroke=color))

    draw_circles(simulation.asteroids, "Black")
    draw_circles([g for i, g in enumerate(simulation.goals) if simulation.reached_goals[i]], "LightGreen")
    draw_circles([g for i, g in enumerate(simulation.goals) if not simulation.reached_goals[i]], "Red")
    draw_circles([simulation.racer], "Gray")

    return d


def get_solution_preview(
        simulation: Simulation,
        simulation_path: Optional[list[Vertex]],
        color: str = "Green"
) -> draw.Drawing:
    """Return an SVG object with the solution graph, possibly with the simulation too."""
    d = get_preview(simulation)
    s = simulation.bounding_box.width() / d.width

    def draw_path(path, color, stroke_width, stroke_opacity) -> draw.Path:
        p = draw.Path(stroke_width=stroke_width, stroke=color, opacity=stroke_opacity, fill_opacity=0)

        p.M(path[0][0] / s, path[0][1] / s)
        for i in range(1, len(path)):
            p.L(path[i][0] / s, path[i][1] / s)
        d.append(p)

    if simulation_path:
        d.append(draw_path(simulation_path, color, simulation.racer.radius / (s * 2), 1))

    return d


def get_checkpoints_path(
        simulation, vertices, edges, goal_vertices,
        longest_edges,
):
    """Get a path that the rocket should follow to obtain all goals."""

    def astar_heuristic(p1, p2):
        return np.linalg.norm(np.array(p1) - np.array(p2))

    def perturb_shortest_path(G, best_path, penalty_factor=2.0):
        """ Perturbs the shortest path by increasing the weight of a random edge. """
        if len(best_path) < 2:
            return best_path  # No modification possible
        
        # Randomly choose an edge along the best path
        idx = random.randint(2, len(best_path) - 3)
        u, v = best_path[idx], best_path[idx + 1]

        if G.has_edge(u, v):
            original_weight = G[u][v].get("weight", 1.0)
            G[u][v]["weight"] = original_weight * penalty_factor
            if G.has_edge(v, u):  # If the graph is undirected
                G[v][u]["weight"] = original_weight * penalty_factor

        perturbed_path = nx.astar_path(G, best_path[0], best_path[-1], heuristic=astar_heuristic)

        return perturbed_path

    # convert vertices, edges to nx.Graph
    G = nx.Graph()
    for v in vertices:
        G.add_node(v)
    for u, v in edges:
        G.add_edge(u, v, weight=np.linalg.norm(np.array(u) - np.array(v)))

    # for more than 1 goal, we have to solve TSP (at least approximately)
    if len(simulation.goals) > 1:
        G_tsp = nx.DiGraph()

        tsp_vertices = [(simulation.racer.x, simulation.racer.y)]

        for g in simulation.goals:
            tsp_vertices.append((g.x, g.y))
            G_tsp.add_node(tsp_vertices[-1])

        count = 0
        for i, u in enumerate(tsp_vertices):
            for v in tsp_vertices[i + 1:]:
                print(f"TSP: building graph {count + 1}/{len(tsp_vertices) * (len(tsp_vertices) - 1) // 2}")
                w = nx.astar_path_length(G, u, v, heuristic=astar_heuristic)
                G_tsp.add_edge(u, v, weight=w)
                G_tsp.add_edge(v, u, weight=w)

                count += 1

        print(f"TSP: graph built")

        # nx doesn't play nice with infinite values, but this might as well be
        INF = 100000000000

        # we solve TSP by adding a hack vertex to force us to try all possible ending vertices
        # this means that we have to solve #number_of_vertices instances of TSP
        G_tsp.add_node("hack")
        G_tsp.add_edge(vertices[0], "hack", weight=0)
        G_tsp.add_edge("hack", vertices[0], weight=INF)

        best_cost = float('inf')
        best_path = None

        for i, g in enumerate(simulation.goals):
            print(f"TSP: solving {i + 1}/{len(simulation.goals)}")

            for g_hack in simulation.goals:
                v = (g_hack.x, g_hack.y)
                G_tsp.add_edge("hack", v, weight=0 if g is g_hack else INF)
                G_tsp.add_edge(v, "hack", weight=INF)

            path = nx.approximation.simulated_annealing_tsp(G_tsp, init_cycle="greedy")
            path.pop()

            cost = 0
            for j in range(len(path)):
                cost += G_tsp[path[j - 1]][path[j]]["weight"]

            for g_hack in simulation.goals:
                v = (g_hack.x, g_hack.y)
                G_tsp.remove_edge("hack", v)
                G_tsp.remove_edge(v, "hack")

            path = list(reversed(path))
            j = path.index("hack")
            path.remove("hack")
            path = path[j:] + path[:j]

            assert len(path) == len(simulation.goals) + 1, "Not all goals visited!"

            if cost < best_cost:
                best_path = path
                best_cost = cost
                print(f"TSP: new optimum of length {best_cost}")

        path = [(int(simulation.racer.x), int(simulation.racer.y))]
        for i in range(len(best_path) - 1):
            path += nx.astar_path(G, best_path[i], best_path[i + 1], heuristic=astar_heuristic)[1:]

        # last vertex is a center of some goal
        path = path[:-1]

        # shorten the path by cutting out vertices
        while True:
            # try to remove vertices for goals
            for i in range(len(path) - 2):
                if not G.has_edge(path[i], path[i + 2]):
                    continue

                # if it's not a part of a goal, we can remove for free
                if path[i + 1] not in goal_vertices:
                    path.pop(i + 1)
                    break

                # if it is and is not the only one, we can remove it too
                goal_vertices_count = 0
                for v in path:
                    if v in goal_vertices and goal_vertices[v] == goal_vertices[path[i + 1]]:
                        goal_vertices_count += 1

                if goal_vertices_count != 1:
                    path.pop(i + 1)
                    break
            else:
                break

        # rotate goal vertices (we can take an arbitrary one to get the goal)
        while True:
            improved = False

            for i in range(len(path) - 2):
                u = path[i + 1]

                if u not in goal_vertices:
                    continue

                for v in vertices:
                    if v not in goal_vertices:
                        continue

                    if goal_vertices[u] != goal_vertices[v]:
                        continue

                    if not G.has_edge(path[i], v) or not G.has_edge(v, path[i + 2]):
                        continue

                    u_dist = G[path[i]][u]["weight"] + G[u][path[i + 2]]["weight"]
                    v_dist = G[path[i]][v]["weight"] + G[v][path[i + 2]]["weight"]

                    if u_dist > v_dist:
                        path = path[:i + 1] + [v] + path[i + 2:]
                        improved = True
                        break

            if not improved:
                break

        # TODO: remove goal vertices entirely if the edge exists and intersects it
        # TODO: for each goal, attempt to remove vertex and connect astar to next and to previous

        print("Solved with TSP.")
    else:
        path = nx.astar_path(
            G,
            (simulation.racer.x, simulation.racer.y),
            (simulation.goals[0].x, simulation.goals[0].y),
            heuristic=astar_heuristic,
        )

        path = perturb_shortest_path(G, path, penalty_factor=1000)

        print("Solved with A*.")

    # split long edges
    while True:
        split = False

        i = 0
        while i < len(path) - 1:
            if euclidean_distance(*path[i], *path[i + 1]) > longest_edges:
                x = path[i][0] + path[i + 1][0]
                y = path[i][1] + path[i + 1][1]
                path.insert(i + 1, (x / 2, y / 2))
                split = True
            i += 1

        if not split:
            break

    return path


def finalize_instructions(simulation, instructions):
    """Finalize the instructions (cut useless ones + extend if the simulation isn't finished)."""
    simulation.restart()
    for i, instruction in enumerate(instructions):
        simulation.tick(instruction)

        # there might be some redundant instructions at the very end
        if simulation.finished():
            break

    instructions = instructions[:i + 1]

    # if we ended just before the goal, repeat last instruction
    simulation.simulate(instructions)
    i = 0
    while not simulation.finished():
        simulation.tick(instruction)
        instructions.append(instruction)
        i += 1

        if i > 1000:
            break

    return instructions


def get_solution_path(simulation, instructions):
    """Generate the path the racer took using the instructions."""
    simulation.restart()
    solution_path = [(simulation.racer.x, simulation.racer.y)]
    for i, instruction in enumerate(instructions):
        simulation.tick(instruction)
        solution_path.append((simulation.racer.x, simulation.racer.y))

    return solution_path


def solve(simulation, path):
    """Solve the simulation by following the path."""
    steps_instructions = []
    randomize_step_counter = 0
    i = 1
    while i < len(path):
        v = path[i]

        for goal in simulation.goals:
            if euclidean_distance(goal.x, goal.y, v[0], v[1]) < goal.radius + simulation.racer.radius:
                is_goal_position = True
                break
        else:
            is_goal_position = False

        print(f"Simulating {i}/{len(path) - 1}{'' if not is_goal_position else ' to goal'}", end=": ", flush=True)

        if randomize_step_counter:
            print("randomized, ", end="")

            # only move towards the goal
            if is_goal_position:
                p2g_vector = np.array([goal.x, goal.y]) - np.array(v)
                v = (p2g_vector / np.linalg.norm(p2g_vector)) * np.random.normal(0, simulation.racer.radius, 1)
            else:
                v = np.array(v) + np.random.normal(0, simulation.racer.radius, 2)

        new_instructions = instructions_to_position(simulation, v, max_distance=2 if not is_goal_position else 0.75)

        if new_instructions is None:
            if len(steps_instructions) != 0:
                randomize_step_counter += 1

                i -= randomize_step_counter
                for _ in range(randomize_step_counter):
                    steps_instructions.pop()

            print(f"(- backtrack {randomize_step_counter}x)")

            if len(steps_instructions) != 0:
                simulation.simulate(list(np.hstack(steps_instructions)))
            else:
                simulation.restart()

            continue
        else:
            randomize_step_counter = max(randomize_step_counter - 1, 0)

        steps_instructions.append(new_instructions)
        simulation.simulate(list(np.hstack(steps_instructions)))

        print(f"(+ {len(new_instructions)})")

        i += 1

    return finalize_instructions(simulation, list(np.hstack(steps_instructions)))


if __name__ == "__main__":
    iterations = 5000

    for task in ["sprint"]:
        print(f"Solving {task} ({iterations} attempts):")

        os.makedirs(task, exist_ok=True)

        simulation = Simulation.load(f"../mapy/{task}.txt")
        print("Simulation loaded.")

        vertices, edges, object_types = load_asteroid_graph(f"../grafy/{task}.txt")
        edges = [(vertices[u], vertices[v]) for u, v in edges]

        print("Graph loaded.")

        average_asteroid_radius = sum([asteroid.radius for asteroid in simulation.asteroids]) / len(simulation.asteroids)
        goal_vertices = {
            vertices[i]: (simulation.goals[j].x, simulation.goals[j].y)
            for (i, (c, j)) in enumerate(object_types)
            if c == "G"
        }

        best_instructions = None
        best_instructions_length = float('inf')

        i = 0
        while i < iterations:
            simulation.restart()
            path = get_checkpoints_path(
                simulation,
                vertices, edges, goal_vertices,
                longest_edges=average_asteroid_radius * np.random.uniform(0.5, 5),
            )

            instructions = solve(simulation, path)
            i += 1

            if len(instructions) < best_instructions_length:
                best_instructions = instructions
                best_instructions_length = len(instructions)
                print(f"Solved with {len(instructions)} instructions.")

                solution_path = get_solution_path(simulation, best_instructions)
                d = get_solution_preview(simulation, solution_path, "Green")
                d.save_svg(f"{task}/solution.svg")
                save_instructions(f"{task}/solution.txt", best_instructions)

        print(f"Best solution: {len(best_instructions)} instructions.")