#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2018-09-21
# @Filename: database.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)
#
# @Last modified by: José Sánchez-Gallego (gallegoj@uw.edu)
# @Last modified time: 2019-09-21 23:09:51
from __future__ import absolute_import, division, print_function
import abc
import importlib
import re
import socket
import six
from pgpasslib import getpass
from sdssdb import _peewee, _sqla, config, log
if _peewee:
from peewee import OperationalError, PostgresqlDatabase
if _sqla:
from sqlalchemy import create_engine, MetaData
from sqlalchemy.engine import url
from sqlalchemy.exc import OperationalError as OpError
from sqlalchemy.orm import sessionmaker, scoped_session
__all__ = ['DatabaseConnection', 'PeeweeDatabaseConnection', 'SQLADatabaseConnection']
[docs]class DatabaseConnection(six.with_metaclass(abc.ABCMeta)):
"""A PostgreSQL database connection with profile and autoconnect features.
Provides a base class for PostgreSQL connections for either peewee_ or
SQLAlchemy_. The parameters for the connection can be passed directly (see
`.connect_from_parameters`) or, more conveniently, a profile can be used.
By default `.dbname` is left undefined and needs to be passed when
initiating the connection. This is useful for databases such as
``apodb/lcodb`` for which the model classes are identical but the database
name is not. For databases for which the database name is fixed (e.g.,
``sdss5db``), this class can be subclassed and `.dbname` overridden.
Parameters
----------
dbname : str
The database name.
profile : str
The configuration profile to use. The profile defines the default
user, database server hostname, and port for a given location. If
not provided, the profile is automatically determined based on the
current domain, or defaults to ``local``.
autoconnect : bool
Whether to autoconnect to the database using the profile parameters.
Requites `.dbname` to be set.
dbversion : str
A database version. If specified, appends to dbname as "dbname_dbversion"
and becomes the dbname used for connection strings.
"""
#: The database name.
dbname = None
dbversion = None
def __init__(self, dbname=None, profile=None, autoconnect=True, dbversion=None):
#: Reports whether the connection is active.
self.connected = False
self.profile = None
self.dbname = dbname if dbname else self.dbname
self.dbversion = dbversion or self.dbversion
if self.dbversion:
self.dbname = f'{self.dbname}_{self.dbversion}'
self.set_profile(profile=profile, connect=autoconnect)
if autoconnect and self.dbname:
self.connect(dbname=self.dbname, silent_on_fail=True)
def __repr__(self):
return '<{} (dbname={!r}, profile={!r}, connected={})>'.format(
self.__class__.__name__, self.dbname, self.profile, self.connected)
[docs] def set_profile(self, profile=None, connect=True):
"""Sets the profile from the configuration file.
Parameters
-----------
profile : str
The profile to set. If `None`, uses the domain name to
determine the profile.
connect : bool
If True, tries to connect to the database using the new profile.
Returns
-------
connected : bool
Returns True if the database is connected.
"""
previous_profile = self.profile
if profile is not None:
assert profile in config, 'profile not found in configuration file.'
self.profile = profile
else:
# Get hostname
hostname = socket.getfqdn()
# Initially set location to local.
self.profile = 'local'
# Tries to find a profile whose domain matches the hostname
for profile in config:
if 'domain' in config[profile] and config[profile]['domain'] is not None:
if re.match(config[profile]['domain'], hostname):
self.profile = profile
break
if connect:
if self.connected and self.profile == previous_profile:
pass
elif self.dbname is not None:
self.connect(silent_on_fail=True)
return self.connected
@abc.abstractmethod
def _conn(self, dbname, **params):
"""Actually initialises the database connection.
This method should be overridden depending on the ORM library being
used. At the end, `.connected` should be set to True if the connection
was successful.
"""
pass
[docs] def connect(self, dbname=None, silent_on_fail=False, **connection_params):
"""Initialises the database using the profile information.
Parameters
----------
dbname : `str` or `None`
The database name. If `None`, defaults to `.dbname`.
user : str
Overrides the profile database user.
host : str
Overrides the profile database host.
port : str
Overrides the profile database port.
silent_on_fail : `bool`
If `True`, does not show a warning if the connection fails.
Returns
-------
connected : bool
Returns True if the database is connected.
"""
if self.profile is None:
raise RuntimeError('the profile was not set when '
'DatabaseConnection was instantiated. Use '
'set_profile to set the profile in runtime.')
# Gets the necessary configuration values from the profile
db_configuration = {}
for item in ['user', 'host', 'port']:
if item in connection_params:
db_configuration[item] = connection_params[item]
else:
profile_value = config[self.profile].get(item, None)
db_configuration[item] = profile_value
dbname = dbname or self.dbname
if dbname is None:
raise RuntimeError('the database name was not set when '
'DatabaseConnection was instantiated. '
'To set it in runtime change the dbname '
'attribute.')
return self.connect_from_parameters(dbname=dbname,
silent_on_fail=silent_on_fail,
**db_configuration)
[docs] def connect_from_parameters(self, dbname=None, **params):
"""Initialises the database from a dictionary of parameters.
Parameters
----------
dbname : `str` or `None`
The database name. If `None`, defaults to `.dbname`.
params : dict
A dictionary of parameters, which should include ``user``,
``host``, and ``port``.
Returns
-------
connected : bool
Returns True if the database is connected.
"""
# Make hostname an alias of host.
if 'hostname' in params:
if 'host' not in params:
params['host'] = params.pop('hostname')
else:
raise KeyError('cannot use hostname and host at the same time.')
dbname = dbname or self.dbname
if dbname is None:
raise RuntimeError('the database name was not set when '
'DatabaseConnection was instantiated. '
'To set it in runtime change the dbname '
'attribute.')
return self._conn(dbname, **params)
[docs] @staticmethod
def list_profiles(profile=None):
"""Returns a list of profiles.
Parameters
----------
profile : `str` or `None`
If `None`, returns a list of profile keys. If profile is not `None`
returns the parameters for the given profile.
"""
if profile is None:
return config.keys()
return config[profile]
@abc.abstractproperty
def connection_params(self):
"""Returns a dictionary with the connection parameters.
Returns
-------
connection_params : dict
A dictionary with the ``user``, ``host``, and ``part`` of the
current connection. E.g.,
``{'user': 'sdssdb', 'host': 'sdss4-db', 'port': 5432}``
"""
pass
[docs] def become(self, user):
"""Change the connection to a certain user."""
if not self.connected:
raise RuntimeError('DB has not been initialised.')
dsn_params = self.connection_params
if dsn_params is None:
raise RuntimeError('cannot determine the DSN parameters. '
'The DB may be disconnected.')
dsn_params['user'] = user
if 'dbname' not in dsn_params:
dsn_params['dbname'] = self.dbname
self.connect_from_parameters(**dsn_params)
[docs] def become_admin(self):
"""Becomes the admin user."""
assert self.profile is not None, \
'this connection was not initialised from a profile. Try using become().'
profile = config[self.profile]
assert 'admin' in profile, 'admin user not defined in profile'
self.become(profile['admin'])
[docs] def become_user(self):
"""Becomes the read-only user."""
assert self.profile is not None, \
'this connection was not initialised from a profile. Try using become().'
profile = config[self.profile]
user = profile['user'] if 'user' in profile else None
self.become(user)
[docs] def change_version(self, dbversion=None):
''' Change database version and attempt to reconnect
Parameters:
dbversion (str):
A database version
'''
self.dbversion = dbversion
dbname, *dbver = self.dbname.split('_')
self.dbname = f'{dbname}_{self.dbversion}' if dbversion else dbname
self.connect(dbname=self.dbname, silent_on_fail=True)
if _peewee:
[docs] class PeeweeDatabaseConnection(DatabaseConnection, PostgresqlDatabase):
"""Peewee database connection implementation.
Attributes
----------
models : list
Models bound to this database. Only models that are bound using
`~sdssdb.peewee.BaseModel` are handled.
"""
def __init__(self, *args, **kwargs):
self.models = {}
PostgresqlDatabase.__init__(self, None)
DatabaseConnection.__init__(self, *args, **kwargs)
@property
def connection_params(self):
"""Returns a dictionary with the connection parameters."""
if self.connected:
dsn = self.connection().get_dsn_parameters()
dsn.update({'dbname': self.dbname})
return dsn
return None
def _conn(self, dbname, silent_on_fail=False, **params):
"""Connects to the DB and tests the connection."""
PostgresqlDatabase.__init__(self, None)
PostgresqlDatabase.init(self, dbname, **params)
try:
self.connected = PostgresqlDatabase.connect(self)
self.dbname = dbname
except OperationalError:
if not silent_on_fail:
log.warning(f'failed to connect to database {self.database!r}.')
PostgresqlDatabase.init(self, None)
self.connected = False
if self.is_connection_usable():
with self.atomic():
for model in self.models.values():
if getattr(model._meta, 'use_reflection', False):
if hasattr(model, 'reflect'):
model.reflect()
return self.connected
[docs] def get_model(self, table_name, schema=None):
"""Returns the model for a table.
Parameters
----------
table_name : str
The name of the table whose model will be returned.
schema : str
The schema for the table. If `None`, the first model that
matches the table name will be returned.
Returns
-------
:class:`peewee:Model` or `None`
The model associated with the table, or `None` if no model
was found.
"""
for model in self.models:
if schema and model._meta.schema != schema:
continue
if model._meta.table_name == table_name:
return model
return None
if _sqla:
[docs] class SQLADatabaseConnection(DatabaseConnection):
''' SQLAlchemy database connection implementation '''
engine = None
bases = []
Session = None
metadata = None
def __init__(self, *args, **kwargs):
self._connect_params = None
DatabaseConnection.__init__(self, *args, **kwargs)
@property
def connection_params(self):
"""Returns a dictionary with the connection parameters."""
return self._connect_params
def _get_password(self, **params):
''' Get a db password from a pgpass file
Parameters:
params (dict):
A dictionary of database connection parameters
Returns:
The database password for a given set of connection parameters
'''
password = params.get('password', None)
if not password:
try:
password = getpass(params['host'], params['port'], params['database'],
params['username'])
except KeyError:
raise RuntimeError('ERROR: invalid server configuration')
return password
def _make_connection_string(self, dbname, **params):
''' Build a db connection string
Parameters:
dbname (str):
The name of the database to connect to
params (dict):
A dictionary of database connection parameters
Returns:
A database connection string
'''
db_params = params.copy()
db_params['drivername'] = 'postgresql+psycopg2'
db_params['database'] = dbname
db_params['username'] = db_params.pop('user', None)
db_params['host'] = db_params.pop('host', 'localhost')
db_params['port'] = db_params.pop('port', 5432)
if db_params['username']:
db_params['password'] = self._get_password(**db_params)
db_connection_string = url.URL(**db_params)
self._connect_params = params
return db_connection_string
def _conn(self, dbname, silent_on_fail=False, **params):
'''Connects to the DB and tests the connection.'''
# get connection string
db_connection_string = self._make_connection_string(dbname, **params)
try:
self.create_engine(db_connection_string, echo=False,
pool_size=10, pool_recycle=1800)
self.engine.connect()
except OpError:
if not silent_on_fail:
log.warning('Failed to connect to database {0}'.format(dbname))
self.engine.dispose()
self.engine = None
self.connected = False
self.Session = None
self.metadata = None
else:
self.connected = True
self.dbname = dbname
self.prepare_bases()
return self.connected
[docs] def reset_engine(self):
''' Reset the engine, metadata, and session '''
self.bases = []
if self.engine:
self.engine.dispose()
self.engine = None
self.metadata = None
self.Session.close()
self.Session = None
[docs] def create_engine(self, db_connection_string=None, echo=False, pool_size=10,
pool_recycle=1800, expire_on_commit=True):
''' Create a new database engine
Resets and creates a new sqlalchemy database engine. Also creates and binds
engine metadata and a new scoped session.
'''
self.reset_engine()
if not db_connection_string:
dbname = self.dbname or self.DATABASE_NAME
db_connection_string = self._make_connection_string(dbname,
**self.connection_params)
self.engine = create_engine(db_connection_string, echo=echo, pool_size=pool_size,
pool_recycle=pool_recycle)
self.metadata = MetaData(bind=self.engine)
self.Session = scoped_session(sessionmaker(bind=self.engine, autocommit=True,
expire_on_commit=expire_on_commit))
[docs] def add_base(self, base, prepare=True):
"""Binds a base to this connection."""
if base not in self.bases:
self.bases.append(base)
if prepare and self.connected:
self.prepare_bases(base=base)
[docs] def prepare_bases(self, base=None):
"""Prepare a Model Base
Prepares a SQLalchemy Base for reflection. This binds a database
engine to a specific Base which maps to a set of ModelClasses.
If ``base`` is passed only that base will be prepared. Otherwise,
all the bases bound to this database connection will be prepared.
"""
do_bases = [base] if base else self.bases
for base in do_bases:
base.prepare(self.engine)
# If the base has an attribute _relations that's the function
# to call to set up the relationships once the engine has been
# bound to the base.
if hasattr(base, '_relations'):
if isinstance(base._relations, str):
module = importlib.import_module(base.__module__)
relations_func = getattr(module, base._relations)
relations_func()
elif callable(base._relations):
base._relations()
else:
pass