Source code for sdssdb.utils.ingest

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2019-09-21
# @Filename: ingest.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)

import functools
import io
import multiprocessing
import os
import re
import warnings

import numpy
import peewee
from playhouse.postgres_ext import ArrayField
from playhouse.reflection import generate_models
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.ext.declarative import DeferredReflection, declarative_base

from sdssdb import log
from sdssdb.connection import SQLADatabaseConnection
from sdssdb.sqlalchemy import BaseModel


try:
    import progressbar
except ImportError:
    progressbar = False

try:
    import inflect
except ImportError:
    inflect = None


__all__ = (
    "to_csv",
    "copy_data",
    "drop_table",
    "create_model_from_table",
    "bulk_insert",
    "file_to_db",
    "create_adhoc_database",
)


DTYPE_TO_FIELD = {
    "i2": peewee.SmallIntegerField,
    "i4": peewee.IntegerField,
    "i8": peewee.BigIntegerField,
    "f4": peewee.FloatField,
    "f8": peewee.DoubleField,
    "S([0-9]+)": peewee.CharField,
}


[docs] def to_csv(table, path, header=True, delimiter="\t", use_multiprocessing=False, workers=4): """Creates a PostgreSQL-valid CSV file from a table, handling arrays. Parameters ---------- table : astropy.table.Table The table to convert. path : str The path to which to write the CSV file. header : bool Whether to add a header with the column names. delimiter : str The delimiter between columns in the CSV files. use_multiprocessing : bool Whether to use multiple cores. The rows of the resulting file will not have the same ordering as the original table. workers : int How many workers to use with multiprocessing. """ if use_multiprocessing: pool = multiprocessing.Pool(workers) tmp_list = pool.map( functools.partial(convert_row_to_psql, delimiter=delimiter), table, chunksize=1000 ) else: tmp_list = [convert_row_to_psql(row, delimiter=delimiter) for row in table] csv_str = "\n".join(tmp_list) if header: csv_str = delimiter.join(table.colnames) + "\n" + csv_str unit = open(path, "w") unit.write(csv_str)
def table_exists(table_name, connection, schema=None): """Returns `True` if a table exists in a database. Parameters ---------- table_name : str The name of the table. connection : .PeeweeDatabaseConnection The Peewee database connection to use. schema : str The schema in which the table lives. """ return connection.table_exists(table_name, schema=schema)
[docs] def drop_table(table_name, connection, cascade=False, schema=None): """Drops a table. Does nothing if the table does not exist. Parameters ---------- table_name : str The name of the table to be dropped. connection : .PeeweeDatabaseConnection The Peewee database connection to use. cascade : bool Whether to drop related tables using ``CASCADE``. schema : str The schema in which the table lives. Returns ------- result : bool Returns `True` if the table was correctly dropped or `False` if the table does not exists and nothing was done. """ if not table_exists(table_name, connection, schema=schema): return False connection.execute_sql(f"DROP TABLE {schema}.{table_name}" + (" CASCADE;" if cascade else ";")) return True
[docs] def create_model_from_table(table_name, table, schema=None, lowercase=False, primary_key=None): """Returns a `~peewee:Model` from the columns in a table. Parameters ---------- table_name : str The name of the table. table : ~astropy.table.Table An astropy table whose column names and types will be used to create the model. schema : str The schema in which the table lives. lowercase : bool If `True`, all column names will be converted to lower case. primary_key : str The name of the column to mark as primary key. """ # Prevents name confusion when setting schema in Meta. schema_ = schema attrs = {} class BaseModel(peewee.Model): class Meta: db_table = table_name schema = schema_ primary_key = False for ii, column_name in enumerate(table.dtype.names): if lowercase: column_name = column_name.lower() column_dtype = table.dtype[ii] field_kwargs = {} if primary_key and primary_key == column_name: primary_key = True else: primary_key = False ColumnField = None type_found = False for dtype, Field in DTYPE_TO_FIELD.items(): match = re.match(dtype, column_dtype.base.str[1:]) if match: if column_dtype.base.str[1] == "S": field_kwargs["max_length"] = int(match.group(1)) ColumnField = Field else: ColumnField = Field if len(column_dtype.shape) == 1: ColumnField = ArrayField( ColumnField, field_kwargs=field_kwargs, dimensions=column_dtype.shape[0], null=True, primary_key=primary_key, ) elif len(column_dtype.shape) > 1: raise ValueError( f"column {column_name} with dtype " f"{column_dtype}: multidimensional arrays " "are not supported." ) else: ColumnField = ColumnField(**field_kwargs, null=True, primary_key=primary_key) type_found = True break if not type_found: raise ValueError( f"cannot find an appropriate field type for " f"column {column_name} with dtype {column_dtype}." ) attrs[column_name] = ColumnField return type(str(table_name), (BaseModel,), attrs)
def convert_row_to_psql(row, delimiter="\t", null="\\N"): """Concerts an astropy table row to a Postgresql-valid CSV string.""" row_data = [] for col_value in row: if numpy.isscalar(col_value): row_data.append(str(col_value)) elif numpy.ma.is_masked(col_value): row_data.append(null) else: if col_value.dtype.base.str[1] == "S": col_value = col_value.astype("U") row_data.append( str(col_value.tolist()) .replace("\n", "") .replace("'", '"') .replace("[", '"{') .replace("]", '}"') ) return delimiter.join(row_data)
[docs] def copy_data(data, connection, table_name, schema=None, chunk_size=10000, show_progress=False): """Loads data into a DB table using ``COPY``. Parameters ---------- data : ~astropy.table.Table An astropy table whose column names and types will be used to create the model. connection : .PeeweeDatabaseConnection The Peewee database connection to use. table_name : str The name of the table. schema : str The schema in which the table lives. chunk_size : int How many rows to load at once. show_progress : bool If `True`, shows a progress bar. Requires the `progressbar2 <https://progressbar-2.readthedocs.io/en/latest/>`__ module to be installed. """ table_sql = "{0}.{1}".format(schema, table_name) if schema else table_name cursor = connection.cursor() # If the progressbar package is installed, uses it to create a progress bar. if show_progress: if progressbar is None: warnings.warn("progressbar2 is not installed. Will not show a progress bar.") else: bar = progressbar.ProgressBar() iterable = bar(range(len(data))) else: iterable = range(len(data)) # TODO: it's probably more efficient to convert each column to string first # (by chunks) and then unzip them into a single string. That way we only # iterate over columns instead of over rows and columns. chunk = 0 tmp_list = [] for ii in iterable: row = data[ii] tmp_list.append(convert_row_to_psql(row)) chunk += 1 # If we have reached a chunk commit point, or this is the last item, # copy and commits to the database. last_item = ii == len(data) - 1 if chunk == chunk_size or (last_item and len(tmp_list) > 0): ss = io.StringIO("\n".join(tmp_list)) cursor.copy_from(ss, table_sql) connection.commit() tmp_list = [] chunk = 0 cursor.close() return
[docs] def bulk_insert(data, connection, model, chunk_size=100000, show_progress=False): """Loads data into a DB table using bulk insert. Parameters ---------- data : ~astropy.table.Table An astropy table with the data to insert. connection : .PeeweeDatabaseConnection The Peewee database connection to use. model : ~peewee:Model The model representing the database table into which to insert the data. chunk_size : int How many rows to load at once. show_progress : bool If `True`, shows a progress bar. Requires the `progressbar2 <https://progressbar-2.readthedocs.io/en/latest/>`__ module to be installed. """ from . import adaptors # noqa if show_progress: if progressbar is None: warnings.warn("progressbar2 is not installed. Will not show a progress bar.") else: bar = progressbar.ProgressBar(max_value=len(data)).start() else: bar = None n_chunk = 0 with connection.atomic(): for batch in peewee.chunked(data, chunk_size): model.insert_many(batch).execute() if bar: n_chunk += chunk_size bar.update(n_chunk) return
[docs] def file_to_db( input_, connection, table_name, schema=None, lowercase=False, create=False, drop=False, truncate=False, primary_key=None, load_data=True, use_copy=True, chunk_size=100000, show_progress=False, ): """Loads a table from a file to a database. Loads a file or a `~astropy.table.Table` object into a database. If ``create=True`` a new table will be created, with column types matching the table ones. All columns are initially defined as ``NULL``. By default, the data are loaded using the ``COPY`` method to optimise performance. This can be disabled if needed. Parameters ---------- input_ : str or ~astropy.table.Table The path to a file that will be opened using `Table.read <astropy.table.Table.read>` or an astropy `~astropy.table.Table`. connection : .PeeweeDatabaseConnection The Peewee database connection to use (SQLAlchemy connections are not supported). table_name : str The name of the table where to load the data, or to be created. schema : str The schema in which the table lives. lowercase : bool If `True`, all column names will be converted to lower case. create : bool Creates the table if it does not exist. drop : bool Drops the table before recreating it. Implies ``create=True``. Note that a ``CASCADE`` drop will be executed. Use with caution. truncate : bool Truncates the table before loading the data but maintains the existing columns. primary_key : str The name of the column to mark as primary key (ignored if the table is not being created). load_data : bool If `True`, loads the data from the table; otherwise just creates the table in the database. use_copy : bool When `True` (recommended) uses the SQL ``COPY`` command to load the data from a CSV stream. chunk_size : int How many rows to load at once. show_progress : bool If `True`, shows a progress bar. Requires the ``progressbar2`` module to be installed. Returns ------- model : ~peewee:Model The model for the table created. """ import astropy.table # If we drop we need to re-create but there is no need to truncate. if drop: create = True truncate = False if isinstance(input_, str) and os.path.isfile(input_): table = astropy.table.Table.read(input_) else: assert isinstance(input_, astropy.table.Table) table = input_ if drop: drop_table(table_name, connection, schema=schema) if table_exists(table_name, connection, schema=schema): Model = generate_models(connection, schema=schema, table_names=[table_name])[table_name] else: if not create: raise ValueError( f"table {table_name} does not exist. " "Call the function with create=True " "if you want to create it." ) Model = create_model_from_table( table_name, table, schema=schema, lowercase=lowercase, primary_key=primary_key ) Model._meta.database = connection Model.create_table() if truncate: Model.truncate_table() if load_data: if use_copy: copy_data( table, connection, table_name, schema=schema, chunk_size=chunk_size, show_progress=show_progress, ) else: bulk_insert( table, connection, Model, chunk_size=chunk_size, show_progress=show_progress ) return Model
[docs] def create_adhoc_database(dbname, schema=None, profile="local"): """Creates an adhoc SQLA database and models, given an existing db Creates an in-memory SQLA database connection given a database name to connect to, along with auto-generated models for the a given schema name. Currently limited to building models for one schema at a time. Useful for temporarily creating and trying a database connection, and simple models, without building and committing a full fledged new database connection. Parameters ---------- dbname : str The name of the database to create a connection for schema : str The name of the schema to create mappings for profile : str The database profile to connect with Returns ------- tuple A temporary database connection and module of model classes Example ------- >>> from sdssdb.utils.ingest import create_adhoc_database >>> tempdb, models = create_adhoc_database('datamodel', schema='filespec') >>> tempdb >>> <DatamodelDatabaseConnection (dbname='datamodel', profile='local', connected=True)> >>> models.File >>> sqlalchemy.ext.automap.File """ # create the database dbclass = f"{dbname.title()}DatabaseConnection" base = declarative_base( cls=( DeferredReflection, BaseModel, ) ) tempdb_class = type( dbclass, (SQLADatabaseConnection,), {"dbname": dbname, "base": automap_base(base)} ) tempdb = tempdb_class(profile=profile, autoconnect=True) if tempdb.connected is False: log.warning( f"Could not connect to database: {dbname}. " "Please check that the database exists. Cannot automap models." ) return tempdb, None # automap the models tempdb.base.prepare( tempdb.engine, reflect=True, schema=schema, classname_for_table=camelize_classname, name_for_collection_relationship=pluralize_collection, ) models = tempdb.base.classes return tempdb, models
def camelize_classname(base, tablename, table): """Produce a 'camelized' class name, e.g. Converts a database table name to camelcase. Uses underscores to denote a new hump. E.g. 'words_and_underscores' -> 'WordsAndUnderscores' see https://docs.sqlalchemy.org/en/13/orm/extensions/automap.html#overriding-naming-schemes Parameters ---------- base : ~sqlalchemy.ext.automap.AutomapBase The AutomapBase class doing the prepare. tablenname : str The string name of the Table table : ~sqlalchemy.schema.Table The Table object itself Returns ------- str A string class name """ return str( tablename[0].upper() + re.sub(r"_([a-z])", lambda m: m.group(1).upper(), tablename[1:]) ) def pluralize_collection(base, local_cls, referred_cls, constraint): """Produce an 'uncamelized', 'pluralized' class name Converts a camel-cased class name into a uncamelized, pluralized class name, e.g. ``'SomeTerm' -> 'some_terms'``. Used when auto-defining relationship names. See https://docs.sqlalchemy.org/en/13/orm/extensions/automap.html#overriding-naming-schemes. Parameters ---------- base : ~sqlalchemy.ext.automap.AutomapBase The AutomapBase class doing the prepare. local_cls : object The class to be mapped on the local side. referred_cls : object The class to be mapped on the referring side. constraint : ~sqlalchemy.schema.ForeignKeyConstraint The ForeignKeyConstraint that is being inspected to produce this relationship. Returns ------- str An uncamelized, pluralized string class name """ assert inflect, "pluralize_collection requires the inflect library." referred_name = referred_cls.__name__ uncamelized = re.sub(r"[A-Z]", lambda m: "_%s" % m.group(0).lower(), referred_name)[1:] _pluralizer = inflect.engine() pluralized = _pluralizer.plural(uncamelized) return pluralized