# Copyright 1999-2006 Gentoo Foundation
# Distributed under the terms of the GNU General Public License v2
# $Header: $
# kate: encoding utf-8; eol unix;
# kate: indent-width 4; mixedindent off; replace-tabs on; remove-trailing-space on; space-indent on;
# kate: word-wrap-column 120; word-wrap on;

#from cache import template, cache_errors
from cache import fs_template, cache_errors
from cache.template import reconstruct_eclasses
from portage_util import writemsg

import MySQLdb as db_module
import MySQLdb.cursors

DBError = db_module.Error

class database(fs_template.FsBased):

    autocommits = False
    synchronous = False
    _db_module = db_module
    _db_error = DBError
    _db_table = None
    _utf8_unsafe = ("_eclasses_", "DEPEND", "RDEPEND", "SLOT", "RESTRICT",
                    "KEYWORDS", "INHERITED","IUSE", "CDEPEND", "PDEPEND",
                    "PROVIDE", "EAPI")

    def __init__(self, *args, **config):
        super(database, self).__init__(*args, **config)
        self._allowed_keys = ["_mtime_", "_eclasses_"] + self._known_keys
        config.setdefault("host", "localhost")
        config.setdefault("user", "gentoo")
        config.setdefault("passwd", "pass")
        config.setdefault("db", "gentoo")
        self._db_init_connection(config)
        self._db_init_structures()

    def _db_init_connection(self, config):
        connection_kwargs = {}
        connection_kwargs["host"] = config["host"]
        connection_kwargs["user"] = config["user"]
        connection_kwargs["passwd"] = config["passwd"]
        connection_kwargs["db"] = config["db"]
        connection_kwargs["cursorclass"] = MySQLdb.cursors.DictCursor
        try:
            self._db_connection = self._db_module.connect(**connection_kwargs)
            #__conn.ping(True)
            self._db_cursor = self._db_connection.cursor()
            self._db_cursor.execute("SET NAMES 'utf8'")
            self._db_cursor.execute("SET CHARACTER SET utf8")
        except self._db_error, e:
            raise cache_errors.InitializationError(self.__class__, e)

    def _create_table(self, table, columns, types, indexes):
        cols = []
        for i in range(len(columns)):
            cols.append("`%s` %s" % (columns[i], types[i]))

        create_statement = "CREATE TABLE %s (" % table + \
            ",".join(cols + indexes) + ")"
        cursor = self._db_cursor
        writemsg("mysql: dropping old table: %s\n" % table)
        cursor.execute("DROP TABLE IF EXISTS `%s`;" % table)
        cursor.execute(create_statement)

    def _db_init_structures(self):
        self._db_table = {}
        tbl = self._db_table["packages"] = {}
        tbl["table_name"] = "portage_packages"
        tbl["package_id"] = "portage_package_key"
        #noid: tbl["package_id"] = "internal_db_package_id"
        tbl["package_key"] = "portage_package_key"
        tbl["internal_columns"] = [tbl["package_key"]]
        #noid: tbl["internal_columns"] = [tbl["package_id"],tbl["package_key"]]
        tbl["columns"] = tbl["internal_columns"] + self._allowed_keys
        ct = tbl["columns_types"] = [
            #noid: "int(11) NOT NULL AUTO_INCREMENT",
            "varchar(128) COLLATE latin1_bin NOT NULL DEFAULT ''"]
        for c in tbl["columns"][1:]:
            if c in self._utf8_unsafe:
                ct.append("text CHARACTER SET latin1 COLLATE latin1_bin NOT NULL")
            else:
                ct.append("text NOT NULL")
        tbl["indexes"] = []
        tbl["indexes"].append("PRIMARY KEY (`%s`)" % tbl["package_id"])
        #tbl["indexes"].append("UNIQUE KEY (`%s`)" % tbl["package_key"])

        cursor = self._db_cursor
        for k, v in self._db_table.iteritems():
            try:
                cursor.execute("SELECT * FROM `%s` WHERE False" % v["table_name"])
                current_table_field_list = set([x[0].upper() for x in cursor.description])
                wanted_table_field_list = set([x.upper() for x in v["columns"]])
                if current_table_field_list < wanted_table_field_list:
                    self._create_table(v["table_name"], v["columns"], v["columns_types"], v["indexes"])
            except:
                self._create_table(v["table_name"], v["columns"], v["columns_types"], v["indexes"])

    def __getitem__(self, cpv):
        cursor = self._db_cursor
        tbl = self._db_table["packages"]
        result_set = set(tbl["columns"]) - set(tbl["internal_columns"])
        sql = "SELECT `%s` FROM `%s` WHERE `%s`=%%s LIMIT 2" % \
            ("`,`".join(result_set), tbl["table_name"], tbl["package_key"],)
        #print "get:", cpv
        cursor.execute(sql, cpv)
        result = cursor.fetchone()
        if not result:
            raise KeyError(cpv)
        if cursor.fetchone():
            raise cache_errors.CacheCorruption(cpv, "key is not unique")

        # XXX: The resolver chokes on unicode strings so we convert them here.
        for k in self._utf8_unsafe:
            try:
                result[k]=str(result[k]) # convert unicode strings to normal
            except UnicodeEncodeError, e:
                writemsg("%s: %s\n" % (cpv, str(e)))
        result["_eclasses_"] = reconstruct_eclasses(cpv, result["_eclasses_"])
        return result

    def _setitem(self, cpv, values):
        keys = [self._db_table["packages"]["package_key"]] + self._allowed_keys
        update_statement = "REPLACE INTO `%s`" % self._db_table["packages"]["table_name"] \
                         + "(`" \
                         + '`,`'.join(keys) \
                         + "`) VALUES (" \
                         + ("%s," * len(keys)).rstrip(",") \
                         + ")"
        values_parameters = []
        values_parameters.append(cpv)
        for k in self._allowed_keys:
            values_parameters.append(values.get(k, ''))
        cursor = self._db_cursor
        try:
            cursor.execute(update_statement, values_parameters)
        except self._db_error, e:
            writemsg("%s: %s\n" % (cpv, str(e)))
            raise

    def commit(self):
        self._db_connection.commit()

    def _delitem(self, cpv):
        cursor = self._db_cursor
        tbl = self._db_table["packages"]
        sql = "DELETE FROM `%s` WHERE `%s`=%%s" % (tbl["table_name"], tbl["package_key"],)
        cursor.execute(sql, cpv)

    def __contains__(self, cpv):
        cursor = self._db_cursor
        tbl = self._db_table["packages"]
        sql = "SELECT COUNT(*) AS cnt FROM `%s` WHERE `%s`=%%s" % (tbl["table_name"], tbl["package_key"],)
        #print "con:", sql % cpv
        cursor.execute(sql, cpv)
        cnt = cursor.fetchone()["cnt"]
        if cnt == 1:
            return True
        elif cnt == 0:
            return False
        else:
            raise cache_errors.CacheCorruption(cpv, "key is not unique")

    def __iter__(self):
        """generator for walking the dir struct"""
        cursor = self._db_cursor
        tbl = self._db_table["packages"]
        sql = "SELECT `%s` AS r FROM `%s` ORDER BY `%s`" % \
            (tbl["package_key"], tbl["table_name"], tbl["package_key"])
        #print "itr:", sql
        cursor.execute(sql)

        while True:
            row = cursor.fetchone()
            if not row:
                break
            yield row['r']
