import abc
import importlib
import os
import re
import socket

import pgpasslib
import six

from sqlalchemy import MetaData, create_engine
from sqlalchemy.engine import url
from sqlalchemy.exc import OperationalError as OpError
from sqlalchemy.orm import scoped_session, sessionmaker

import peewee
from peewee import OperationalError, PostgresqlDatabase
from playhouse.postgres_ext import ArrayField
from playhouse.reflection import Introspector, UnknownField

import sdssdb
from sdssdb import config, log
from sdssdb.utils.internals import get_database_columns

def _should_autoconnect():
    """Determines whether we should autoconnect."""

    if 'SDSSDB_AUTOCONNECT' in os.environ:
        envvar_autoconnect = os.environ['SDSSDB_AUTOCONNECT'].lower()
        if envvar_autoconnect == '0' or envvar_autoconnect == 'false':
            return False
        return sdssdb.autoconnect

[docs] class DatabaseConnection(six.with_metaclass(abc.ABCMeta)): """A PostgreSQL database connection with profile and autoconnect features. Provides a base class for PostgreSQL connections for either peewee_ or SQLAlchemy_. The parameters for the connection can be passed directly (see `.connect_from_parameters`) or, more conveniently, a profile can be used. By default `.dbname` is left undefined and needs to be passed when initiating the connection. This is useful for databases such as ``apodb/lcodb`` for which the model classes are identical but the database name is not. For databases for which the database name is fixed (e.g., ``sdss5db``), this class can be subclassed and `.dbname` overridden. Parameters ---------- dbname : str The database name. profile : str The configuration profile to use. The profile defines the default user, database server hostname, and port for a given location. If not provided, the profile is automatically determined based on the current domain, or defaults to ``local``. autoconnect : bool or None Whether to autoconnect to the database using the profile parameters. Requites `.dbname` to be set. If `None`, whether to autoconnect is defined, in order, by the existence of an environment variable ``$SDSSDB_AUTOCONNECT`` or by ``sdssdb.autoconnect``. If they are set to ``0`` or ``false`` the database won't autoconnect. Note that this must be set before importing any model classes. dbversion : str A database version. If specified, appends to dbname as "dbname_dbversion" and becomes the dbname used for connection strings. """ #: The database name. dbname = None #: Database version dbversion = None #: # Whether to call Model.reflect() in Peewee after a connection. auto_reflect = True def __init__(self, dbname=None, profile=None, autoconnect=None, dbversion=None): self.profile = None self._config = {} self.dbname = dbname if dbname else self.dbname self.dbversion = dbversion or self.dbversion if self.dbversion: self.dbname = f'{self.dbname}_{self.dbversion}' self.set_profile(profile=profile, connect=False) if autoconnect is None: autoconnect = _should_autoconnect() if autoconnect and self.dbname: self.connect(dbname=self.dbname, silent_on_fail=True) def __repr__(self): return '<{} (dbname={!r}, profile={!r}, connected={})>'.format( self.__class__.__name__, self.dbname, self.profile, self.connected)
[docs] def set_profile(self, profile=None, connect=True): """Sets the profile from the configuration file. Parameters ----------- profile : str The profile to set. If `None`, uses the domain name to determine the profile. connect : bool If True, tries to connect to the database using the new profile. Returns ------- connected : bool Returns True if the database is connected. """ previous_profile = self.profile if profile is not None: assert profile in config, 'profile not found in configuration file.' self.profile = profile self._config = config[profile].copy() else: # Get hostname hostname = socket.getfqdn() # Initially set location to local. self.profile = 'local' self._config = config[self.profile].copy() # Tries to find a profile whose domain matches the hostname for profile in config: if 'domain' in config[profile] and config[profile]['domain'] is not None: if re.match(config[profile]['domain'], hostname): self.profile = profile self._config = config[profile].copy() # If the profile host matches the current hostname set the # value to None to force using localhost to prevent cases # in which the loopback is not configured properly in PostgreSQL. if hostname == self._config['host']: self._config['host'] = None break if connect: if self.connected and self.profile == previous_profile: pass elif self.dbname is not None: self.connect(**self._config) return self.connected
@abc.abstractmethod def _conn(self, dbname, **params): """Actually initialises the database connection. This method should be overridden depending on the ORM library being used. At the end, `.connected` should be set to True if the connection was successful. """ pass
[docs] def connect(self, dbname=None, silent_on_fail=False, **connection_params): """Initialises the database using the profile information. Parameters ---------- dbname : `str` or `None` The database name. If `None`, defaults to `.dbname`. user : str Overrides the profile database user. host : str Overrides the profile database host. port : str Overrides the profile database port. silent_on_fail : `bool` If `True`, does not show a warning if the connection fails. Returns ------- connected : bool Returns True if the database is connected. """ if self.profile is None: raise RuntimeError('the profile was not set when ' 'DatabaseConnection was instantiated. Use ' 'set_profile to set the profile in runtime.') # Gets the necessary configuration values from the profile db_configuration = {} for item in ['user', 'host', 'port']: if item in connection_params: db_configuration[item] = connection_params[item] else: profile_value = self._config.get(item, None) db_configuration[item] = profile_value dbname = dbname or self.dbname if dbname is None: raise RuntimeError('the database name was not set when ' 'DatabaseConnection was instantiated. ' 'To set it in runtime change the dbname ' 'attribute.') return self.connect_from_parameters(dbname=dbname, silent_on_fail=silent_on_fail, **db_configuration)
[docs] def connect_from_parameters(self, dbname=None, **params): """Initialises the database from a dictionary of parameters. Parameters ---------- dbname : `str` or `None` The database name. If `None`, defaults to `.dbname`. params : dict A dictionary of parameters, which should include ``user``, ``host``, and ``port``. Returns ------- connected : bool Returns True if the database is connected. """ # Make hostname an alias of host. if 'hostname' in params: if 'host' not in params: params['host'] = params.pop('hostname') else: raise KeyError('cannot use hostname and host at the same time.') dbname = dbname or self.dbname if dbname is None: raise RuntimeError('the database name was not set when ' 'DatabaseConnection was instantiated. ' 'To set it in runtime change the dbname ' 'attribute.') return self._conn(dbname, **params)
[docs] @staticmethod def list_profiles(profile=None): """Returns a list of profiles. Parameters ---------- profile : `str` or `None` If `None`, returns a list of profile keys. If profile is not `None` returns the parameters for the given profile. """ if profile is None: return config.keys() return config[profile]
@abc.abstractproperty def connection_params(self): """Returns a dictionary with the connection parameters. Returns ------- connection_params : dict A dictionary with the ``user``, ``host``, and ``part`` of the current connection. E.g., ``{'user': 'sdssdb', 'host': 'sdss4-db', 'port': 5432}`` """ pass
[docs] def become(self, user): """Change the connection to a certain user.""" if not self.connected: raise RuntimeError('DB has not been initialised.') dsn_params = self.connection_params dsn_params.pop('password', None) # Do not keep the password since it may change. if dsn_params is None: raise RuntimeError('cannot determine the DSN parameters. ' 'The DB may be disconnected.') dsn_params['user'] = user if 'dbname' not in dsn_params: dsn_params['dbname'] = self.dbname self.connect_from_parameters(**dsn_params)
[docs] def become_admin(self, admin=None): """Becomes the admin user. If ``admin=None`` defaults to the ``admin`` value in the current profile. """ assert self.profile is not None, \ 'this connection was not initialised from a profile. Try using become().' assert 'admin' in self._config, 'admin user not defined in profile' self.become(admin or self._config['admin'])
[docs] def become_user(self, user=None): """Becomes the read-only user. If ``user=None`` defaults to the ``user`` value in the current profile. """ assert self.profile is not None, \ 'this connection was not initialised from a profile. Try using become().' if user is None: user = self._config['user'] if 'user' in self._config else None self.become(user)
[docs] def change_version(self, dbversion=None): ''' Change database version and attempt to reconnect Parameters: dbversion (str): A database version ''' self.dbversion = dbversion dbname, *dbver = self.dbname.split('_') self.dbname = f'{dbname}_{self.dbversion}' if dbversion else dbname self.connect(dbname=self.dbname, silent_on_fail=True)
[docs] def post_connect(self): """Hook called after a successfull connection.""" pass
[docs] class PeeweeDatabaseConnection(DatabaseConnection, PostgresqlDatabase): """Peewee database connection implementation. Attributes ---------- models : list Models bound to this database. Only models that are bound using `~sdssdb.peewee.BaseModel` are handled. """ def __init__(self, *args, **kwargs): self.models = {} self.introspector = {} self._metadata = {} autorollback = kwargs.pop('autorollback', True) PostgresqlDatabase.__init__(self, None, autorollback=autorollback) DatabaseConnection.__init__(self, *args, **kwargs) @property def connected(self): """Reports whether the connection is active.""" return self.is_connection_usable() @property def connection_params(self): """Returns a dictionary with the connection parameters.""" if self.connected: return self.connect_params.copy() return None def _conn(self, dbname, silent_on_fail=False, **params): """Connects to the DB and tests the connection.""" if 'password' not in params: pgpass_params = {key: value for key, value in params.copy().items() if value is not None} try: params['password'] = pgpasslib.getpass(dbname=dbname, **pgpass_params) except pgpasslib.FileNotFound: params['password'] = None PostgresqlDatabase.init(self, dbname, **params) self._metadata = {} try: PostgresqlDatabase.connect(self) self.dbname = dbname except OperationalError as ee: if not silent_on_fail: log.warning(f'failed connecting to database {self.database!r}: {ee}') PostgresqlDatabase.init(self, None) if self.is_connection_usable() and self.auto_reflect: with self.atomic(): for model in self.models.values(): if getattr(model._meta, 'use_reflection', False): if hasattr(model, 'reflect'): model.reflect() if self.connected: self.post_connect() return self.connected
[docs] def get_model(self, table_name, schema=None): """Returns the model for a table. Parameters ---------- table_name : str The name of the table whose model will be returned. schema : str The schema for the table. If `None`, the first model that matches the table name will be returned. Returns ------- :class:`peewee:Model` or `None` The model associated with the table, or `None` if no model was found. """ for model in self.models.values(): if schema and model._meta.schema != schema: continue if model._meta.table_name == table_name: return model return None
[docs] def get_introspector(self, schema=None): """Gets a Peewee database :class:`peewee:Introspector`.""" schema_key = schema or '' if schema_key not in self.introspector: self.introspector[schema_key] = Introspector.from_database( self, schema=schema) return self.introspector[schema_key]
[docs] def get_fields(self, table_name, schema=None, cache=True): """Returns a list of Peewee fields for a table.""" schema = schema or 'public' if schema not in self._metadata or not cache: self._metadata[schema] = get_database_columns(self, schema=schema) if table_name not in self._metadata[schema]: return [] table_metadata = self._metadata[schema][table_name] pk = table_metadata['pk'] composite_key = pk is not None and len(pk) > 1 columns = table_metadata['columns'] fields = [] for col_name, field_type, array_type, nullable in columns: is_pk = True if (pk is not None and not composite_key and pk[0] == col_name) else False params = {'column_name': col_name, 'null': nullable, 'primary_key': is_pk, 'unique': is_pk} if array_type: field = ArrayField(array_type, **params) elif array_type is False and field_type is UnknownField: field = peewee.BareField(**params) else: field = field_type(**params) fields.append(field) return fields
[docs] def get_primary_keys(self, table_name, schema=None, cache=True): """Returns the primary keys for a table.""" schema = schema or 'public' if schema not in self._metadata or not cache: self._metadata[schema] = get_database_columns(self, schema=schema) if table_name not in self._metadata[schema]: return [] else: return self._metadata[schema][table_name]['pk'] or []
[docs] class SQLADatabaseConnection(DatabaseConnection): ''' SQLAlchemy database connection implementation ''' engine = None bases = [] Session = None metadata = None def __init__(self, *args, **kwargs): #: Reports whether the connection is active. self.connected = False self._connect_params = None DatabaseConnection.__init__(self, *args, **kwargs) @property def connection_params(self): """Returns a dictionary with the connection parameters.""" return self._connect_params def _get_password(self, **params): ''' Get a db password from a pgpass file Parameters: params (dict): A dictionary of database connection parameters Returns: The database password for a given set of connection parameters ''' password = params.get('password', None) if not password: try: password = pgpasslib.getpass(params['host'], params['port'], params['database'], params['username']) except KeyError: raise RuntimeError('ERROR: invalid server configuration') return password def _make_connection_string(self, dbname, **params): ''' Build a db connection string Parameters: dbname (str): The name of the database to connect to params (dict): A dictionary of database connection parameters Returns: A database connection string ''' db_params = params.copy() db_params['drivername'] = 'postgresql+psycopg2' db_params['database'] = dbname db_params['username'] = db_params.pop('user', None) db_params['host'] = db_params.pop('host', 'localhost') db_params['port'] = db_params.pop('port', 5432) if db_params['username']: db_params['password'] = self._get_password(**db_params) db_connection_string = url.URL(**db_params) self._connect_params = params return db_connection_string def _conn(self, dbname, silent_on_fail=False, **params): '''Connects to the DB and tests the connection.''' # get connection string db_connection_string = self._make_connection_string(dbname, **params) try: self.create_engine(db_connection_string, echo=False, pool_size=10, pool_recycle=1800) self.engine.connect() except OpError: if not silent_on_fail: log.warning('Failed to connect to database {0}'.format(dbname)) self.engine.dispose() self.engine = None self.connected = False self.Session = None self.metadata = None else: self.connected = True self.dbname = dbname self.prepare_bases() if self.connected: self.post_connect() return self.connected
[docs] def reset_engine(self): ''' Reset the engine, metadata, and session ''' self.bases = [] if self.engine: self.engine.dispose() self.engine = None self.metadata = None self.Session.close() self.Session = None
[docs] def create_engine(self, db_connection_string=None, echo=False, pool_size=10, pool_recycle=1800, expire_on_commit=True): ''' Create a new database engine Resets and creates a new sqlalchemy database engine. Also creates and binds engine metadata and a new scoped session. ''' self.reset_engine() if not db_connection_string: dbname = self.dbname or self.DATABASE_NAME db_connection_string = self._make_connection_string(dbname, **self.connection_params) self.engine = create_engine(db_connection_string, echo=echo, pool_size=pool_size, pool_recycle=pool_recycle) self.metadata = MetaData(bind=self.engine) self.Session = scoped_session(sessionmaker(bind=self.engine, autocommit=True, expire_on_commit=expire_on_commit))
[docs] def add_base(self, base, prepare=True): """Binds a base to this connection.""" if base not in self.bases: self.bases.append(base) if prepare and self.connected: self.prepare_bases(base=base)
[docs] def prepare_bases(self, base=None): """Prepare a Model Base Prepares a SQLalchemy Base for reflection. This binds a database engine to a specific Base which maps to a set of ModelClasses. If ``base`` is passed only that base will be prepared. Otherwise, all the bases bound to this database connection will be prepared. """ do_bases = [base] if base else self.bases for base in do_bases: base.prepare(self.engine) # If the base has an attribute _relations that's the function # to call to set up the relationships once the engine has been # bound to the base. if hasattr(base, '_relations'): if isinstance(base._relations, str): module = importlib.import_module(base.__module__) relations_func = getattr(module, base._relations) relations_func() elif callable(base._relations): base._relations() else: pass