#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Author: José Sánchez-Gallego (
# @Date: 2018-12-14
# @Filename:
# @License: BSD 3-clause (

# The following functions are adapted from the sqlalchemy_schemadisplay by
# Florian Schulze (

import re

from peewee import ForeignKeyField, IndexMetadata

    import pydot
except ImportError:
    pydot = None

__all__ = ['create_schema_graph', 'show_schema_graph']

field_type_psql = {'AUTO': 'SERIAL',
                   'BIGAUTO': 'BIGSERIAL',
                   'BIGINT': 'BIGINT',
                   'BLOB': 'BYTEA',
                   'BOOL': 'BOOLEAN',
                   'CHAR': 'CHAR',
                   'DATE': 'DATE',
                   'DATETIME': 'TIMESTAMP',
                   'DECIMAL': 'NUMERIC',
                   'DEFAULT': '',
                   'DOUBLE': 'DOUBLE PRECISION',
                   'FLOAT': 'REAL',
                   'INT': 'INTEGER',
                   'SMALLINT': 'SMALLINT',
                   'TEXT': 'TEXT',
                   'TIME': 'TIME',
                   'UUID': 'UUID',
                   'UUIDB': 'BYTEA',
                   'VARCHAR': 'VARCHAR'}

def _render_table_html(model, show_columns=True, show_pks=True,
                       show_indices=True, show_datatypes=True):
    """Creates the HTML tags for a table, including PKs, FKs, and indices.

    model : `peewee.Model`
        The Peewee model for which to create the table.
    show_columns : bool
        Whether to show the column names.
    show_pks : bool
        Whether to show the primary key. Supersedes ``show_columns``.
    show_indices : `bool`
        Whether to show the indices from the table as separate rows.
    show_datatypes : `bool`
        Whether to show the data type of each column.


    table_name = model._meta.table_name
    fields = model._meta.fields

    # pk_col_names = set([fields[field_name].column_name for field_name in fields
    #                     if fields[field_name].primary_key])

    # fk_col_names = set([fields[field_name].column_name for field_name in fields
    #                     if isinstance(fields[field_name], ForeignKeyField)])

    def format_field_str(field):
        """Add in (PK) OR (FK) suffixes to column names."""

        suffixes = []

        column_name = field.column_name
        if column_name == '__composite_key__':
            column_name = '(' + ', '.join(pk.field_names) + ')'
            suffixes.append('PK')  # Composite keys get .primary_key == False

        if field.primary_key:
        if isinstance(field, ForeignKeyField):

        suffix = ' (' + ', '.join(suffixes) + ')' if len(suffixes) > 0 else ''

        if show_datatypes and field.column_name != '__composite_key__':
            field_type = field.field_type
            if field_type in field_type_psql:
                field_type = field_type_psql[field_type]
            return f'- {column_name}{suffix} : {field_type}'
            return f'- {column_name}{suffix}'

    html = (f'<<TABLE BORDER="1" CELLBORDER="0" CELLSPACING="0">'
            f'<TR><TD ALIGN="CENTER"><font face="Lucida Sans Demibold Roman">'

    added_col_name = []
    fields_html = []

    pk = model._meta.primary_key
    if show_pks and pk:
        if model._meta.composite_key:
            column_name = '(' + ', '.join(pk.field_names) + ')'
            column_name = pk.column_name
            '<TR><TD ALIGN="LEFT" PORT="{}">{}</TD></TR>'.format(
                column_name, format_field_str(pk)))

    # Add a row for each column in the table.
    if show_columns:
        for field in fields.values():

            if field.primary_key:

            column_name = field.column_name

            # Avoids repeating columns. This can happen if there are multiple
            # FKs pointing to the same column.
            if column_name in added_col_name:

                '<TR><TD ALIGN="LEFT" PORT="{}">{}</TD></TR>'.format(
                    column_name, format_field_str(field)))


    if len(fields_html) > 0:
        html += '<TR><TD BORDER="1" CELLPADDING="0"></TD></TR>'
        html += ''.join(fields_html)

    # Add indexes and unique constraints
    if show_indices:

        if model._meta.database.connected:
            indexes = model._meta.database.get_indexes(model._meta.table_name,
            indexes = [index._expressions[0] for index in model._meta.fields_to_index()
                       if not isinstance(index._expressions[0], ForeignKeyField)]

        if len(indexes) > 0:
            first = True

            for index in indexes:
                if not isinstance(index, IndexMetadata):

                column_names = index.columns
                ilabel = 'INDEX'

                if len(column_names) == 1:
                    column_name = column_names[0]
                    if column_name == '':
                        match = re.match(r'.+q3c_ang2ipix\("*(\w+)"*, "*(\w+)"*\).+',
                        if match:
                            column_name = '(' + ', '.join(match.groups()) + ')'
                            ilabel = 'Q3C'
                    column_name = '(' + ', '.join(column_names) + ')'

                if index.unique:
                    if pk and column_name == pk.column_name:
                    ilabel = 'UNIQUE'

                if first:
                    html += '<TR><TD BORDER="1" CELLPADDING="0"></TD></TR>'
                    first = False

                html += f'<TR><TD ALIGN="LEFT">{ilabel} {column_name}</TD></TR>'

    html += '</TABLE>>'

    return html

[docs] def create_schema_graph(models=None, base=None, schema=None, show_columns=True, show_pks=True, show_indices=True, show_datatypes=True, skip_tables=[], font='Bitstream-Vera Sans', graph_options={}, relation_options={}): """Creates a graph visualisation from a series of Peewee models. Produces a `pydot <>`__ graph including the tables and relationships from a series of models or from a base model class. Parameters ---------- models : list A list of Peewee `models <peewee:Model>` to be graphed. base : peewee:Model A base model class. If passed, all the model classes that were created by subclassing from the base model will be used. schema : str A schema name. If passed, will be used to limit the list of models or ``base`` subclasses to only the models that match the schema name. show_columns : bool Whether to show the column names. show_pks : bool Whether to show the primary key. Supersedes ``show_columns``. show_indices : bool Whether to show the indices from the table as separate rows. show_datatypes : bool Whether to show the data type of each column. skip_tables : list List of table names to skip. font : str The name of the font to use. graph_options : dict Options for creating the graph. Any valid Graphviz option. relation_options : dict Additional parameters to be passed to ``pydot.Edge`` when creating the relationships. Returns ------- graph : `pydot.Dot` A ``pydot.Dot`` object with the graph representation of the schema. Example ------- :: >>> graph = create_schema_graph([User, Tweet]) >>> graph.write_pdf('tweetdb.pdf') """ assert models or base, 'either model or base must be passed.' assert pydot, ('pydot is required for create_schema_graph. ' 'Try running "pip install sdssdb[all]"') relation_kwargs = {'fontsize': '7.0'} relation_kwargs.update(relation_options) if base and not models: models = set(base.__subclasses__()) while True: old_models = models.copy() for model in old_models: models |= set(model.__subclasses__()) if models == old_models: break if schema: models = [model for model in models if model._meta.schema == schema] default_graph_options = dict(program='dot', rankdir='TB', sep='0.01', mode='ipsep', overlap='ipsep') default_graph_options.update(graph_options) graph = pydot.Dot(prog='dot', **graph_options) for model in models: if model._meta.table_name in skip_tables: continue if model._meta.database.connected and not model.table_exists(): continue graph.add_node( pydot.Node(str(model._meta.table_name), shape='plaintext', label=_render_table_html(model, show_columns=show_columns, show_pks=show_pks, show_indices=show_indices, show_datatypes=show_datatypes), fontname=font, fontsize='7.0') ) for field in model._meta.fields.values(): if (not isinstance(field, ForeignKeyField) or field.rel_model not in models): continue from_col_name = '+ ' + field.column_name to_col_name = field.rel_field.column_name if field.rel_field.primary_key: to_col_name = '' else: to_col_name = '+ ' + to_col_name edge = [model._meta.table_name, field.rel_model._meta.table_name] # is_inheritance = from_field.primary_key and to_field.primary_key # if is_inheritance: # edge = edge[::-1] # is_index = from_field.primary_key or from_field.unique graph_edge = pydot.Edge( dir='both', headlabel=to_col_name, taillabel=from_col_name, arrowhead='none', arrowtail='none', # arrowhead=is_inheritance and 'none' or 'odot', # arrowtail=is_index and 'empty' or 'crow', fontname=font, *edge, **relation_kwargs ) graph.add_edge(graph_edge) return graph
def show_schema_graph(*args, **kwargs): """Creates and displays a schema graph.""" from io import StringIO from PIL import Image iostream = StringIO(create_schema_graph(*args, **kwargs).create_png())'command', 'gwenview'))