Source code for sdssdb.connection

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2018-09-21
# @Filename: database.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)

from __future__ import annotations

import abc
import importlib
import os
import re
import socket

import pgpasslib
import six

from sqlalchemy import MetaData, create_engine
from sqlalchemy.engine import url
from sqlalchemy.exc import OperationalError as OpError
from sqlalchemy.ext.declarative import DeferredReflection
from sqlalchemy.orm import scoped_session, sessionmaker

import peewee
from peewee import OperationalError, PostgresqlDatabase
from playhouse.postgres_ext import ArrayField
from playhouse.reflection import Introspector, UnknownField

import sdssdb
from sdssdb import config, log
from sdssdb.utils.internals import get_database_columns


__all__ = ["DatabaseConnection", "PeeweeDatabaseConnection", "SQLADatabaseConnection"]


def _should_autoconnect():
    """Determines whether we should autoconnect."""

    if "SDSSDB_AUTOCONNECT" in os.environ:
        envvar_autoconnect = os.environ["SDSSDB_AUTOCONNECT"].lower()
        if envvar_autoconnect == "0" or envvar_autoconnect == "false":
            return False
    else:
        return sdssdb.autoconnect


def get_database_uri(
    dbname: str,
    host: str | None = None,
    port: int | None = None,
    user: str | None = None,
    password: str | None = None,
):
    """Returns the URI to the database."""

    if user is None and password is None:
        auth: str = ""
    elif password is None:
        auth: str = f"{user}@"
    else:
        auth: str = f"{user}:{password}@"

    host_port: str = f"{host or ''}" if port is None else f"{host or ''}:{port}"

    if auth == "" and host_port == "":
        return f"postgresql://{dbname}"

    return f"postgresql://{auth}{host_port}/{dbname}"


[docs] class DatabaseConnection(six.with_metaclass(abc.ABCMeta)): """A PostgreSQL database connection with profile and autoconnect features. Provides a base class for PostgreSQL connections for either peewee_ or SQLAlchemy_. The parameters for the connection can be passed directly (see `.connect_from_parameters`) or, more conveniently, a profile can be used. By default `.dbname` is left undefined and needs to be passed when initiating the connection. This is useful for databases such as ``apodb/lcodb`` for which the model classes are identical but the database name is not. For databases for which the database name is fixed (e.g., ``sdss5db``), this class can be subclassed and `.dbname` overridden. Parameters ---------- dbname : str The database name. profile : str The configuration profile to use. The profile defines the default user, database server hostname, and port for a given location. If not provided, the profile is automatically determined based on the current domain, or defaults to ``local``. autoconnect : bool or None Whether to autoconnect to the database using the profile parameters. Requites `.dbname` to be set. If `None`, whether to autoconnect is defined, in order, by the existence of an environment variable ``$SDSSDB_AUTOCONNECT`` or by ``sdssdb.autoconnect``. If they are set to ``0`` or ``false`` the database won't autoconnect. Note that this must be set before importing any model classes. dbversion : str A database version. If specified, appends to dbname as "dbname_dbversion" and becomes the dbname used for connection strings. use_psycopg3 : bool Whether to use psycopg3 instead of psycopg2. If `None`, defaults to the value of the environment variable ``$SDSSDB_PSYCOPG3`` (which defaults to `True` if not set). """ #: The database name. dbname = None #: Database version dbversion = None #: # Whether to call Model.reflect() in Peewee after a connection. auto_reflect = True def __init__( self, dbname=None, profile=None, autoconnect=None, dbversion=None, use_psycopg3=None, ): self.profile = None self._config = {} self.dbname = dbname if dbname else self.dbname self.dbversion = dbversion or self.dbversion if self.dbversion: self.dbname = f"{self.dbname}_{self.dbversion}" self.set_profile(profile=profile, connect=False) if use_psycopg3 is None: use_psycopg3 = os.environ.get("SDSSDB_PSYCOPG3", "true").lower() in ["true", "1"] self.use_psycopg3 = use_psycopg3 if autoconnect is None: autoconnect = _should_autoconnect() if autoconnect and self.dbname: self.connect(dbname=self.dbname, silent_on_fail=True) def __repr__(self): return "<{} (dbname={!r}, profile={!r}, connected={})>".format( self.__class__.__name__, self.dbname, self.profile, self.connected, )
[docs] def set_profile(self, profile=None, connect=True, **params): """Sets the profile from the configuration file. Parameters ----------- profile : str The profile to set. If `None`, uses the domain name to determine the profile. connect : bool If True, tries to connect to the database using the new profile. params Connection parameters (``user``, ``host``, ``port``, ``password``) that will override the profile values. Returns ------- connected : bool Returns True if the database is connected. """ previous_profile = self.profile if profile is not None: assert profile in config, "profile not found in configuration file." self.profile = profile self._config = config[profile].copy() else: # Get hostname hostname = socket.getfqdn() # Initially set location to local. self.profile = "local" self._config = config[self.profile].copy() # Tries to find a profile whose domain matches the hostname for profile in config: if "domain" in config[profile] and config[profile]["domain"] is not None: if re.match(config[profile]["domain"], hostname): self.profile = profile self._config = config[profile].copy() # If the profile host matches the current hostname set the # value to None to force using localhost to prevent cases # in which the loopback is not configured properly in PostgreSQL. if hostname == self._config["host"]: self._config["host"] = None break self._config.update(params) if connect: if self.connected and self.profile == previous_profile: pass elif self.dbname is not None: self.connect(**self._config) return self.connected
@abc.abstractmethod def _conn(self, dbname, **params): """Actually initialises the database connection. This method should be overridden depending on the ORM library being used. At the end, `.connected` should be set to True if the connection was successful. """ pass
[docs] def connect(self, dbname=None, silent_on_fail=False, **connection_params): """Initialises the database using the profile information. Parameters ---------- dbname : `str` or `None` The database name. If `None`, defaults to `.dbname`. ``dbname`` can also be a full database URI, in which case the other connection parameters are ignored. user : str Overrides the profile database user. host : str Overrides the profile database host. port : str Overrides the profile database port. silent_on_fail : `bool` If `True`, does not show a warning if the connection fails. Returns ------- connected : bool Returns True if the database is connected. """ if self.profile is None: raise RuntimeError( "the profile was not set when " "DatabaseConnection was instantiated. Use " "set_profile to set the profile in runtime." ) # Gets the necessary configuration values from the profile db_configuration = {} for item in ["user", "host", "port"]: if item in connection_params: db_configuration[item] = connection_params[item] else: profile_value = self._config.get(item, None) db_configuration[item] = profile_value dbname = dbname or self.dbname if dbname is None: raise RuntimeError( "the database name was not set when " "DatabaseConnection was instantiated. " "To set it in runtime change the dbname " "attribute." ) return self.connect_from_parameters( dbname=dbname, silent_on_fail=silent_on_fail, **db_configuration, )
[docs] def connect_from_parameters(self, dbname=None, **params): """Initialises the database from a dictionary of parameters. Parameters ---------- dbname : `str` or `None` The database name. If `None`, defaults to `.dbname`. params : dict A dictionary of parameters, which should include ``user``, ``host``, and ``port``. Returns ------- connected : bool Returns True if the database is connected. """ # Make hostname an alias of host. if "hostname" in params: if "host" not in params: params["host"] = params.pop("hostname") else: raise KeyError("cannot use hostname and host at the same time.") dbname = dbname or self.dbname if dbname is None: raise RuntimeError( "the database name was not set when " "DatabaseConnection was instantiated. " "To set it in runtime change the dbname " "attribute." ) return self._conn(dbname, **params)
[docs] @staticmethod def list_profiles(profile=None): """Returns a list of profiles. Parameters ---------- profile : `str` or `None` If `None`, returns a list of profile keys. If profile is not `None` returns the parameters for the given profile. """ if profile is None: return config.keys() return config[profile]
[docs] def get_connection_uri(self): """Returns the URI to the database connection.""" params = self.connection_params if not self.connected or params is None or self.dbname is None: raise RuntimeError("The database is not connected.") valid_params = { "user": params.get("user", None), "host": params.get("host", None), "port": params.get("port", None), "password": params.get("password", None), } return get_database_uri(self.dbname, **valid_params)
@property @abc.abstractmethod def connection_params(self) -> dict | None: """Returns a dictionary with the connection parameters. Returns ------- connection_params : dict A dictionary with the ``user``, ``host``, and ``part`` of the current connection. E.g., ``{'user': 'sdssdb', 'host': 'sdss4-db', 'port': 5432}`` """ pass
[docs] def become(self, user): """Change the connection to a certain user.""" dsn_params = self.connection_params if not self.connected or dsn_params is None: raise RuntimeError("DB has not been initialised.") dsn_params.pop("password", None) # Do not keep the password since it may change. if dsn_params is None: raise RuntimeError("cannot determine the DSN parameters. The DB may be disconnected.") dsn_params["user"] = user if "dbname" not in dsn_params: dsn_params["dbname"] = self.dbname self.connect_from_parameters(**dsn_params)
[docs] def become_admin(self, admin=None): """Becomes the admin user. If ``admin=None`` defaults to the ``admin`` value in the current profile. """ assert self.profile is not None, ( "this connection was not initialised from a profile. Try using become()." ) assert "admin" in self._config, "admin user not defined in profile" self.become(admin or self._config["admin"])
[docs] def become_user(self, user=None): """Becomes the read-only user. If ``user=None`` defaults to the ``user`` value in the current profile. """ assert self.profile is not None, ( "this connection was not initialised from a profile. Try using become()." ) if user is None: user = self._config["user"] if "user" in self._config else None self.become(user)
[docs] def change_version(self, dbversion=None): """Change database version and attempt to reconnect Parameters: dbversion (str): A database version """ self.dbversion = dbversion dbname, *dbver = self.dbname.split("_") self.dbname = f"{dbname}_{self.dbversion}" if dbversion else dbname self.connect(dbname=self.dbname, silent_on_fail=True)
[docs] def post_connect(self): """Hook called after a successfull connection.""" pass
[docs] class PeeweeDatabaseConnection(DatabaseConnection, PostgresqlDatabase): """Peewee database connection implementation. Attributes ---------- models : list Models bound to this database. Only models that are bound using `~sdssdb.peewee.BaseModel` are handled. """ def __init__(self, *args, **kwargs): self.models = {} self.introspector = {} self._metadata = {} PostgresqlDatabase.__init__(self, None) DatabaseConnection.__init__(self, *args, **kwargs) @property def connected(self): """Reports whether the connection is active.""" return self.is_connection_usable() @property def connection_params(self): """Returns a dictionary with the connection parameters.""" if self.connected: if self.psycopg_version == "psycopg2": return self.connection().info.dsn_parameters elif self.psycopg_version == "psycopg3": return self.connection().info.get_parameters() else: raise RuntimeError("unknown psycopg version in use.") return None @property def psycopg_version(self): """Returns the version of psycopg in use.""" if not self.connected: raise RuntimeError("The database is not connected.") if isinstance(self._adapter, self.psycopg2_adapter): return "psycopg2" elif isinstance(self._adapter, self.psycopg3_adapter): return "psycopg3" else: return "unknown" def _conn(self, dbname, silent_on_fail=False, **params): """Connects to the DB and tests the connection.""" if dbname.startswith("postgresql://"): PostgresqlDatabase.__init__(self, dbname, prefer_psycopg3=self.use_psycopg3) else: if "password" not in params: pgpass_params = { key: value for key, value in params.copy().items() if value is not None } try: params["password"] = pgpasslib.getpass(dbname=dbname, **pgpass_params) except pgpasslib.FileNotFound: params["password"] = None PostgresqlDatabase.init( self, dbname, prefer_psycopg3=self.use_psycopg3, **params, ) self._metadata = {} try: PostgresqlDatabase.connect(self) conn_params = self.connection_params dbname = conn_params.get("dbname", dbname) self.dbname = dbname except OperationalError as ee: if not silent_on_fail: log.warning(f"failed connecting to database {self.database!r}: {ee}") PostgresqlDatabase.init(self, None) if self.is_connection_usable() and self.auto_reflect: with self.atomic(): for model in self.models.values(): if getattr(model._meta, "use_reflection", False): if hasattr(model, "reflect"): model.reflect() if self.connected: self.post_connect() return self.connected
[docs] def get_model(self, table_name, schema=None): """Returns the model for a table. Parameters ---------- table_name : str The name of the table whose model will be returned. schema : str The schema for the table. If `None`, the first model that matches the table name will be returned. Returns ------- :class:`peewee:Model` or `None` The model associated with the table, or `None` if no model was found. """ for model in self.models.values(): if schema and model._meta.schema != schema: continue if model._meta.table_name == table_name: return model return None
[docs] def get_introspector(self, schema=None): """Gets a Peewee database :class:`peewee:Introspector`.""" schema_key = schema or "" if schema_key not in self.introspector: self.introspector[schema_key] = Introspector.from_database(self, schema=schema) return self.introspector[schema_key]
[docs] def get_fields(self, table_name, schema=None, cache=True): """Returns a list of Peewee fields for a table.""" schema = schema or "public" if schema not in self._metadata or not cache: self._metadata[schema] = get_database_columns(self, schema=schema) if table_name not in self._metadata[schema]: return [] table_metadata = self._metadata[schema][table_name] pk = table_metadata["pk"] composite_key = pk is not None and len(pk) > 1 columns = table_metadata["columns"] fields = [] for col_name, field_type, array_type, nullable in columns: is_pk = True if (pk is not None and not composite_key and pk[0] == col_name) else False params = { "column_name": col_name, "null": nullable, "primary_key": is_pk, "unique": is_pk, } if array_type: field = ArrayField(array_type, **params) elif array_type is False and field_type is UnknownField: field = peewee.BareField(**params) else: field = field_type(**params) fields.append(field) return fields
[docs] def get_primary_keys(self, table_name, schema=None, cache=True): """Returns the primary keys for a table.""" schema = schema or "public" if schema not in self._metadata or not cache: self._metadata[schema] = get_database_columns(self, schema=schema) if table_name not in self._metadata[schema]: return [] else: return self._metadata[schema][table_name]["pk"] or []
[docs] class SQLADatabaseConnection(DatabaseConnection): """SQLAlchemy database connection implementation""" engine = None bases = [] Session = None metadata = None def __init__(self, *args, **kwargs): #: Reports whether the connection is active. self.connected = False self._connect_params = None DatabaseConnection.__init__(self, *args, **kwargs) @property def connection_params(self): """Returns a dictionary with the connection parameters.""" return self._connect_params def _get_password(self, **params): """Get a db password from a pgpass file Parameters: params (dict): A dictionary of database connection parameters Returns: The database password for a given set of connection parameters """ password = params.get("password", None) if not password: try: password = pgpasslib.getpass( params["host"], params["port"], params["database"], params["username"], ) except KeyError: raise RuntimeError("ERROR: invalid server configuration") return password def _make_connection_string(self, dbname_or_uri, **params): """Build a db connection string Parameters: dbname_or_uri (str): The name of the database or the URI to connect to params (dict): A dictionary of database connection parameters Returns: A database connection string """ # Handle the case dbname_or_uri is a URI. if dbname_or_uri.startswith("postgresql://"): dbname_or_uri = dbname_or_uri.replace("postgresql://", "postgresql+psycopg://") if not self.use_psycopg3: dbname_or_uri = dbname_or_uri.replace("psycopg", "psycopg2") if dbname_or_uri.startswith("postgresql"): return dbname_or_uri # Now the case in which dbname_or_uri is a database name and parameters. db_params = params.copy() db_params["drivername"] = ( "postgresql+psycopg" if self.use_psycopg3 else "postgresql+psycopg2" ) db_params["database"] = dbname_or_uri db_params["username"] = db_params.pop("user", None) db_params["host"] = db_params.pop("host", "localhost") db_params["port"] = db_params.pop("port", 5432) if db_params["username"]: db_params["password"] = self._get_password(**db_params) db_connection_string = url.URL.create(**db_params) self._connect_params = params return db_connection_string def _conn(self, dbname, silent_on_fail=False, **params): """Connects to the DB and tests the connection.""" # get connection string db_connection_string = self._make_connection_string(dbname, **params) try: self.create_engine(db_connection_string, echo=False, pool_size=10, pool_recycle=1800) self.engine.connect() except OpError: if not silent_on_fail: log.warning("Failed to connect to database {0}".format(dbname)) self.engine.dispose() self.engine = None self.connected = False self.Session = None self.metadata = None else: self.connected = True self.dbname = dbname self.prepare_bases() if self.connected: self.post_connect() return self.connected
[docs] def reset_engine(self): """Reset the engine, metadata, and session""" if self.engine: self.engine.dispose() self.engine = None self.metadata = None self.Session.close() self.Session = None
[docs] def create_engine( self, db_connection_string=None, echo=False, pool_size=10, pool_recycle=1800, expire_on_commit=True, ): """Create a new database engine Resets and creates a new sqlalchemy database engine. Also creates and binds engine metadata and a new scoped session. """ self.reset_engine() if not db_connection_string: dbname = self.dbname or self.DATABASE_NAME db_connection_string = self._make_connection_string(dbname, **self.connection_params) self.engine = create_engine( db_connection_string, echo=echo, pool_size=pool_size, pool_recycle=pool_recycle, future=True, ) self.metadata = MetaData() self.Session = scoped_session( sessionmaker(bind=self.engine, expire_on_commit=expire_on_commit, future=True) )
[docs] def add_base(self, base, prepare=True): """Binds a base to this connection.""" if base not in self.bases: self.bases.append(base) if prepare and self.connected: self.prepare_bases(base=base)
[docs] def prepare_bases(self, base=None): """Prepare a Model Base Prepares a SQLalchemy Base for reflection. This binds a database engine to a specific Base which maps to a set of ModelClasses. If ``base`` is passed only that base will be prepared. Otherwise, all the bases bound to this database connection will be prepared. """ do_bases = [base] if base else self.bases for base in do_bases: if issubclass(base, DeferredReflection): base.prepare(self.engine, views=True) # If the base has an attribute _relations that's the function # to call to set up the relationships once the engine has been # bound to the base. if hasattr(base, "_relations"): if isinstance(base._relations, str): module = importlib.import_module(base.__module__) relations_func = getattr(module, base._relations) relations_func() elif callable(base._relations): base._relations() else: pass