import asyncio
import datetime
import json
import logging
import os
import time
import ssl
import sys
from enum import Enum
from operator import itemgetter

from aiohttp import web

from . import fetch
from . import fs
from . import util
from . import auth_middleware
from .throttle import throttle

MessageType = Enum('MessageType', 'status resources')

class GroveServer:
  """Main class for the grove-server."""

  def __init__(self, config, http_port, agent_port, agent_timeout, directory, debug):
    self.config = config
    self.http_port = config.get("server_http_port", http_port)
    self.ssl_key = config.get("server_ssl_key")
    self.ssl_cert = config.get("server_ssl_cert")
    self.username = config.get("server_username")
    self.password = config.get("server_password")
    self.agent_port = config.get("server_port", agent_port)
    self.agent_timeout = config.get("agent_timeout", agent_timeout)
    self.skip_connect = config.get("skip_connect", False)
    self.home_dir = directory if directory is not None else config.get("server_home_dir", ".")
    self.cache_dir = config.get("server_cache_dir", os.path.join(self.home_dir, 'cache'))
    self.resource_dir = config.get("server_resource_dir", os.path.join(self.home_dir, 'resource'))
    self.user_resource_dir = config.get("server_user_resource_dir", os.path.join(self.home_dir, 'user'))
    self.cache_size = config.get("cache_size", 3000000000)
    self.cache_expiry = config.get("cache_expiry", -1)
    self.access_logs = config.get("access_logs", debug)
    self.manager_notifier = util.ThrottledManagerNotifier(
      config.get("notify_endpoint", ""),
      config.get("throttle_time_sec", 5))
    self.agents = {}
    self.log = logging.getLogger(__name__)
    self.read_status_cache()

    self.log.info(
      "http_port [%d], agent_port [%d], agent_timeout [%d], home_dir [%s], cache_dir [%s], resource_dir [%s], user_resource_dir [%s], cache_size [%d], cache_expiry [%d], access_logs [%s]",
      self.http_port, self.agent_port, self.agent_timeout, self.home_dir, self.cache_dir, self.resource_dir,
      self.user_resource_dir, self.cache_size, self.cache_expiry, self.access_logs)

  def new_agent_protocol(self):
    return GroveAgentServerProtocol(self, self.agent_timeout)

  def status_cache_file(self):
    return os.path.join(self.home_dir, 'status.json')

  def file_cache_dir(self):
    return self.cache_dir

  def clean_tgz_cache(self):
    if not os.path.exists(self.file_cache_dir()):
      return

    file_list = [os.path.join(dp, f) for dp, dn, fn in os.walk(self.file_cache_dir()) for f in fn]
    file_info_list = [(x, os.path.getsize(x), os.stat(x).st_atime) for x in file_list if os.path.isfile(x)]
    file_info_list.sort(key=itemgetter(2))

    now = time.time()
    cache_size = sum(x[1] for x in file_info_list)
    self.log.debug("Cache size: %d bytes", cache_size)

    i = 0
    for i, file_info in enumerate(file_info_list):
      if (self.cache_expiry != -1 and now - self.cache_expiry > file_info[2]) or cache_size > self.cache_size:
        cache_size -= file_info[1]
      else:
        break

    for f in file_info_list[0:i + 1 if cache_size == 0 else i]:
      self.log.debug("Removing from cache: %s", f[0])
      os.remove(f[0])

  def read_status_cache(self):
    if self.agents:
      raise Exception('must be called first, before populating agents')

    if os.path.exists(self.status_cache_file()):
      with open(self.status_cache_file(), 'r') as f:
        status_cache = json.loads(f.read())

      for agent_host, agent_obj in status_cache["agents"].items():
        self.log.debug('Loaded status for agent {}: {}'.format(agent_host, agent_obj))
        self.agents[agent_host] = AgentTracker(agent_obj)

      self.log.debug('Loaded {} agents from cache.'.format(len(self.agents)))
    else:
      self.log.debug('Cache file missing, skipping cache load.')

  @throttle(seconds=5)
  def write_status_cache(self):
    status_cache = {
      "agents": {}
    }
    for agent_host, agent_obj in self.agents.items():
      agent_cache_obj = agent_obj.mk_cache_obj()
      self.log.debug('Writing status for agent {}: {}'.format(agent_host, agent_cache_obj))
      status_cache["agents"][agent_host] = agent_cache_obj

    fs.write_file(self.status_cache_file(), json.dumps(status_cache))
    self.log.debug('Wrote {} agents to cache.'.format(len(self.agents)))

  def symlink_resources_to_cache(self):
    for dir in [self.resource_dir, self.user_resource_dir]:
      for resource_file in [os.path.join(dp, f) for dp, dn, fn in os.walk(dir) for f in fn]:
        cache_file_path = resource_file.replace(dir, self.file_cache_dir())
        if not os.path.exists(cache_file_path):
          fs.mkdirs(os.path.dirname(cache_file_path))
          os.symlink(resource_file, cache_file_path)
          self.log.info('[Symlink] {} -> {}'.format(resource_file, cache_file_path))

  def receive_agent_update(self, handler, message, type):
    host = handler.host()
    if host in self.agents and self.agents[host].handler.hostport() == handler.hostport():
      self.log.debug("Got [{}] update from agent: {} for {}: {}".format(type.name, handler.hostport(), host, message))

      self.agents[host].status_time = time.time()

      needs_cache_update = False
      if type == MessageType.status:
        if self.agents[host].status != message:
          self.agents[host].status = message
          needs_cache_update = True
      elif type == MessageType.resources:
        if self.agents[host].resources != message:
          self.agents[host].resources = message
          needs_cache_update = True
      else:
        self.log.warning("Asked to handle unsupported message type [{}]".format(type.name))

      if needs_cache_update:
        self.write_status_cache()
    else:
      self.log.debug("Skipping [{}] update: {} for {}".format(type, handler.hostport(), host))

  def register_agent_handler(self, handler):
    host = handler.host()
    if host not in self.agents:
      self.agents[host] = AgentTracker(None)

    if self.agents[host].handler:
      self.log.info("Disconnecting old agent: {} for {}".format(handler.hostport(), host))
      self.agents[host].handler.transport.close()

    self.log.debug("Registering agent: {} for {}".format(handler.hostport(), host))
    self.agents[host].handler = handler
    self.agents[host].ip = handler.peeraddr
    self.agents[host].hostname = handler.hostname
    self.agents[host].image_version = handler.image_version
    self.manager_notifier.maybe_notify()

  def unregister_agent_handler(self, handler):
    host = handler.host()
    if host in self.agents and self.agents[host].handler.hostport() == handler.hostport():
      self.log.debug("Unregistering agent: {} for {}".format(handler.hostport(), host))
      self.agents[host].handler = None
      self.manager_notifier.maybe_notify()
    else:
      self.log.warn("Skipping unregister request: {} for {}".format(handler.hostport(), host))

  def run(self):
    ssl_context = None
    if (self.ssl_key is not None and self.ssl_cert is not None):
      ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
      ssl_context.load_cert_chain(self.ssl_cert, self.ssl_key)
      # ssl_context.minimum_verion only supported for python version 3.7 ~ 3.9
      # From python 3.10, the TLSv1 and TLSv1_1 are removed (end of support) by default
      if (sys.version_info >= (3, 7) and sys.version_info < (3, 10)):
        ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
      elif (sys.version_info < (3, 7)):
        ssl_context.options |= ssl.OP_NO_TLSv1
        ssl_context.options |= ssl.OP_NO_TLSv1_1

    loop = asyncio.get_event_loop()
    agent_server = loop.run_until_complete(
      loop.create_server(
        lambda: self.new_agent_protocol(),
        '0.0.0.0',
        self.agent_port,
        ssl=ssl_context
      )
    )

    http_handler = GroveHttpServer(self).new_handler()
    http_server = loop.run_until_complete(
      loop.create_server(
        http_handler,
        '0.0.0.0',
        self.http_port,
        ssl=ssl_context
      )
    )

    try:
      loop.run_forever()
    finally:
      agent_server.close()
      http_server.close()
      loop.run_until_complete(agent_server.wait_closed())
      loop.run_until_complete(http_server.wait_closed())
      loop.close()
      self.write_status_cache()

class AgentTracker:
  """The GroveServer uses one of these to track each agent."""

  def __init__(self, cache_obj):
    self.handler = None
    self.status = None
    self.status_time = 0
    self.resources = None
    self.ip = None
    self.hostname = None
    self.image_version = None

    if cache_obj:
      self.status = cache_obj["status"]
      self.status_time = cache_obj["status_time"]
      self.resources = cache_obj.get("resources")
      self.ip = cache_obj.get("ip")
      self.hostname = cache_obj.get("hostname")
      self.image_version = cache_obj.get("image_version")

  def mk_cache_obj(self):
    return {
      "status": self.status,
      "status_time": self.status_time,
      "resources": self.resources,
      "ip": self.ip,
      "hostname": self.hostname,
      "image_version": self.image_version
    }

class GroveHttpServer:
  def __init__(self, parent):
    self.parent = parent
    self.app = web.Application()
    self.fetch_lock = asyncio.Lock()
    self.app.router.add_route('GET', '/', self.get_root)
    self.app.router.add_route('GET', '/agents', self.get_agents)
    self.app.router.add_route('GET', r'/repo/{file:[^{}]+}', self.get_repo)
    self.app.router.add_route('POST', '/do', self.get_do)
    self.app.router.add_route('POST', '/reap', self.get_reap)
    self.app.router.add_route('POST', '/fetch', self.get_fetch)
    self.app.router.add_route('GET', '/health', self.get_health)

    if (self.parent.username is not None and self.parent.password is not None):
      self.app.middlewares.append(
        auth_middleware.BasicAuthMiddleware(
          r"^(?!\/health$).*",
          { self.parent.username: self.parent.password }
        )
      )

    self.log = logging.getLogger(__name__)

    self.parent.symlink_resources_to_cache()

  def new_handler(self):
    return self.app.make_handler() if self.parent.access_logs else self.app.make_handler(access_log=None)

  async def get_root(self, request):
    return web.Response(body='hello world!'.encode())

  async def get_do(self, request):
    request_obj = await request.json()

    futures = []
    for agent_host in request_obj['agents']:
      agent_tracker = self.parent.agents.get(agent_host)
      if agent_tracker and agent_tracker.handler:
        future = agent_tracker.handler.send_command(request_obj['command'])
      else:
        future = asyncio.Future()
        future.set_result({
          "host": agent_host,
          "ok": False,
          "result": 'Agent offline'
        })

      futures.append(future)

    results = await asyncio.gather(*futures)

    return web.Response(
      body=json.dumps({"results": results}).encode()
    )

  async def get_reap(self, request):
    request_obj = await request.json()

    remaining = []
    for agent_host in request_obj['agents']:
      agent_tracker = self.parent.agents.get(agent_host)
      if agent_tracker:
        if agent_tracker.handler:
          # Agent exists, and is online
          remaining.append(agent_host)
        else:
          # Agent exists, and is offline
          del self.parent.agents[agent_host]

    return web.Response(
      body=json.dumps({"remaining": remaining}).encode()
    )

  async def get_agents(self, request):
    result = {}
    for agent_host, agent_tracker in self.parent.agents.items():
      result[agent_host] = {
        'online': bool(agent_tracker.handler),
        'status': agent_tracker.status,
        'status_time': agent_tracker.status_time,
        'ip': agent_tracker.ip,
        'hostname': agent_tracker.hostname,
        'resources': agent_tracker.resources,
        'image_version': agent_tracker.image_version
      }

    return web.Response(
      body=json.dumps({"agents": result}).encode()
    )

  # TODO: Handle this on a per-file basis instead of a global lock (self.fetch_lock)
  async def get_fetch(self, request):
    def callback(result):
      self.fetch_lock.release()

    request_obj = await request.json()
    force_refresh = "refresh" in request_obj and request_obj["refresh"].lower() == 'true'

    self.parent.clean_tgz_cache()

    try:
      await self.fetch_lock.acquire()
      asyncio.ensure_future(fetch.fetch_files_async(request_obj["filenames"], self.parent.file_cache_dir(),
                                                    request_obj.get("server_repositories",
                                                                    self.parent.config.get("server_repositories")),
                                                    force_refresh)).add_done_callback(callback)
      return web.Response()
    except Exception as e:
      self.log.warning("While trying to fetch %s: %s", request_obj["filenames"], e)
      raise web.HTTPNotFound(text="Error fetching {}: {}".format(request_obj["filenames"], e))

  async def get_repo(self, request):
    filename = request.match_info['file']
    cache_path = os.path.join(self.parent.file_cache_dir(), filename)
    repositories = self.parent.config.get("server_repositories")
    if request.rel_url.query.get("repo") is not None:
      repositories = [{"url": request.rel_url.query["repo"]}]

    force_refresh = "refresh" in request.rel_url.query and request.rel_url.query["refresh"].lower() == 'true'
    existence_only = "check_exists" in request.rel_url.query
    timeout = int(request.rel_url.query["timeout"]) if "timeout" in request.rel_url.query else None

    self.parent.clean_tgz_cache()

    if existence_only:
      if self.fetch_lock.locked():
        status = 202
      elif os.path.exists(cache_path):
        status = 200
      else:
        self.parent.symlink_resources_to_cache()
        status = 200 if os.path.exists(cache_path) else 404

      return web.Response(status=status)

    try:
      async with self.fetch_lock:
        resp = await fetch.fetch_files_async(
          [filename],
          self.parent.file_cache_dir(),
          repositories,
          force_refresh,
          timeout
        )
        return web.FileResponse(resp[0])
    except Exception as e:
      # If we couldn't pull from remote or cache, update symlinks from resource directory and try again
      self.log.debug("Failed to fetch [%s], attempting to symlink grove resource file", filename)
      self.parent.symlink_resources_to_cache()
      if os.path.exists(cache_path):
        return web.FileResponse(cache_path)

      self.log.warning("While trying to fetch [%s]: %s", filename, e)
      raise web.HTTPNotFound(
        text="Could not fetch [{}] from any remote repository {}: {}".format(filename, repositories, e))

  async def get_health(self, request):
    return web.Response(status=200, body=datetime.datetime.now(datetime.timezone.utc).isoformat())

class GroveAgentServerProtocol(asyncio.Protocol):
  def __init__(self, parent, agent_timeout):
    self.parent = parent
    self.agent_timeout = agent_timeout
    self.buffer = ''
    self.log = logging.getLogger(__name__)
    self.futures = {}
    self.futurecounter = 0
    self.hostname = None
    self.image_version = None

  # asyncio protocol function
  def connection_made(self, transport):
    self.transport = transport
    peername_raw = self.transport.get_extra_info('peername')
    self.peername = "{}:{}".format(peername_raw[0], str(peername_raw[1]))
    self.peeraddr = self.transport.get_extra_info('peername')[0]
    self.peerport = self.transport.get_extra_info('peername')[1]
    self.set_heartbeat_timeout()
    self.name = self.peeraddr
    self.connected = self.parent.skip_connect
    if (self.connected and (self.parent.username is None or self.parent.password is None)):
      self.log.info('agent: {}: Agent connected (skipped)'.format(self.peername))
      self.parent.register_agent_handler(self)

  # asyncio protocol function
  def data_received(self, data):
    try:
      self.buffer += data.decode()
      nl_index = self.buffer.find("\r\n")
      while nl_index > -1:
        command_str = self.buffer[0:nl_index]
        self.buffer = self.buffer[(nl_index + 2):]
        nl_index = self.buffer.find("\r\n")

        command = command_str.split(" ", 1)
        response = None
        close = False
        self.log.debug('agent: {}: Received command: {}'.format(self.peername, command))

        if command == ['imok']:
          pass
        elif command[0] == 'connect':
          if (self.connected):
            self.log.debug("agent: {}: Superfluous connect message", self.peername)
            return
          self.connected = True
          connect_obj = json.loads(command[1])
          self.log.debug('agent: {}: conn: {}'.format(self.host(), connect_obj))
          self.name = connect_obj.get('name') or self.peeraddr
          self.hostname = connect_obj.get('hostname')
          self.image_version = connect_obj.get('image_version')
          if (self.parent.username is None or self.parent.password is None):
            self.log.info('agent: {}: Agent connected'.format(self.name))
            self.parent.register_agent_handler(self)
        elif command[0] == 'auth':
          auth_obj = json.loads(command[1])
          if (self.connected and auth_obj.get('username') == self.parent.username and auth_obj.get('password') == self.parent.password):
            self.log.info('agent: {}: Agent connected (authenticated)'.format(self.peername))
            self.parent.register_agent_handler(self)
          else:
            self.log.warn('agent: {}: Agent provided invalid authentication'.format(self.peername))
            close = True
        elif command[0] in ['status', 'resources']:
          message_obj = json.loads(command[1])
          self.parent.receive_agent_update(self, message_obj, MessageType[command[0]])
        elif command[0] == 'ok' or command[0] == 'fail':
          future_ok = command[0] == 'ok'
          rest = command[1].split(" ", 1)
          future_id = int(rest[0])
          future = self.futures.get(future_id)
          if future is not None:
            future_result = rest[1]
            if not future.done():
              self.log.info('agent: {}: Completed future {}: {}'.format(self.peername, future_id, future_result))
              future.set_result({
                "host": self.host(),
                "ok": future_ok,
                "result": future_result
              })
            del self.futures[future_id]
          else:
            self.log.warn("agent: {}: Completed bogus future: {}".format(self.peername, future_id))
            close = True
        else:
          self.log.warn('agent: {}: Received bad command: {}'.format(self.peername, command_str))
          close = True

        if response is not None:
          self.log.debug('agent: {}: Sending response: {}'.format(self.peername, response))
          self.transport.write((response + "\r\n").encode())

        if close:
          self.log.info('agent: {}: Closing connection'.format(self.peername))
          self.transport.close()

        self.cancel_heartbeat_timeout()
        self.set_heartbeat_timeout()
    except Exception as e:
      self.log.warn('agent: {}: Error reading from agent: {}'.format(self.peername, repr(e)))
      self.transport.close()

  # asyncio protocol function
  def connection_lost(self, e):
    self.log.info('agent: {}: Connection lost ({})'.format(self.peername, e))
    self.cancel_heartbeat_timeout()
    self.parent.unregister_agent_handler(self)
    for future_id, future in self.futures.items():
      if not future.done():
        self.log.info('agent: {}: Future {} failed on connection loss'.format(
          self.peername,
          future_id
        ))
        future.set_result({
          "host": self.host(),
          "ok": False,
          "result": 'Connection lost'
        })
    self.futures = {}

  def host(self):
    """Peer host, identifies a particular remote agent."""
    return self.name

  def hostport(self):
    """Peer host & port, identifies a particular handler."""
    return self.peername

  def send_command(self, command_obj):
    """Sends a command to this client, returns a future for the result."""
    command = json.dumps(command_obj)
    if "\r" in command or "\n" in command:
      raise ValueError("Command cannot contain CR or LF")
    self.futurecounter += 1
    future_id = self.futurecounter
    future = asyncio.Future()
    self.futures[future_id] = future
    self.log.info('agent: {}: Sending command {}: {}'.format(self.peername, str(future_id), util.sanitizedLog(command)))
    self.transport.write("do {} {}\r\n".format(str(future_id), command).encode())
    return future

  def cancel_heartbeat_timeout(self):
    self.h_timeout.cancel()

  def set_heartbeat_timeout(self):
    self.h_timeout = asyncio.get_event_loop().call_later(self.agent_timeout, self.timed_out)

  def set_command_timeout(self, future_id):
    """Each command has its own timeout that keeps ticking even though client heartbeats."""
    asyncio.get_event_loop().call_later(
      self.agent_timeout,
      lambda: self.command_timed_out(future_id)
    )

  def command_timed_out(self, future_id):
    future = self.futures.get(future_id)
    if future and not future.done():
      self.log.info('agent: {}: Future {} timed out'.format(self.peername, future_id))
      future.set_result({
        "host": self.host(),
        "ok": False,
        "result": 'Timed out'
      })
      del self.futures[future_id]

  def timed_out(self):
    self.log.info('agent: {}: Timed out'.format(self.peername))
    self.transport.close()
