#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# UPMC - c.durr - 2016-2017
# Conception et Pratique de l'Algorithmique
# TME 4 - algorithme k-means++

from math import sin, cos, asin, sqrt, radians
from sys import argv, stderr
import random


def dist(p, q):
    """distance à vol d'oiseau entre deux sommets sur le globe, en kilomètres
    """
    dlon = p[1] - q[1]
    dlat = p[0] - q[0]
    a = sin(radians(dlat / 2))**2 + cos(radians(p[0])) * cos(radians(q[0])) * sin(radians(dlon / 2))**2
    return asin( sqrt(a) ) * 6373


def distList(p, centers):
    '''distance entre un point p et une liste de points centers
    retourne le couple (mindist, argmin)
    '''
    assert centers != []
    return min((dist(p, cj), j) for j, cj in enumerate(centers))


def read_gpsfile(gpsfile):
    """lit un fichier texte dont les lignes sont dans le format
    v [identificateur] [latitude] [longitude]
    returns: liste de points
    """
    points = []
    for line in open(gpsfile,'r'):
        assert line[0] == 'v'
        tab = line.split()
        latitude  = float(tab[2])
        longitude = float(tab[3])
        points.append( (latitude, longitude) )
    return points

def init(points, k):
    """retourne une première liste de k points
    complexity: O(n * k * k), n=len(points)
    """
    centers = [random.choice(points)]         # premier point choisit uniformément au hasard
    for _ in range(1, k):                     # ajoute les k-1 points suivants
        select = None
        total = 0
        for pi in points:                     # selon la distribution D^2
            d2 = distList(pi, centers)[0]**2
            total += d2
            if random.uniform(0, total) <= d2:
                select = pi                   # choisit pi avec probabilité proportionnelle à d2
        centers.append(select)
    return centers


def Lloyd(points, centers):
    """recherche locale de Lloyd.
    calculer les cellules de Voronoi
    changer les centres de chaque cellule en son centroïd.
    recalculer les cellules de Voronoï.
    continuer tant qu'il y a des changement d'affectation.
    """
    changed = True
    k = len(centers)
    # affecter chaque point i au centre le plus proche dans L
    cluster = {i: distList(pi, centers)[1] for i, pi in enumerate(points)}
    while changed:
        changed = False
        lat = [0.0] * k                      # calculer les centres de gravité de chaque cellule
        lon = [0.0] * k
        siz = [0] * k
        for i, pi in enumerate(points):
            lat[cluster[i]] += pi[0]
            lon[cluster[i]] += pi[1]
            siz[cluster[i]] += 1
        for a in range(k):
            if siz[a] > 0 :
                centers[a] = (lat[a] / siz[a], lon[a] / siz[a])
        delta = [float('inf')] * k           # delta[j] = distance minimale vers les autres centres
        for j1, cj1 in enumerate(centers):
            for j2, cj2 in enumerate(centers):
                if j1 != j2:                 # seulement vers les *autres* centres
                    d = dist(cj1, cj2)
                    if d < delta[j1]:
                        delta[j1] = d
        obj = 0
        for i, pi in enumerate(points):      # trouver le centre le plus proche pour chaque point
            d = dist(pi, centers[cluster[i]])
            if d * 2 > delta[cluster[i]]:
                d, closest = distList(pi, centers)
                if closest != cluster[i]:    # changement dans l'association point->centre ?
                    cluster[i] = closest
                    changed = True
            obj += d*d
        print("objective value = %.10g" % obj, file = stderr)   # pour montrer la progression
    return (centers, obj)


def print_solution(points, centers, obj):
    """affiche une solution donnée centers dans le format demandé
    """
    for i, pi in enumerate(points):
        print("v %i %.6f %.6f" % (i, pi[0], pi[1]))
    for cj in centers:
        print("s %.6f %.6f" % cj)
    print("o %.10g" % obj)


if __name__ == "__main__":
    if len(argv) != 3:
        print("usage: kmeansplusplus.py <k> <gpsfilename>")
        exit(1)

    k = int(argv[1])
    points = read_gpsfile(argv[2])
    centers, obj = Lloyd(points, init(points, k))
    print_solution(points, centers, obj)

