###
### Copyright (C) 2018-2019 Intel Corporation
###
### SPDX-License-Identifier: BSD-3-Clause
###

from datetime import datetime as dt
import itertools
from lib import artifacts
from lib.baseline import Baseline
import lib.system
import os
import re
import slash
from slash.utils.traceback_utils import get_traceback_string
import sys
import xml.etree.cElementTree as et

__SCRIPT_DIR__ = os.path.abspath(os.path.dirname(__file__))

# LibVA used to load the i965 driver by default, but now it loads the iHD driver
# by default.  For backwards compatibility, explicitly set VAAPI driver to i965
# via the user environment variable.
os.environ.setdefault("LIBVA_DRIVER_NAME", "i965")

# We don't support any of the built-in slash plugins
slash.plugins.manager.uninstall_all()

slash.config.root.log.root = os.path.join(__SCRIPT_DIR__, "results")
slash.config.root.log.last_session_symlink = "session.latest.log"
slash.config.root.log.last_session_dir_symlink = "session.latest"
slash.config.root.log.highlights_subpath = "highlights.latest.log"
slash.config.root.log.colorize = False
slash.config.root.log.unified_session_log = True
slash.config.root.log.truncate_console_lines = False
slash.config.root.run.dump_variation = True
slash.config.root.run.default_sources = ["test"]
slash.config.root.log.subpath = os.path.join(
  "{context.session.id}",
  "{context.test.__slash__.module_name}",
  "{context.test.__slash__.class_name}",
  "{context.test.__slash__.function_name}({context.test.__slash__.variation.safe_repr}).log")

def validate_unique_cases(tree):
  for e in tree.findall(".//testcase"):
    occurrences = tree.findall(".//testcase[@name='{}'][@classname='{}']".format(
      e.get("name"), e.get("classname")))
    if len(occurrences) > 1:
      slash.logger.warn("{} occurrences of testcase found: {} {}".format(
        len(occurrences), e.get("classname"), e.get("name")))

def ansi_escape(text):
  return ansi_escape.prog.sub('', text)
ansi_escape.prog = re.compile(r'(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]')

class MediaPlugin(slash.plugins.PluginInterface):
  testspec = dict()

  platform = None
  render_device = "/dev/dri/renderD128"
  suite = os.path.basename(sys.argv[0])
  mypath = __SCRIPT_DIR__
  blacklist = {}

  RETENTION_NONE = 0;
  RETENTION_FAIL = 1;
  RETENTION_ALL  = 2;

  def get_name(self):
    return "media"

  def configure_argument_parser(self, parser):
    parser.add_argument("--rebase", action = "store_true")
    parser.add_argument("--baseline-file",
      default = os.path.abspath(os.path.join(self.mypath, "baseline", "default")))
    parser.add_argument("--artifact-retention", default = artifacts.Retention.NONE,
      type = int, help = (
        f"{int(artifacts.Retention.NONE)} = Keep None; "
        f"{int(artifacts.Retention.FAIL)} = Keep Failed; "
        f"{int(artifacts.Retention.ALL)} = Keep All"
        " (default: %(default)i)"))
    parser.add_argument("--call-timeout", default = 300, type = int,
      help = "call timeout in seconds")
    parser.add_argument("--ctapt", default = -1, type = int,
      help = "number of call timeouts allowed per test function")
    parser.add_argument("--ctapr", default = -1, type = int,
      help = "number of call timeouts allowed per run")
    parser.add_argument("--parallel-metrics", action = "store_true")
    parser.add_argument("--platform", default = None, type = str, required = True)
    parser.add_argument("--device", default = '/dev/dri/renderD128', type = str, help = 'vaapi-fits run with render device')

  def configure_from_parsed_args(self, args):
    self.baseline = Baseline(args.baseline_file, args.rebase)
    self.artifacts = artifacts.Artifacts(args.artifact_retention)
    self.call_timeout = args.call_timeout
    self.parallel_metrics = args.parallel_metrics
    self.ctapt = args.ctapt
    self.ctapr = args.ctapr
    self.platform = args.platform
    self.render_device = args.device

    assert not (args.rebase and slash.config.root.parallel.num_workers > 0), "rebase in parallel mode is not supported"
    assert not (args.ctapt != -1 and slash.config.root.parallel.num_workers > 0), "ctapt in parallel mode is not supported"
    assert not (args.ctapr != -1 and slash.config.root.parallel.num_workers > 0), "ctapr in parallel mode is not supported"

    if self._get_os() == 'linux' and "NONE" != self.platform.upper():
      assert os.path.exists(args.device), "Target Device {} not exist".format(args.device)

    # Gstreamer msdk/vaapi plugins will register elements based on the
    # device.  In a multi-gpu system, this device is arbitrary by default.
    # The default may not represent the actual device we are targeting for
    # testing.  Therefore, we must set the target device env variable for the
    # msdk/vaapi plugins before any of the slash.requires are evaluated so
    # that test cases are not inadvertently skipped due to missing feature
    # on one device vs. the other.
    os.environ["GST_MSDK_DRM_DEVICE"] = self.render_device
    os.environ["GST_VAAPI_DRM_DEVICE"] = self.render_device

  def _calls_allowed(self):
    # only enabled during test function execution (see test_start/test_end)
    if not self.cta_enabled:
      return True

    meta = slash.context.result.test_metadata
    ntimeouts = self.num_ctapt.get((meta.file_path, meta.function_name), 0)

    allowed = self.ctapr <= -1 or self.num_ctapr <= self.ctapr
    allowed = allowed and (self.ctapt <= -1 or ntimeouts <= self.ctapt)
    if not allowed:
      slash.logger.notice("Call Timeouts Allowed: limit reached!")

    return allowed

  def _report_call_timeout(self):
    # only enabled during test function execution
    if not self.cta_enabled:
      return

    meta = slash.context.result.test_metadata
    self.num_ctapr += 1
    self.num_ctapt.setdefault((meta.file_path, meta.function_name), 0)
    self.num_ctapt[(meta.file_path, meta.function_name)] += 1

  def _test_state_value(self, key, default):
    data = slash.context.result.data.setdefault("_state_values_", dict())
    class state_value:
      def __init__(self, val):
        self.value = val
    return data.setdefault(key, state_value(default))

  def _expand_context(self, context):
    from lib.platform import info
    import types
    for c in context:
      if type(c) is types.FunctionType:
        c = c()
        if c is None: continue

      sc = str(c).strip().lower()
      if "driver" == sc:
        sc = "drv.{driver}"
      elif "platform" == sc:
        sc = "plat.{platform}"
      elif "gpu.gen" == sc:
        sc = "gpu.gen.{gpu[gen]}"
      elif sc.startswith("key:"):
        continue

      yield sc.format(**info())

  def _get_ref_addr(self, context):
    path, case = slash.context.test.__slash__.address.split(':')
    keyctx = list(filter(lambda c: str(c).startswith("key:"), context))
    if len(keyctx):
      key = keyctx[0].lstrip("key:")
    else:
      key = os.path.relpath(path, self.mypath)
    return ':'.join([key, case])

  def _set_test_details(self, **kwargs):
    for k, v in kwargs.items():
      slash.context.result.details.set(k, v)
      slash.logger.info("DETAIL: {} = {}".format(k, v))

  def _get_test_spec(self, *args):
    spec = self.testspec
    for key in args:
      spec = spec.setdefault(key, dict())
    return spec.setdefault("--spec--", dict())

  def _iter_test_spec(self):
    def inner(spec, *context):
      for k, v in spec.items():
        if "--spec--" == k:
          for name, params in v.items():
            yield context, name, params
        else:
          yield from inner(v, *context, k)
    yield from inner(self.testspec)

  def _get_driver_name(self):
    # TODO: query vaapi for driver name (i.e. use ctypes to call vaapi)
    return os.environ.get("LIBVA_DRIVER_NAME", None) or "i965"

  def _get_platform_name(self):
    return self.platform

  def _get_gpu_gen(self):
    from lib.platform import info
    return info()["gpu"]["gen"]

  def _get_call_timeout(self):
    if self.test_call_timeout > 0:
       return self.test_call_timeout
    else:
       return self.call_timeout

  def _get_os(self):
    from lib.platform import info
    return info()["os"]

  def _get_physical_mem_gb(self):
    from lib.platform import info
    return info()["mem"]

  def test_start(self):
    test = slash.context.test
    result = slash.context.result
    variation = test.get_variation().values
    self.test_call_timeout = 0
    #self._set_test_details(**variation)
    result.data.update(test_start = dt.now())

    # Begin system capture for test (i.e. dmesg).
    # NOTE: syscapture is not accurate for parallel runs
    if slash.config.root.parallel.worker_id is None:
      self.syscapture.checkpoint()

    # enable call timeout tracking during test function execution (see test_end)
    self.cta_enabled = True
    # check test-case whether at blacklist
    self._assert_not_blacklisted(test)

  def test_end(self):
    # only enabled during test function execution (see test_start)
    self.cta_enabled = False

    test = slash.context.test
    result = slash.context.result

    # Process system capture result (i.e. dmesg)
    # NOTE: syscapture is not accurate for parallel runs
    if slash.config.root.parallel.worker_id is None:
      capture = self.syscapture.checkpoint()
      hangmsgs = [
        r"\[.*\] i915 .*: \[drm\] Resetting .* after gpu hang",
        r"\[.*\] i915 .*: \[drm\] Resetting .* for hang on .*",
        r"\[.*\] i915 .*: \[drm\] .*GPU HANG",
        r"\[.*\] i915 .*: \[drm\] .*GPU hang",
        r"\[.*\] i915 .*: \[drm\] Resetting .* time out",
        r"\[.*\] xe .*: \[drm\] .* suspended",
        r"\[.*\] xe .*: \[drm\] Timedout job:",
        r"\[.*\] xe .*: \[drm\] Engine reset:",
      ]
      for msg in hangmsgs:
        if re.search(msg, capture, re.MULTILINE) is not None:
          slash.logger.error("GPU HANG DETECTED!")
      if re.search(r"\[.*\] i915 .*: \[drm\] .*Failed to reset", capture, re.MULTILINE) is not None:
        result.add_error('GPU RESET FAILED')
      for line in capture.split('\n'):
        if len(line):
          slash.logger.info(line)

    # Finally, calculate test execution time
    result.data.update(test_end = dt.now())
    time = (result.data["test_end"] - result.data["test_start"]).total_seconds()
    result.data.update(time = str(time))
    self._set_test_details(time = "{} seconds".format(time))

  def before_session_start(self):
    self.assets = artifacts.MediaAssets()
    for _, _, params in self._iter_test_spec():
      self.assets.register(params)

  def session_start(self):
    self.session_start = dt.now()

    self.syscapture = lib.system.Capture()

    # only enabled during test function execution (see test_start/test_end)
    self.cta_enabled = False
    self.num_ctapt = dict()
    self.num_ctapr = 0
    self.test_call_timeout = 0

    # setup metrics_pool
    self.metrics_pool = None
    if self.parallel_metrics:
      import multiprocessing, signal
      handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
      self.metrics_pool = multiprocessing.Pool()
      signal.signal(signal.SIGINT, handler)

  def session_end(self):
    if self.metrics_pool is not None:
      self.metrics_pool.close()
      self.metrics_pool.join()

    if slash.config.root.parallel.worker_id is not None:
      return

    self.baseline.finalize()

    time = (dt.now() - self.session_start).total_seconds()
    tests = slash.context.session.results.get_num_results()
    errors = slash.context.session.results.get_num_errors()
    failures = slash.context.session.results.get_num_failures()
    skipped = slash.context.session.results.get_num_skipped()

    from lib.platform import info

    suite = et.Element(
      "testsuite", name = self.suite, disabled = "0", tests = str(tests),
      errors = str(errors), failures = str(failures), skipped = str(skipped),
      time = str(time), timestamp = self.session_start.isoformat(), **info())

    for result in slash.context.session.results.iter_test_results():
      suitename, casename = result.test_metadata.address.split(':')
      classname = os.path.splitext(suitename)[0].replace(os.sep, '.').strip('.')
      classname = "{}.{}".format(self.suite, classname)

      case = et.SubElement(
        suite, "testcase", name = casename, classname = classname,
        time = result.data.get("time") or "0")

      outfile = result.get_log_path()
      if os.path.exists(outfile):
        with open(outfile, 'rb', 0) as out:
          value = "".join(line.decode("utf-8") for line in out) 
          et.SubElement(case, "system-out").text = ansi_escape(value)

      for error in itertools.chain(result.get_errors(), result.get_failures()):
        exc_type, exc_value, _ = exc_info = sys.exc_info()
        tag = "failure" if error.is_failure() else "error"
        et.SubElement(
          case, tag, message = error.message,
          type = exc_type.__name__ if exc_type else tag).text = ansi_escape(
            get_traceback_string(exc_info) if exc_value is not None else "")

      for skip in result.get_skips():
        case.set("skipped", "1")
        et.SubElement(case, "skipped", type = skip or '')

      for name, value in result.details.all().items():
        et.SubElement(case, "detail", name = name, value = str(value))

    tree = et.ElementTree(suite)

    validate_unique_cases(tree)

    filename = os.path.join(
      slash.context.session.results.global_result.get_log_dir(), "results.xml")
    tree.write(filename)

  def _load_blacklist(self, file_path = None):
    if file_path is None or len(file_path) == 0 or not os.path.exists(file_path):
      return

    import json, glob
    for file in glob.glob(os.path.join(str(file_path), "*.json")):
      with open(file, 'rb') as f:
        self.blacklist.update(json.load(f))

  def _assert_not_blacklisted(self, test):
    # Current platform and driver combine not in blacklist
    from slash.utils.pattern_matching import Matcher
    search= [
      str(test.__slash__),
      "driver=" + self._get_driver_name(),
      "platform=" + self._get_platform_name(),
    ]

    for filter,message in self.blacklist.items():
      assert not Matcher(filter).matches(" ".join(search)), 'Blacklisted: ' + message

media = MediaPlugin()
slash.plugins.manager.install(media, activate = True, is_internal = True)

# Allow user to override test config file via environment variable.
# NOTE: It would be nice to use the configure_argument_parser mechanism in our
# media plugin instead.  However, it does not work since
# configure_argument_parser does not get called for "slash list"... it only gets
# invoked for "slash run".  Hence, we use an environment var so that we can
# always load the config file when slash loads this file.
config = os.environ.get(
  "VAAPI_FITS_CONFIG_FILE",
  os.path.abspath(os.path.join(media.mypath, "config", "default")))
assert os.path.exists(config)

# extra imports for user config
from lib.codecs import Codec

with open(config, 'rb') as f:
  exec(f.read())

blacklist = os.environ.get(
  'VAAPI_FITS_BLACKLIST',
  os.path.abspath(os.path.join(media.mypath, "blacklist")))
media._load_blacklist(blacklist)

def validate_case_name(name, context):
  from slash.core.variation import _PRINTABLE_CHARS
  import numbers
  _VALID_TYPES = (numbers.Integral, str)

  assert isinstance(name, _VALID_TYPES), f"""
    Illegal type for test case name: name = {name}, type = {type(name)}, context = {context}
    Legal types: {_VALID_TYPES}
  """

  assert _PRINTABLE_CHARS.issuperset(str(name)), f"""
    Illegal character(s) in test case name: name = {name}, context = {context}
    Legal characters: {''.join(sorted(_PRINTABLE_CHARS))}
  """

for context, name, params in media._iter_test_spec():
  validate_case_name(name, context)

###
### kate: syntax python;
###