import functools
import io
import multiprocessing
import os
import re
import warnings

import numpy
import peewee
from playhouse.postgres_ext import ArrayField
from playhouse.reflection import generate_models
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.ext.declarative import DeferredReflection, declarative_base

from sdssdb import log
from sdssdb.connection import SQLADatabaseConnection
from sdssdb.sqlalchemy import BaseModel

    import progressbar
except ImportError:
    progressbar = False

    import inflect
except ImportError:
    inflect = None

__all__ = ('to_csv', 'copy_data', 'drop_table', 'create_model_from_table',
           'bulk_insert', 'file_to_db', 'create_adhoc_database')

    'i2': peewee.SmallIntegerField,
    'i4': peewee.IntegerField,
    'i8': peewee.BigIntegerField,
    'f4': peewee.FloatField,
    'f8': peewee.DoubleField,
    'S([0-9]+)': peewee.CharField

[docs] def to_csv(table, path, header=True, delimiter='\t', use_multiprocessing=False, workers=4): """Creates a PostgreSQL-valid CSV file from a table, handling arrays. Parameters ---------- table : astropy.table.Table The table to convert. path : str The path to which to write the CSV file. header : bool Whether to add a header with the column names. delimiter : str The delimiter between columns in the CSV files. use_multiprocessing : bool Whether to use multiple cores. The rows of the resulting file will not have the same ordering as the original table. workers : int How many workers to use with multiprocessing. """ if use_multiprocessing: pool = multiprocessing.Pool(workers) tmp_list =, delimiter=delimiter), table, chunksize=1000) else: tmp_list = [convert_row_to_psql(row, delimiter=delimiter) for row in table] csv_str = '\n'.join(tmp_list) if header: csv_str = delimiter.join(table.colnames) + '\n' + csv_str unit = open(path, 'w') unit.write(csv_str)
def table_exists(table_name, connection, schema=None): """Returns `True` if a table exists in a database. Parameters ---------- table_name : str The name of the table. connection : .PeeweeDatabaseConnection The Peewee database connection to use. schema : str The schema in which the table lives. """ return connection.table_exists(table_name, schema=schema)
[docs] def drop_table(table_name, connection, cascade=False, schema=None): """Drops a table. Does nothing if the table does not exist. Parameters ---------- table_name : str The name of the table to be dropped. connection : .PeeweeDatabaseConnection The Peewee database connection to use. cascade : bool Whether to drop related tables using ``CASCADE``. schema : str The schema in which the table lives. Returns ------- result : bool Returns `True` if the table was correctly dropped or `False` if the table does not exists and nothing was done. """ if not table_exists(table_name, connection, schema=schema): return False connection.execute_sql(f'DROP TABLE {schema}.{table_name}' + (' CASCADE;' if cascade else ';')) return True
[docs] def create_model_from_table(table_name, table, schema=None, lowercase=False, primary_key=None): """Returns a `~peewee:Model` from the columns in a table. Parameters ---------- table_name : str The name of the table. table : ~astropy.table.Table An astropy table whose column names and types will be used to create the model. schema : str The schema in which the table lives. lowercase : bool If `True`, all column names will be converted to lower case. primary_key : str The name of the column to mark as primary key. """ # Prevents name confusion when setting schema in Meta. schema_ = schema attrs = {} class BaseModel(peewee.Model): class Meta: db_table = table_name schema = schema_ primary_key = False for ii, column_name in enumerate(table.dtype.names): if lowercase: column_name = column_name.lower() column_dtype = table.dtype[ii] field_kwargs = {} if primary_key and primary_key == column_name: primary_key = True else: primary_key = False ColumnField = None type_found = False for dtype, Field in DTYPE_TO_FIELD.items(): match = re.match(dtype, column_dtype.base.str[1:]) if match: if column_dtype.base.str[1] == 'S': field_kwargs['max_length'] = int( ColumnField = Field else: ColumnField = Field if len(column_dtype.shape) == 1: ColumnField = ArrayField(ColumnField, field_kwargs=field_kwargs, dimensions=column_dtype.shape[0], null=True, primary_key=primary_key) elif len(column_dtype.shape) > 1: raise ValueError(f'column {column_name} with dtype ' f'{column_dtype}: multidimensional arrays ' 'are not supported.') else: ColumnField = ColumnField(**field_kwargs, null=True, primary_key=primary_key) type_found = True break if not type_found: raise ValueError(f'cannot find an appropriate field type for ' f'column {column_name} with dtype {column_dtype}.') attrs[column_name] = ColumnField return type(str(table_name), (BaseModel,), attrs)
def convert_row_to_psql(row, delimiter='\t', null='\\N'): """Concerts an astropy table row to a Postgresql-valid CSV string.""" row_data = [] for col_value in row: if numpy.isscalar(col_value): row_data.append(str(col_value)) elif row_data.append(null) else: if col_value.dtype.base.str[1] == 'S': col_value = col_value.astype('U') row_data.append( str(col_value.tolist()) .replace('\n', '') .replace('\'', '\"') .replace('[', '\"{').replace(']', '}\"')) return delimiter.join(row_data)
[docs] def copy_data(data, connection, table_name, schema=None, chunk_size=10000, show_progress=False): """Loads data into a DB table using ``COPY``. Parameters ---------- data : ~astropy.table.Table An astropy table whose column names and types will be used to create the model. connection : .PeeweeDatabaseConnection The Peewee database connection to use. table_name : str The name of the table. schema : str The schema in which the table lives. chunk_size : int How many rows to load at once. show_progress : bool If `True`, shows a progress bar. Requires the `progressbar2 <>`__ module to be installed. """ table_sql = '{0}.{1}'.format(schema, table_name) if schema else table_name cursor = connection.cursor() # If the progressbar package is installed, uses it to create a progress bar. if show_progress: if progressbar is None: warnings.warn('progressbar2 is not installed. ' 'Will not show a progress bar.') else: bar = progressbar.ProgressBar() iterable = bar(range(len(data))) else: iterable = range(len(data)) # TODO: it's probably more efficient to convert each column to string first # (by chunks) and then unzip them into a single string. That way we only # iterate over columns instead of over rows and columns. chunk = 0 tmp_list = [] for ii in iterable: row = data[ii] tmp_list.append(convert_row_to_psql(row)) chunk += 1 # If we have reached a chunk commit point, or this is the last item, # copy and commits to the database. last_item = ii == len(data) - 1 if chunk == chunk_size or (last_item and len(tmp_list) > 0): ss = io.StringIO('\n'.join(tmp_list)) cursor.copy_from(ss, table_sql) connection.commit() tmp_list = [] chunk = 0 cursor.close() return
[docs] def bulk_insert(data, connection, model, chunk_size=100000, show_progress=False): """Loads data into a DB table using bulk insert. Parameters ---------- data : ~astropy.table.Table An astropy table with the data to insert. connection : .PeeweeDatabaseConnection The Peewee database connection to use. model : ~peewee:Model The model representing the database table into which to insert the data. chunk_size : int How many rows to load at once. show_progress : bool If `True`, shows a progress bar. Requires the `progressbar2 <>`__ module to be installed. """ from . import adaptors # noqa if show_progress: if progressbar is None: warnings.warn('progressbar2 is not installed. ' 'Will not show a progress bar.') else: bar = progressbar.ProgressBar(max_value=len(data)).start() else: bar = None n_chunk = 0 with connection.atomic(): for batch in peewee.chunked(data, chunk_size): model.insert_many(batch).execute() if bar: n_chunk += chunk_size bar.update(n_chunk) return
[docs] def file_to_db(input_, connection, table_name, schema=None, lowercase=False, create=False, drop=False, truncate=False, primary_key=None, load_data=True, use_copy=True, chunk_size=100000, show_progress=False): """Loads a table from a file to a database. Loads a file or a `~astropy.table.Table` object into a database. If ``create=True`` a new table will be created, with column types matching the table ones. All columns are initially defined as ``NULL``. By default, the data are loaded using the ``COPY`` method to optimise performance. This can be disabled if needed. Parameters ---------- input_ : str or ~astropy.table.Table The path to a file that will be opened using ` <>` or an astropy `~astropy.table.Table`. connection : .PeeweeDatabaseConnection The Peewee database connection to use (SQLAlchemy connections are not supported). table_name : str The name of the table where to load the data, or to be created. schema : str The schema in which the table lives. lowercase : bool If `True`, all column names will be converted to lower case. create : bool Creates the table if it does not exist. drop : bool Drops the table before recreating it. Implies ``create=True``. Note that a ``CASCADE`` drop will be executed. Use with caution. truncate : bool Truncates the table before loading the data but maintains the existing columns. primary_key : str The name of the column to mark as primary key (ignored if the table is not being created). load_data : bool If `True`, loads the data from the table; otherwise just creates the table in the database. use_copy : bool When `True` (recommended) uses the SQL ``COPY`` command to load the data from a CSV stream. chunk_size : int How many rows to load at once. show_progress : bool If `True`, shows a progress bar. Requires the ``progressbar2`` module to be installed. Returns ------- model : ~peewee:Model The model for the table created. """ import astropy.table # If we drop we need to re-create but there is no need to truncate. if drop: create = True truncate = False if isinstance(input_, str) and os.path.isfile(input_): table = else: assert isinstance(input_, astropy.table.Table) table = input_ if drop: drop_table(table_name, connection, schema=schema) if table_exists(table_name, connection, schema=schema): Model = generate_models(connection, schema=schema, table_names=[table_name])[table_name] else: if not create: raise ValueError(f'table {table_name} does not exist. ' 'Call the function with create=True ' 'if you want to create it.') Model = create_model_from_table(table_name, table, schema=schema, lowercase=lowercase, primary_key=primary_key) Model._meta.database = connection Model.create_table() if truncate: Model.truncate_table() if load_data: if use_copy: copy_data(table, connection, table_name, schema=schema, chunk_size=chunk_size, show_progress=show_progress) else: bulk_insert(table, connection, Model, chunk_size=chunk_size, show_progress=show_progress) return Model
[docs] def create_adhoc_database(dbname, schema=None, profile='local'): """ Creates an adhoc SQLA database and models, given an existing db Creates an in-memory SQLA database connection given a database name to connect to, along with auto-generated models for the a given schema name. Currently limited to building models for one schema at a time. Useful for temporarily creating and trying a database connection, and simple models, without building and committing a full fledged new database connection. Parameters ---------- dbname : str The name of the database to create a connection for schema : str The name of the schema to create mappings for profile : str The database profile to connect with Returns ------- tuple A temporary database connection and module of model classes Example ------- >>> from sdssdb.utils.ingest import create_adhoc_database >>> tempdb, models = create_adhoc_database('datamodel', schema='filespec') >>> tempdb >>> <DatamodelDatabaseConnection (dbname='datamodel', profile='local', connected=True)> >>> models.File >>> sqlalchemy.ext.automap.File """ # create the database dbclass = f'{dbname.title()}DatabaseConnection' base = declarative_base(cls=(DeferredReflection, BaseModel,)) tempdb_class = type(dbclass, (SQLADatabaseConnection,), {'dbname': dbname, 'base': automap_base(base)}) tempdb = tempdb_class(profile=profile, autoconnect=True) if tempdb.connected is False: log.warning(f'Could not connect to database: {dbname}. ' 'Please check that the database exists. Cannot automap models.') return tempdb, None # automap the models tempdb.base.prepare(tempdb.engine, reflect=True, schema=schema, classname_for_table=camelize_classname, name_for_collection_relationship=pluralize_collection) models = tempdb.base.classes return tempdb, models
def camelize_classname(base, tablename, table): """ Produce a 'camelized' class name, e.g. Converts a database table name to camelcase. Uses underscores to denote a new hump. E.g. 'words_and_underscores' -> 'WordsAndUnderscores' see Parameters ---------- base : ~sqlalchemy.ext.automap.AutomapBase The AutomapBase class doing the prepare. tablenname : str The string name of the Table table : ~sqlalchemy.schema.Table The Table object itself Returns ------- str A string class name """ return str(tablename[0].upper() + re.sub(r'_([a-z])', lambda m:, tablename[1:])) def pluralize_collection(base, local_cls, referred_cls, constraint): """ Produce an 'uncamelized', 'pluralized' class name Converts a camel-cased class name into a uncamelized, pluralized class name, e.g. ``'SomeTerm' -> 'some_terms'``. Used when auto-defining relationship names. See Parameters ---------- base : ~sqlalchemy.ext.automap.AutomapBase The AutomapBase class doing the prepare. local_cls : object The class to be mapped on the local side. referred_cls : object The class to be mapped on the referring side. constraint : ~sqlalchemy.schema.ForeignKeyConstraint The ForeignKeyConstraint that is being inspected to produce this relationship. Returns ------- str An uncamelized, pluralized string class name """ assert inflect, 'pluralize_collection requires the inflect library.' referred_name = referred_cls.__name__ uncamelized = re.sub(r'[A-Z]', lambda m: "_%s" %, referred_name)[1:] _pluralizer = inflect.engine() pluralized = _pluralizer.plural(uncamelized) return pluralized