#!/usr/bin/env python
# Copyright 1999-2012 Gentoo Foundation
# Distributed under the terms of the GNU General Public License v2
# $Header: $
#
# Zac Medico <zmedico@gentoo.org>
#

from __future__ import print_function

import datetime
import errno
import logging
import optparse
import os
import signal
import sys
import time

try:
	import thread
except ImportError:
	# Python >=3.0
	import _thread as thread

try:
	import threading
except ImportError:
	import dummy_threading as threading

import portage
from portage.dep import Atom, use_reduce, paren_enclose
from portage.repository.config import _gen_valid_repo
from portage.update import fixdbentries, grab_updates
from itertools import chain

def eval_use(s, use):
	return paren_enclose(use_reduce(s, uselist=use, token_class=Atom))

def create_syncronized_func(myfunc, mylock):
	def newfunc(*pargs, **kwargs):
		mylock.acquire()
		try:
			myfunc(*pargs, **kwargs)
		finally:
			mylock.release()
	return myfunc

class ConsoleUpdate(object):

	_synchronized_methods = ["append", "carriageReturn",
		"newLine", "reset", "update"]

	def __init__(self):
		self.offset = 0
		self.stream = sys.stdout
		self.quiet = False
		self._lock = threading.RLock()
		for method_name in self._synchronized_methods:
			myfunc = create_syncronized_func(
				getattr(self, method_name), self._lock)
			setattr(self, method_name, myfunc)
		# ANSI code that clears from the cursor to the end of the line
		self._CLEAR_EOL = None
		try:
			import curses
			try:
				curses.setupterm()
				self._CLEAR_EOL = curses.tigetstr('el')
			except curses.error:
				pass
		except ImportError:
			pass
		if not self._CLEAR_EOL:
			self._CLEAR_EOL = '\x1b[K'

	def acquire(self, **kwargs):
		return self._lock.acquire(**kwargs)

	def release(self):
		self._lock.release()

	def reset(self):
		self.offset = 0

	def carriageReturn(self):
		if not self.quiet:
			self.stream.write("\r")
			self.stream.write(self._CLEAR_EOL)
			self.offset = 0

	def newLine(self):
		if not self.quiet:
			self.stream.write("\n")
			self.stream.flush()
			self.reset()

	def update(self, msg):
		if not self.quiet:
			self.carriageReturn()
			self.append(msg)

	def append(self, msg):
		if not self.quiet:
			self.offset += len(msg)
			self.stream.write(msg)
			self.stream.flush()

class ProgressCounter(object):
	def __init__(self):
		self.total = 0
		self.current = 0

class ProgressAnalyzer(ProgressCounter):
	def __init__(self):
		self.start_time = time.time()
		self.currentTime = self.start_time
		self._samples = []
		self.sampleCount = 20
	def percentage(self, digs=0):
		if self.total > 0:
			float_percent = 100 * float(self.current) / float(self.total)
		else:
			float_percent = 0.0
		return ("%%.%df" % digs) % float_percent
	def totalTime(self):
		self._samples.append((self.currentTime, self.current))
		while len(self._samples) > self.sampleCount:
			self._samples.pop(0)
		prev_time, prev_count = self._samples[0]
		time_delta = self.currentTime - prev_time
		if time_delta > 0:
			rate = (self.current - prev_count) / time_delta
			if rate > 0:
				return self.total / rate
		return 0
	def remaining_time(self):
		return self.totalTime() - self.elapsed_time()
	def elapsed_time(self):
		return self.currentTime - self.start_time

class ConsoleProgress(object):
	def __init__(self, name="Progress", console=None):
		self.name = name
		self.analyzer = ProgressAnalyzer()
		if console is None:
			self.console = ConsoleUpdate()
		else:
			self.console = console
		self.time_format="%H:%M:%S"
		self.quiet = False
		self.lastUpdate = 0
		self.latency = 0.5

	def formatTime(self, t):
		return time.strftime(self.time_format, time.gmtime(t))

	def displayProgress(self, current, total):
		if self.quiet:
			return

		self.analyzer.currentTime = time.time()
		if self.analyzer.currentTime - self.lastUpdate < self.latency:
			return
		self.lastUpdate = self.analyzer.currentTime
		self.analyzer.current = current
		self.analyzer.total = total

		output = ((self.name, self.analyzer.percentage(1).rjust(4) + "%"),
		("Elapsed", self.formatTime(self.analyzer.elapsed_time())),
		("Remaining", self.formatTime(self.analyzer.remaining_time())),
		("Total", self.formatTime(self.analyzer.totalTime())))
		self.console.update(" ".join([ x[0] + ": " + x[1] for x in output ]))

class ProgressHandler(object):
	def __init__(self):
		self.curval = 0
		self.maxval = 0
		self.last_update = 0
		self.min_display_latency = 0.2

	def onProgress(self, maxval, curval):
		self.maxval = maxval
		self.curval = curval
		cur_time = time.time()
		if cur_time - self.last_update >= self.min_display_latency:
			self.last_update = cur_time
			self.display()

	def display(self):
		raise NotImplementedError(self)

def open_file(filename=None):
	if filename is None:
		f = sys.stderr
	elif filename == "-":
		f = sys.stdout
	else:
		try:
			filename = os.path.expanduser(filename)
			f = open(filename, "a")
		except (IOError, OSError) as e:
			sys.stderr.write("%s\n" % e)
			sys.exit(e.errno)
	return f

def create_log(name="", logfile=None, loglevel=0):
	log = logging.getLogger(name)
	log.setLevel(loglevel)
	handler = logging.StreamHandler(open_file(logfile))
	handler.setFormatter(logging.Formatter("%(levelname)s %(message)s"))
	log.addHandler(handler)
	return log

def is_interrupt(e):
	if isinstance(e, (SystemExit, KeyboardInterrupt)):
		return True
	return hasattr(e, "errno") and e.errno == errno.EINTR

class PackageMoves(object):

	def __init__(self):
		self.maxval = 0

	def run(self, onProgress=None):
		settings = portage.settings
		trees = portage.db

		updpath = os.path.join(portage.settings["PORTDIR"], "profiles", "updates")
		update_data = grab_updates(updpath)
		myupd = []
		for mykey, mystat, mycontent in update_data:
			valid_updates, errors = portage.parse_updates(mycontent)
			myupd.extend(valid_updates)

		cpv_all = trees[settings["ROOT"]]["vartree"].dbapi.cpv_all()

		maxval = len(cpv_all)
		self.maxval = maxval
		curval = 0
		if onProgress:
			onProgress(maxval, curval)

		vdb_root = os.path.join(settings["ROOT"], portage.VDB_PATH)
		self.updateCount = 0
		while cpv_all:
			if fixdbentries(myupd, os.path.join(vdb_root, cpv_all.pop())):
				self.updateCount += 1
			curval += 1
			if onProgress:
				onProgress(maxval, curval)

class DepTransfer(object):

	pkg_tree_map = {
		"ebuild":"porttree",
		"binary":"bintree",
		"installed":"vartree"}

	use_evaluated_keys = frozenset(("LICENSE", "RDEPEND", "DEPEND",
		"PDEPEND", "PROPERTIES", "PROVIDE"))

	def __init__(self, pkgtype="installed", metadata_keys=None,
		package_list=None, pretend=False):
		self.pkgtype = pkgtype
		self.pretend = pretend
		if metadata_keys:
			self.metadata_keys = metadata_keys.split()
		else:
			self.metadata_keys = ["DEPEND", "RDEPEND", "PDEPEND"]
		self.package_list = package_list
		self.maxval = 0

	def run(self, onProgress=None):

		auxdbkeys = set(self.metadata_keys)
		auxdbkeys.add("EAPI")
		auxdbkeys.add("repository")
		auxdbkeys = tuple(auxdbkeys)
		dest_auxdbkeys = set(auxdbkeys)
		dest_auxdbkeys.add('USE')
		dest_auxdbkeys = sorted(dest_auxdbkeys)
		settings = portage.settings
		trees = portage.db
		portdb = trees[settings["ROOT"]]["porttree"].dbapi
		mydbapi = \
			trees[settings["ROOT"]][self.pkg_tree_map[self.pkgtype]].dbapi

		if self.package_list:
			package_list = self.package_list
		else:
			package_list = mydbapi.cpv_all()
		self.updateCount = 0
		self.missingCount = 0

		maxval = len(package_list)
		self.maxval = maxval
		curval = 0
		if onProgress:
			onProgress(maxval, curval)

		for cpv in package_list:
			try:
				existing_data = dict(zip(dest_auxdbkeys,
					mydbapi.aux_get(cpv, dest_auxdbkeys)))

				myrepo = _gen_valid_repo(existing_data["repository"])
				try:
					port_data = dict(zip(auxdbkeys,
						portdb.aux_get(cpv, auxdbkeys, myrepo=myrepo)))
				except KeyError:
					self.missingCount += 1
					continue

				if port_data["EAPI"] != existing_data["EAPI"]:
					self.missingCount += 1
					continue

				updates = {}
				use = frozenset(existing_data['USE'].split())
				for k in self.metadata_keys:
					v1 = port_data[k]
					v2 = existing_data[k]
					if k in self.use_evaluated_keys:
						v1 = eval_use(v1, use)
						v2 = eval_use(v2, use)
						if v2 != v1:
							updates[k] = v1
					else:
						if set(v2.split()) != set(v1.split()):
							updates[k] = v1
				if updates:
					if self.pretend:
						print(cpv)
					else:
						mydbapi.aux_update(cpv, updates)
					self.updateCount += 1
			finally:
				curval += 1
				if onProgress:
					onProgress(maxval, curval)

def parse_args(myargv):
	description = "This program does maintenance on the database of installed packages."
	usage = "usage: vdb-tools [options] --move || --transfer [cpv] [cpv] ..."
	parser = optparse.OptionParser(description=description, usage=usage)
	parser.add_option("--move",
		help="perform package moves and slot moves",
		action="store_true", dest="move", default=False)
	parser.add_option("--transfer",
		help="transfer metadata from the portage tree to the installed packages",
		action="store_true", dest="transfer", default=False)
	parser.add_option("--metadata-keys",
		help="transfer specific keys instead of the default ones",
		dest="metadata_keys", default=None)
	parser.add_option("--pkgtype",
		help="specify the package type: installed or binary",
		action="store", dest="pkgtype", type="choice",
		choices=("installed", "binary"), default="installed")
	parser.add_option("--reportfile",
		help="send a report to a file",
		dest="reportfile", default=None)
	parser.add_option("--no-progress",
		action="store_false", dest="progress", default=True,
		help="disable progress output to tty")
	parser.add_option("--pretend",
		action="store_true", dest="pretend", default=False,
		help="show a list of packages, but do not operate on them")
	options, args = parser.parse_args(args=myargv)

	# Conversion to dict allows us to use **opts as function args later on.
	opts = {}
	all_options = ("progress", "transfer", "metadata_keys", "move", "pkgtype",
		"pretend", "reportfile")
	for opt_name in all_options:
		v = getattr(options, opt_name)
		opts[opt_name] = v
	if opts["pretend"]:
		opts["progress"] = False
	return opts, args

def run_command(args):
	opts, args = parse_args(sys.argv[1:])
	console = ConsoleUpdate()
	if not opts["progress"] or not sys.stdout.isatty():
		console.quiet = True
	job = None
	shutdown_initiated = threading.Event()
	shutdown_complete = threading.Event()
	def shutdown_console():
		console.acquire()
		try:
			console.update("Interrupted.")
			console.newLine()
			console.quiet = True
			shutdown_complete.set()
			# Kill the main thread if necessary.
			# This causes the SIGINT signal handler to be invoked in the
			# main thread.  The signal handler needs to be an actual
			# callable object (rather than something like signal.SIG_DFL)
			# in order to avoid TypeError: 'int' object is not callable.
			thread.interrupt_main()
			thread.exit()
		finally:
			console.release()

	def handle_interrupt(*args):
		if shutdown_complete.isSet():
			sys.exit(1)
		# Lock the console from a new thread so that the main thread is allowed
		# to cleanly complete any console interaction that may have been in
		# progress when this interrupt arrived.
		if not shutdown_initiated.isSet():
			thread.start_new_thread(shutdown_console, ())
			shutdown_initiated.set()

	signal.signal(signal.SIGINT, handle_interrupt)
	signal.signal(signal.SIGTERM, handle_interrupt)

	try:
		datestamp = str(datetime.datetime.now())
		time_begin = time.time()
		if opts["reportfile"]:
			reportfile = open_file(opts["reportfile"])
		if opts["move"]:
			job = PackageMoves()
			name = "Package moves"
			complete_msg = "Package moves are complete."
		elif opts["transfer"]:
			job = DepTransfer(pkgtype=opts["pkgtype"],
				metadata_keys=opts["metadata_keys"], package_list=args,
				pretend=opts["pretend"])
			name = "Dependency transfer"
			complete_msg = "Dependency transfer is complete."
		else:
			sys.stderr.write("required options: --move || --transfer\n")
			sys.exit(os.EX_USAGE)
		job.opts = opts

		onProgress = None
		if not console.quiet:
			ui = ConsoleProgress(name=name, console=console)
			progressHandler = ProgressHandler()
			onProgress = progressHandler.onProgress
			def display():
				ui.displayProgress(progressHandler.curval, progressHandler.maxval)
			progressHandler.display = display

		job.run(onProgress=onProgress)

		if not console.quiet:
			# make sure the final progress is displayed
			progressHandler.display()

		total_count = job.maxval
		update_count = job.updateCount
		missingCount = getattr(job, "missingCount", None)

		console.update(complete_msg)
		console.newLine()
		time_end = time.time()
		if opts["reportfile"]:
			width = 20
			reportfile.write(name.ljust(width) + "%s\n" % datestamp)
			reportfile.write("Elapsed seconds".ljust(width) + "%f\n" % (time_end - time_begin))
			reportfile.write("Total packages".ljust(width) + "%i\n" % total_count)
			reportfile.write("Updated packages".ljust(width) + "%i\n" % update_count)
			if missingCount:
				reportfile.write("Missing packages".ljust(width) + "%i\n" % missingCount)
			reportfile.write(("-"*50)+"\n")
	except Exception as e:
		if not is_interrupt(e):
			raise
		del e
		handle_interrupt()
	sys.exit(0)

if __name__ == "__main__":
	run_command(sys.argv[1:])
