# GNU Enterprise Common - MySQL DB driver - Schema Introspection
#
# Copyright 2001-2005 Free Software Foundation
#
# This file is part of GNU Enterprise
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# $Id: Introspection.py 6851 2005-01-03 20:59:28Z jcater $

__all__ = ['Introspection']

import string

from gnue.common.datasources import GIntrospection

# =============================================================================
# This class implements schema introspection for MySQL backends
# =============================================================================

class Introspection (GIntrospection.Introspection):

  # list of the types of Schema objects this driver provides
  types = [('table', _('Tables'), 1)]

  _TYPES = {'number': ['int','integer','bigint','mediumint',
                       'smallint','tinyint','float','real', 'double','decimal'],
            'date'  : ['date','time','timestamp','datetime']}

  # ---------------------------------------------------------------------------
  # Find a schema element by name and/or type
  # ---------------------------------------------------------------------------

  def find (self, name = None, type = None):
    """
    This function searches the schema for an element by name and/or type. If no
    name and no type is given, all elements will be retrieved.

    @param name: look for an element with this name
    @param type: look for an element with this type

    @return: A sequence of schema instances, one per element found, or None if
        no element could be found.
    """

    result = []

    if name is None:
      cmd = u"SHOW TABLES"

    else:
      cmd = u"SHOW COLUMNS FROM %s" % name

    try:
      cursor = self._connection.makecursor (cmd)

    except:
      if name is not  None:
        return None

      else:
        raise

    try:
      for rs in cursor.fetchall ():
        attrs = {'id'        : name or rs [0],
                 'name'      : name or rs [0],
                 'type'      : 'table',
                 'primarykey': None,
                 'indices'   : self.__getIndices (name or rs [0])}

        if attrs ['indices'] is not None:
          for index in attrs ['indices'].values ():
            if index ['primary']:
              attrs ['primarykey'] = index ['fields']
              break

        result.append ( \
          GIntrospection.Schema (attrs, getChildSchema = self._getChildSchema))

        if name is not None:
          break

    finally:
      cursor.close ()


    return len (result) and result or None


  # ---------------------------------------------------------------------------
  # Get all fields of a table
  # ---------------------------------------------------------------------------

  def _getChildSchema (self, parent):
    """
    This function returns a list of all child elements for a given table.

    @param parent: schema object instance whose child elements should be
        fetched
    @return: sequence of schema instances
    """

    result = []

    cmd = u"SHOW COLUMNS FROM %s" % parent.id
    
    cursor = self._connection.makecursor (cmd)

    try:
      for rs in cursor.fetchall ():
        nativetype = string.split (string.replace (rs [1], ')', ''), '(')

        attrs = {'id'        : "%s.%s" % (parent.id, rs [0]),
                 'name'      : rs [0],
                 'type'      : 'field',
                 'nativetype': nativetype [0],
                 'required'  : rs [2] != 'YES',
                 'datatype'  : 'text'}

        for (group, natives) in self._TYPES.items ():
          if nativetype [0] in natives:
            attrs ['datatype'] = group
            break

        if len (nativetype) == 2:
          parts = []
          for item in nativetype [1].split (','):
            parts.extend (item.split ())

          if parts [0].strip ().isdigit ():
            attrs ['length'] = int (parts [0].strip ())

          if len (parts) > 1 and parts [1].strip ().isdigit ():
            attrs ['precision'] = int (parts [1].strip ())

        if rs [4] not in ('NULL', '0000-00-00 00:00:00', '', None):
          attrs ['defaulttype'] = 'constant'
          attrs ['defaultval']  = rs [4]

        if rs [5] == 'auto_increment':
          attrs ['defaulttype'] = 'serial'

        result.append (GIntrospection.Schema (attrs))

    finally:
      cursor.close ()

    return result


  # ---------------------------------------------------------------------------
  # Get a dictionary of all indices available for a given table
  # ---------------------------------------------------------------------------

  def __getIndices (self, table):
    """
    This function creates a dictionary with all indices of a given relation
    where the keys are the indexnames and the values are dictionaries
    describing the indices. Such a dictionary has the keys 'unique', 'primary'
    and 'fields', where 'unique' specifies whether the index is unique or not
    and 'primary' specifies wether the index is the primary key or not.
    'fields' holds a sequence with all field names building the index.

    @param table: name of the table to fetch indices for
    @return: dictionary with indices or None if no indices were found
    """

    result = {}

    cmd = u"SHOW INDEX FROM %s" % table

    cursor = self._connection.makecursor (cmd)

    try:
      for rs in cursor.fetchall ():
        ixName = rs [2]
        if not result.has_key (ixName):
          result [ixName] = {'unique': not rs [1],
                             'primary': ixName == 'PRIMARY',
                             'fields': []}
        result [ixName] ['fields'].append ((rs [3], rs [4]))


      # Sort the field lists according to their sequence number and remove the
      # sequences number from the list
      for index in result.values ():
        fields = index ['fields']
        fields.sort ()
        index ['fields'] = [field [1] for field in fields]

    finally:
      cursor.close ()

    return len (result.keys ()) and result or None
