import _thread
import asyncio
import datetime
import json
import logging
import os
import re
import socket
import ssl
import subprocess
import sys

import yaml
from aiohttp import web

from . import become
from . import config
from . import deploy
from . import dev
from . import dirs
from . import fetch
from . import fs
from . import services
from . import sys_util
from . import util

# For compatibility with python versions 3.6 or earlier.
# asyncio.Task.all_tasks() is fully moved to asyncio.all_tasks() starting with 3.9
try:
  asyncio_all_tasks = asyncio.all_tasks
except AttributeError as e:
  asyncio_all_tasks = asyncio.Task.all_tasks

class GroveAgentClient():
  class ConnectionClosed(Exception):
    pass

  def __init__(self, config, debug):
    self.config = config
    self.homedir = config["homedir"]
    self.svdir = config["svdir"]
    self.name = config.get("name")
    self.server_host = config["server_host"]
    self.server_port = config["server_port"]
    self.ssl_cert = config.get("server_ca_cert")
    self.username = config.get("server_username")
    self.password = config.get("server_password")
    self.timeout = config.get("server_timeout", 10)
    self.heartbeat_secs = config.get("heartbeat_secs", 15)
    self.metadata = config.get("metadata", {})
    self.http_server_port = config.get("http_server_port", 9997)
    self.access_logs = config.get("access_logs", debug)

    if self.timeout <= 0:
      raise ValueError('Expected positive timeout')

    self.log = logging.getLogger(__name__)
    self.message_queue = asyncio.Queue()
    self.connected = False

  def server_str(self):
    return '{}:{}'.format(self.server_host, str(self.server_port))

  async def get_my_status(self):
    if self.config.get("dev_mode"):
      return dev.mock_status()

    # deploy.json
    with open(dirs.deployjson(self.homedir), 'r') as f:
      deploy_json = json.loads(f.read())

    deployed_hash = deploy.activated_hash(self.homedir)

    statuses = await services.status(dirs.svstagedir(self.homedir), self.svdir)

    return {
      "hash": str(deployed_hash),
      "deploy": deploy_json,
      "status": statuses,
      "hostname": socket.gethostname(),
      "metadata": self.metadata
    }

  async def connect(self):
    self.log.info('Connecting to server: {}'.format(self.server_str()))
    ssl_context = None
    hostname = socket.gethostname()
    ip = socket.gethostbyname(hostname)
    if self.name == "__ip__":
      self.name = ip
    elif self.name == "__hostname__":
      self.name = hostname
    if (self.ssl_cert is not None):
      ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
      ssl_context.load_verify_locations(self.ssl_cert)
    connect_info = {
      'version': 1,
      'name': self.name,
      'ip': ip,
      'hostname': hostname
    }
    if os.path.exists('/root/image.version'):
      with open('/root/image.version', 'r') as version_file:
        connect_info['image_version'] = version_file.read().replace('\n', '')
    try:
      reader, writer = await asyncio.wait_for(
        asyncio.open_connection(self.server_host, self.server_port, family=socket.AF_INET, ssl=ssl_context), self.timeout)
      writer.write('connect {}\r\n'.format(json.dumps(connect_info)).encode())
      if (self.username is not None and self.password is not None):
        writer.write('auth {}\r\n'.format(json.dumps({'username': self.username, 'password': self.password})).encode())
      self.connected = True
      self.log.info('Connected to server: {}'.format(self.server_str()))
      return reader, writer
    except asyncio.TimeoutError:
      self.log.warning('Timeout while connecting to: {}'.format(self.server_str()))
    except ConnectionError as e:
      self.log.exception(e)
    except socket.gaierror as e:
      self.log.warning('Error while connecting to [{}]: {}'.format(self.server_str(), e.strerror))

    return None, None

  async def heartbeat(self):
    try:
      while True:
        status_json = json.dumps(await self.get_my_status())
        self.log.debug('Sending heartbeat to server: {}: {}'.format(self.server_str(), status_json))

        await self.message_queue.put('status {}\r\n'.format(status_json).encode())
        await asyncio.sleep(self.heartbeat_secs)
    except asyncio.CancelledError:
      pass
    except Exception as e:
      self.log.exception(e)
      _thread.interrupt_main()

  async def send_resources(self):
    resources_json = json.dumps(sys_util.get_resources())
    self.log.debug('Sending resource info to server: {}: {}'.format(self.server_str(), resources_json))
    await self.message_queue.put('resources {}\r\n'.format(resources_json).encode())

  async def sender(self, writer):
    try:
      while True:
        writer.write(await self.message_queue.get())
    except asyncio.CancelledError:
      pass
    except Exception as e:
      self.log.exception(e)
      _thread.interrupt_main()

  async def listen(self, reader):
    buf = ''
    try:
      while True:
        val = (await reader.read(8192)).decode()
        if val == '':
          self.log.info("Connection to server closed")
          raise self.ConnectionClosed

        buf += val

        nl_index = buf.find("\r\n")
        while nl_index > -1:
          line = buf[0:nl_index]
          buf = buf[(nl_index + 2):]
          nl_index = buf.find("\r\n")

          self.log.info('Read message from server: {}'.format(util.sanitizedLog(line)))

          m = re.match(r'do (\d+) (.+)$', line)
          if m:
            do_id = m.group(1)
            do_str = m.group(2)
            do_obj = json.loads(do_str)

            ok = True
            response = 'nil'

            try:
              if do_obj['type'] == 'become':
                response = await self.do_become(do_obj)
              elif do_obj['type'] == 'bounce':
                response = await self.do_bounce(do_obj)
              elif do_obj['type'] == 'history':
                response = self.do_history(do_obj)
              elif do_obj['type'] == 'run':
                response = await self.do_run(do_obj)
              elif do_obj['type'] == 'write':
                response = await self.do_write(do_obj)
              elif do_obj['type'] == 'sys':
                response = self.do_sys(do_obj)
              else:
                ok = False

            except Exception as e:
              self.log.exception('Command failed: {}'.format(do_obj))
              ok = False
              response = repr(e).replace("\n", "\\n").replace("\r", "\\r")

            if '\r' in response or '\n' in response:
              raise ValueError('Cannot have CR or LF in response: {}'.format(response))

            full_response = '{} {} {}'.format(
              'ok' if ok else 'fail',
              do_id,
              response
            )

            self.log.info('Writing back to server: {}'.format(full_response))
            await self.message_queue.put((full_response + '\r\n').encode())
          else:
            raise Exception('Unexpected message from server: {}'.format(line))
    except asyncio.CancelledError:
      pass
    except self.ConnectionClosed:
      _thread.interrupt_main()
    except Exception as e:
      self.log.exception(e)
      _thread.interrupt_main()

  async def init_http_server(self):
    app = web.Application()
    app.router.add_route('GET', '/health', self.get_health)

    runner = web.AppRunner(app) if self.access_logs else web.AppRunner(app, access_log=None)
    await runner.setup()

    site = web.TCPSite(runner, '0.0.0.0', self.http_server_port)
    await site.start()

    return runner

  async def get_health(self, request):
    payload={'date': datetime.datetime.now(datetime.timezone.utc).isoformat()}
    if self.connected:
      payload['name'] = self.name
    return web.Response(status=200 if self.connected else 503,
                        body=json.dumps(payload).encode())

  def run(self):
    loop = asyncio.get_event_loop()

    http_server = None
    writer = None

    try:
      http_server = loop.run_until_complete(self.init_http_server())
      reader, writer = loop.run_until_complete(self.connect())
      if not (reader or writer):
        sys.exit(1)

      loop.create_task(self.sender(writer))
      loop.create_task(self.heartbeat())
      loop.create_task(self.send_resources())
      loop.create_task(self.listen(reader))

      loop.run_forever()
    except asyncio.CancelledError:
      pass
    finally:
      for task in asyncio_all_tasks():
        task.cancel()
      if writer: writer.close()
      if http_server: loop.run_until_complete(http_server.cleanup())
      loop.close()

  async def do_bounce(self, do_obj):
    # Exclude services marked as always running (Grove agent, Grove server, Sensu, Imply Onprem manager)
    await services.bounce(dirs.svstagedir(self.homedir), self.svdir, services.always_running_services(),
                          do_obj["timeout"] if "timeout" in do_obj else 0)
    return json.dumps({})

  async def do_become(self, do_obj):
    if do_obj.get("service_yaml_path"):
      service_yaml, service_yaml_sha1 = config.load_yaml_file_with_sha1(do_obj["service_yaml_path"])
      pulled = await become.become(self.config, service_yaml, do_obj["service_type"], do_obj.get("post_deploy_params"),
                                   do_obj.get("always_bounce", False), do_obj.get("skip_start", False))

    elif do_obj.get("service_yaml"):
      pulled = await become.become(self.config, dict(yaml.load(do_obj["service_yaml"], Loader=yaml.FullLoader)),
                                   do_obj["service_type"], do_obj.get("post_deploy_params"),
                                   do_obj.get("always_bounce", False), do_obj.get("skip_start", False))

    else:
      pulled = await become.become_prior(self.config, do_obj["hashed_service_parts"], do_obj.get("post_deploy_params"),
                                         do_obj.get("always_bounce", False), do_obj.get("skip_start", False))

    return json.dumps({'hashed_service_parts': pulled})

  async def do_run(self, do_obj):
    try:
      command = util.replace_env_tokens(do_obj["args"], self.config)
      self.log.info('<Executing: {}>'.format(util.sanitizedLog(command)))
      output = await util.subprocess_call(command, do_obj.get("timeout"), cwd="/")
      return json.dumps({'exit_code': 0, 'output': output if do_obj.get("output") == True else None})
    except subprocess.CalledProcessError as e:
      return json.dumps({'exit_code': e.returncode, 'output': e.output})

  def do_history(self, do_obj):
    return json.dumps(deploy.read_history(self.config)["history"])

  async def do_write(self, do_obj):
    if "content" not in do_obj and "source" not in do_obj:
      raise ValueError("Either 'content' or 'source' must be specified")
    if "content" in do_obj and "source" in do_obj:
      raise ValueError("'content' and 'source' cannot both be specified")

    mode = int(str(do_obj["mode"]), 8) if "mode" in do_obj else None

    if "content" in do_obj:
      fs.write_file(util.replace_env_tokens(do_obj["path"], self.config), util.replace_env_tokens(do_obj["content"]), mode)
    else:
      await fetch.fetch_http_async(util.replace_env_tokens(do_obj["source"], self.config),
                                   util.replace_env_tokens(do_obj["path"], self.config),
                                   util.replace_env_tokens(do_obj.get("user"), self.config),
                                   util.replace_env_tokens(do_obj.get("password"), self.config),
                                   mode)

    return json.dumps({})

  def do_sys(self, do_obj):
    return json.dumps(sys_util.get_system_stats())
