#!/usr/bin/env python2
# Copyright 2013 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cPickle
import gzip
import multiprocessing
import optparse
import os
import subprocess
import sys
import threading
import time
import zlib

# Return the width of the terminal, or None if it couldn't be
# determined (e.g. because we're not being run interactively).
def term_width(out):
  if not out.isatty():
    return None
  try:
    p = subprocess.Popen(["stty", "size"],
                         stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    (out, err) = p.communicate()
    if p.returncode != 0 or err:
      return None
    return int(out.split()[1])
  except (IndexError, OSError, ValueError):
    return None

# Output transient and permanent lines of text. If several transient
# lines are written in sequence, the new will overwrite the old. We
# use this to ensure that lots of unimportant info (tests passing)
# won't drown out important info (tests failing).
class Outputter(object):
  def __init__(self, out_file):
    self.__out_file = out_file
    self.__previous_line_was_transient = False
    self.__width = term_width(out_file)  # Line width, or None if not a tty.
  def transient_line(self, msg):
    if self.__width is None:
      self.__out_file.write(msg + "\n")
    else:
      self.__out_file.write("\r" + msg[:self.__width].ljust(self.__width))
      self.__previous_line_was_transient = True
  def permanent_line(self, msg):
    if self.__previous_line_was_transient:
      self.__out_file.write("\n")
      self.__previous_line_was_transient = False
    self.__out_file.write(msg + "\n")

stdout_lock = threading.Lock()

class FilterFormat:
  out = Outputter(sys.stdout)
  total_tests = 0
  finished_tests = 0

  tests = {}
  outputs = {}
  failures = []

  def print_test_status(self, last_finished_test, time_ms):
    self.out.transient_line("[%d/%d] %s (%d ms)"
                            % (self.finished_tests, self.total_tests,
                               last_finished_test, time_ms))

  def handle_meta(self, job_id, args):
    (command, arg) = args.split(' ', 1)
    if command == "TEST":
      (binary, test) = arg.split(' ', 1)
      self.tests[job_id] = (binary, test.strip())
      self.outputs[job_id] = []
    elif command == "EXIT":
      (exit_code, time_ms) = [int(x) for x in arg.split(' ', 1)]
      self.finished_tests += 1
      (binary, test) = self.tests[job_id]
      self.print_test_status(test, time_ms)
      if exit_code != 0:
        self.failures.append(self.tests[job_id])
        for line in self.outputs[job_id]:
          self.out.permanent_line(line)
        self.out.permanent_line(
          "[%d/%d] %s returned/aborted with exit code %d (%d ms)"
          % (self.finished_tests, self.total_tests, test, exit_code, time_ms))
    elif command == "TESTCNT":
      self.total_tests = int(arg.split(' ', 1)[1])
      self.out.transient_line("[0/%d] Running tests..." % self.total_tests)

  def add_stdout(self, job_id, output):
    self.outputs[job_id].append(output)

  def log(self, line):
    stdout_lock.acquire()
    (prefix, output) = line.split(' ', 1)

    if prefix[-1] == ':':
      self.handle_meta(int(prefix[:-1]), output)
    else:
      self.add_stdout(int(prefix[:-1]), output)
    stdout_lock.release()

  def end(self):
    if self.failures:
      self.out.permanent_line("FAILED TESTS (%d/%d):"
                              % (len(self.failures), self.total_tests))
      for (binary, test) in self.failures:
        self.out.permanent_line(" " + binary + ": " + test)

class RawFormat:
  def log(self, line):
    stdout_lock.acquire()
    sys.stdout.write(line + "\n")
    sys.stdout.flush()
    stdout_lock.release()
  def end(self):
    pass

# Record of test runtimes. Has built-in locking.
class TestTimes(object):
  def __init__(self, save_file):
    "Create new object seeded with saved test times from the given file."
    self.__times = {}  # (test binary, test name) -> runtime in ms

    # Protects calls to record_test_time(); other calls are not
    # expected to be made concurrently.
    self.__lock = threading.Lock()

    try:
      with gzip.GzipFile(save_file, "rb") as f:
        times = cPickle.load(f)
    except (EOFError, IOError, cPickle.UnpicklingError, zlib.error):
      # File doesn't exist, isn't readable, is malformed---whatever.
      # Just ignore it.
      return

    # Discard saved times if the format isn't right.
    if type(times) is not dict:
      return
    for ((test_binary, test_name), runtime) in times.items():
      if (type(test_binary) is not str or type(test_name) is not str
          or type(runtime) not in {int, long}):
        return

    self.__times = times

  def get_test_time(self, binary, testname):
    "Return the last duration for the given test, or 0 if there's no record."
    return self.__times.get((binary, testname), 0)

  def record_test_time(self, binary, testname, runtime_ms):
    "Record that the given test ran in the specified number of milliseconds."
    with self.__lock:
      self.__times[(binary, testname)] = runtime_ms

  def write_to_file(self, save_file):
    "Write all the times to file."
    try:
      with open(save_file, "wb") as f:
        with gzip.GzipFile("", "wb", 9, f) as gzf:
          cPickle.dump(self.__times, gzf, cPickle.HIGHEST_PROTOCOL)
    except IOError:
      pass  # ignore errors---saving the times isn't that important

# Remove additional arguments (anything after --).
additional_args = []

for i in range(len(sys.argv)):
  if sys.argv[i] == '--':
    additional_args = sys.argv[i+1:]
    sys.argv = sys.argv[:i]
    break

parser = optparse.OptionParser(
    usage = 'usage: %prog [options] binary [binary ...] -- [additional args]')

parser.add_option('-r', '--repeat', type='int', default=1,
                  help='repeat tests')
parser.add_option('-w', '--workers', type='int',
                  default=multiprocessing.cpu_count(),
                  help='number of workers to spawn')
parser.add_option('--gtest_color', type='string', default='yes',
                  help='color output')
parser.add_option('--gtest_filter', type='string', default='',
                  help='test filter')
parser.add_option('--gtest_also_run_disabled_tests', action='store_true',
                  default=False, help='run disabled tests too')
parser.add_option('--format', type='string', default='filter',
                  help='output format (raw,filter)')

(options, binaries) = parser.parse_args()

if binaries == []:
  parser.print_usage()
  sys.exit(1)

logger = RawFormat()
if options.format == 'raw':
  pass
elif options.format == 'filter':
  logger = FilterFormat()
else:
  sys.exit("Unknown output format: " + options.format)

# Find tests.
save_file = os.path.join(os.path.expanduser("~"), ".gtest-parallel-times")
times = TestTimes(save_file)
tests = []
for test_binary in binaries:
  command = [test_binary]
  if options.gtest_also_run_disabled_tests:
    command += ['--gtest_also_run_disabled_tests']

  list_command = list(command)
  if options.gtest_filter != '':
    list_command += ['--gtest_filter=' + options.gtest_filter]

  try:
    test_list = subprocess.Popen(list_command + ['--gtest_list_tests'],
                                 stdout=subprocess.PIPE).communicate()[0]
  except OSError as e:
    sys.exit("%s: %s" % (test_binary, str(e)))

  command += additional_args

  test_group = ''
  for line in test_list.split('\n'):
    if not line.strip():
      continue
    if line[0] != " ":
      test_group = line.strip()
      continue
    line = line.strip()
    if not options.gtest_also_run_disabled_tests and 'DISABLED' in line:
      continue
    line = line.split('#')[0].strip()
    if not line:
      continue

    test = test_group + line
    tests.append((times.get_test_time(test_binary, test),
                  test_binary, test, command))
tests.sort(reverse=True)

# Repeat tests (-r flag).
tests *= options.repeat
test_lock = threading.Lock()
job_id = 0
logger.log(str(-1) + ': TESTCNT ' + ' ' + str(len(tests)))

exit_code = 0

# Run the specified job. Return the elapsed time in milliseconds if
# the job succeeds, or a very large number (larger than any reasonable
# elapsed time) if the job fails. (This ensures that failing tests
# will run first the next time.)
def run_job((command, job_id, test)):
  begin = time.time()
  sub = subprocess.Popen(command + ['--gtest_filter=' + test] +
                           ['--gtest_color=' + options.gtest_color],
                         stdout = subprocess.PIPE,
                         stderr = subprocess.STDOUT)

  while True:
    line = sub.stdout.readline()
    if line == '':
      break
    logger.log(str(job_id) + '> ' + line.rstrip())

  code = sub.wait()
  runtime_ms = int(1000 * (time.time() - begin))
  logger.log("%s: EXIT %s %d" % (job_id, code, runtime_ms))
  if code == 0:
    return runtime_ms
  global exit_code
  exit_code = code
  return sys.maxint

def worker():
  global job_id
  while True:
    job = None
    test_lock.acquire()
    if job_id < len(tests):
      (_, test_binary, test, command) = tests[job_id]
      logger.log(str(job_id) + ': TEST ' + test_binary + ' ' + test)
      job = (command, job_id, test)
    job_id += 1
    test_lock.release()
    if job is None:
      return
    times.record_test_time(test_binary, test, run_job(job))

def start_daemon(func):
  t = threading.Thread(target=func)
  t.daemon = True
  t.start()
  return t

workers = [start_daemon(worker) for i in range(options.workers)]

[t.join() for t in workers]
logger.end()
times.write_to_file(save_file)
sys.exit(exit_code)
