# pylint: disable=C0103
import logging
import math
import random
import sqlite3
import time
from typing import Callable, Dict, List, Optional, Set
from configparser import ConfigParser

_default_log = logging.getLogger(__name__)

def _solve(f: Callable[[float], float], eps: float, x0: float, x1: float) \
        -> float:
    """Solves f(x) = 0 for x ∈ (x0, x1], assuming a unique solution exists."""
    u0 = f(x0)
    u1 = f(x1)
    while abs(u1) > eps:
        if u0 * u1 > 0.0:
            raise RuntimeError('Cannot solve, same sign.')
        x = 0.5 * (x0 + x1)
        u = f(x)
        if u * u0 < 0.0:
            x1 = x
            u1 = u
        else:
            x0 = x
            u0 = u
    return x1

def _func(lambda_: float) -> Callable[[float], float]:
    def f(pS: float):
        return math.exp(1.0 - lambda_ / pS)
    return f

def _norm(success_dist: Dict[str, float]) -> Callable[[float], float]:
    def f(lambda_: float):
        return sum(_func(lambda_)(pS) for pS in success_dist.values()) - 1.0
    return f

def choice_dist_of_success_dist(success_dist: Dict[str, float]) \
        -> Dict[str, float]:
    success_dist = dict((k, p) for (k, p) in success_dist.items() if p > 0.0)
    lambda0 = _solve(_norm(success_dist), 1e-6, 0.0, 20.0)
    f0 = _func(lambda0)
    return dict((term, f0(pS)) for (term, pS) in success_dist.items())

_CREATE_SQL = """\
CREATE TABLE arcnagios_reputation (
  dist_name text NOT NULL,
  choice_name text NOT NULL,
  update_time double precision NOT NULL,
  recent_count double precision NOT NULL,
  recent_success double precision NOT NULL,
  PRIMARY KEY (dist_name, choice_name)
)"""

_PAST_DIST_NAMES_SQL = """\
SELECT DISTINCT dist_name FROM arcnagios_reputation
"""

_PAST_CHOICE_NAMES_SQL = """\
SELECT choice_name FROM arcnagios_reputation WHERE dist_name = ?
"""

_FETCH_SQL = """\
SELECT choice_name, recent_count, recent_success
FROM arcnagios_reputation WHERE dist_name = ?
"""

_SUBMIT_SELECT_SQL = """\
SELECT update_time, recent_count, recent_success
FROM arcnagios_reputation WHERE dist_name = ? AND choice_name = ?
"""
_SUBMIT_INSERT_SQL = """\
INSERT INTO arcnagios_reputation \
    (update_time, recent_count, recent_success, dist_name, choice_name)
VALUES (?, ?, ?, ?, ?)
"""
_SUBMIT_UPDATE_SQL = """\
UPDATE arcnagios_reputation
SET update_time = ?, recent_count = ?, recent_success = ?
WHERE dist_name = ? AND choice_name = ?
"""

class ReputationTracker:

    def __init__(self, config: ConfigParser, db_path: str,
                 log: logging.Logger = _default_log):
        self._log = log
        self._config = config
        self._db_path = db_path
        self._db: Optional[sqlite3.Connection] = None
        self._choices: Dict[str, str] = {}

    def _config_float(self, var: str, default: float) -> float:
        if self._config.has_section('reputation') \
                and self._config.has_option('reputation', var):
            return self._config.getfloat('reputation', var)
        return default

    def _config_dist_float(self, dist_name: str, var: str,
                           default: Optional[float] = None) -> Optional[float]:
        section_name = 'reputation_dist:' + dist_name
        if self._config.has_section(section_name) \
                and self._config.has_option(section_name, var):
            return self._config.getfloat(section_name, var)
        return default

    @property
    def _busy_timeout(self) -> float:
        return self._config_float('busy_timeout', 10.0)

    @property
    def _default_sample_lifetime(self) -> float:
        return self._config_float('sample_lifetime', 172800.0)

    def _connect(self) -> sqlite3.Connection:
        if self._db is None:
            self._db = sqlite3.connect(self._db_path, self._busy_timeout)
            try:
                self._db.execute(_CREATE_SQL)
            except sqlite3.OperationalError:
                pass
        return self._db

    def past_dist_names(self) -> List[str]:
        db = self._connect()
        return [name for (name,) in db.execute(_PAST_DIST_NAMES_SQL)]

    def past_choice_names(self, dist_name: str) -> List[str]:
        db = self._connect()
        return [name for (name,)
                in db.execute(_PAST_CHOICE_NAMES_SQL, (dist_name,))]

    def disconnect(self) -> None:
        if not self._db is None:
            self._db.close()
            self._db = None

    def success_dist(self, dist_name: str) -> Dict[str, float]:
        db = self._connect()
        cur = db.execute(_FETCH_SQL, (dist_name,))
        return dict((k, (nS + 0.25) / (n + 0.5)) for (k, n, nS) in cur)

    def choice_dist(self, dist_name: str, choice_names: Set[str],
                    success_dist: Optional[Dict[str, float]] = None) \
            -> Dict[str, float]:
        if success_dist is None:
            success_dist = self.success_dist(dist_name)
        if success_dist == {}:
            avg_success = 0.5
        else:
            avg_success = sum(success_dist.values()) / len(success_dist)
        restricted_success_dist = \
            dict((k, success_dist.get(k, avg_success)) for k in choice_names)
        return choice_dist_of_success_dist(restricted_success_dist)

    def submit(self, dist_name: str, choice_name: str, is_success: bool) \
            -> None:
        db = self._connect()
        rows = db.execute(_SUBMIT_SELECT_SQL, (dist_name, choice_name)) \
                .fetchall()
        t_now = time.time()
        if rows == []:
            t_past, recent_count, recent_success = (t_now, 0.0, 0.0)
        else:
            assert len(rows) == 1
            t_past, recent_count, recent_success = rows[0]

        sample_lifetime = self._config_dist_float(dist_name, 'sample_lifetime')
        scale = math.exp((t_past - t_now) \
              / (sample_lifetime or self._default_sample_lifetime))
        recent_count = scale * recent_count + 1.0
        recent_success *= scale
        if is_success:
            recent_success += 1.0
        db.execute(
                rows == [] and _SUBMIT_INSERT_SQL or _SUBMIT_UPDATE_SQL,
                (t_now, recent_count, recent_success, dist_name, choice_name))
        db.commit()

    def _choose_otr(self, dist_name: str, choice_names: Set[str]) -> str:
        choice_dist = self.choice_dist(dist_name, choice_names)
        p = random.uniform(0.0, 1.0)
        for (choice_name, pS) in choice_dist.items():
            p -= pS
            if p < 0.0:
                return choice_name
        return choice_names.pop() # precision loss, return any

    def choose(self, dist_name: str, choice_names: Set[str]) -> str:
        if not choice_names:
            raise ValueError("ReputationTracker.choose expects a non-empty "
                             "sequence of choices.")
        choice_name = self._choose_otr(dist_name, choice_names)
        self._choices[dist_name] = choice_name
        return choice_name

    def choices(self) -> Dict[str, str]:
        return self._choices
