#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2018-12-14
# @Filename: peewee_schemadisplay.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)
# The following functions are adapted from the sqlalchemy_schemadisplay by
# Florian Schulze (https://github.com/fschulze/sqlalchemy_schemadisplay).
import re
from peewee import ForeignKeyField, IndexMetadata
try:
import pydot
except ImportError:
pydot = None
__all__ = ['create_schema_graph', 'show_schema_graph']
field_type_psql = {'AUTO': 'SERIAL',
'BIGAUTO': 'BIGSERIAL',
'BIGINT': 'BIGINT',
'BLOB': 'BYTEA',
'BOOL': 'BOOLEAN',
'CHAR': 'CHAR',
'DATE': 'DATE',
'DATETIME': 'TIMESTAMP',
'DECIMAL': 'NUMERIC',
'DEFAULT': '',
'DOUBLE': 'DOUBLE PRECISION',
'FLOAT': 'REAL',
'INT': 'INTEGER',
'SMALLINT': 'SMALLINT',
'TEXT': 'TEXT',
'TIME': 'TIME',
'UUID': 'UUID',
'UUIDB': 'BYTEA',
'VARCHAR': 'VARCHAR'}
def _render_table_html(model, show_columns=True, show_pks=True,
show_indices=True, show_datatypes=True):
"""Creates the HTML tags for a table, including PKs, FKs, and indices.
Parameters
----------
model : `peewee.Model`
The Peewee model for which to create the table.
show_columns : bool
Whether to show the column names.
show_pks : bool
Whether to show the primary key. Supersedes ``show_columns``.
show_indices : `bool`
Whether to show the indices from the table as separate rows.
show_datatypes : `bool`
Whether to show the data type of each column.
"""
table_name = model._meta.table_name
fields = model._meta.fields
# pk_col_names = set([fields[field_name].column_name for field_name in fields
# if fields[field_name].primary_key])
# fk_col_names = set([fields[field_name].column_name for field_name in fields
# if isinstance(fields[field_name], ForeignKeyField)])
def format_field_str(field):
"""Add in (PK) OR (FK) suffixes to column names."""
suffixes = []
column_name = field.column_name
if column_name == '__composite_key__':
column_name = '(' + ', '.join(pk.field_names) + ')'
suffixes.append('PK') # Composite keys get .primary_key == False
if field.primary_key:
suffixes.append('PK')
if isinstance(field, ForeignKeyField):
suffixes.append('FK')
suffix = ' (' + ', '.join(suffixes) + ')' if len(suffixes) > 0 else ''
if show_datatypes and field.column_name != '__composite_key__':
field_type = field.field_type
if field_type in field_type_psql:
field_type = field_type_psql[field_type]
return f'- {column_name}{suffix} : {field_type}'
else:
return f'- {column_name}{suffix}'
html = (f'<<TABLE BORDER="1" CELLBORDER="0" CELLSPACING="0">'
f'<TR><TD ALIGN="CENTER"><font face="Lucida Sans Demibold Roman">'
f'{table_name}</font><BR/>({model.__name__})</TD></TR>')
added_col_name = []
fields_html = []
pk = model._meta.primary_key
if show_pks and pk:
if model._meta.composite_key:
column_name = '(' + ', '.join(pk.field_names) + ')'
else:
column_name = pk.column_name
fields_html.append(
'<TR><TD ALIGN="LEFT" PORT="{}">{}</TD></TR>'.format(
column_name, format_field_str(pk)))
# Add a row for each column in the table.
if show_columns:
for field in fields.values():
if field.primary_key:
continue
column_name = field.column_name
# Avoids repeating columns. This can happen if there are multiple
# FKs pointing to the same column.
if column_name in added_col_name:
continue
fields_html.append(
'<TR><TD ALIGN="LEFT" PORT="{}">{}</TD></TR>'.format(
column_name, format_field_str(field)))
added_col_name.append(column_name)
if len(fields_html) > 0:
html += '<TR><TD BORDER="1" CELLPADDING="0"></TD></TR>'
html += ''.join(fields_html)
# Add indexes and unique constraints
if show_indices:
if model._meta.database.connected:
indexes = model._meta.database.get_indexes(model._meta.table_name,
schema=model._meta.schema)
else:
indexes = [index._expressions[0] for index in model._meta.fields_to_index()
if not isinstance(index._expressions[0], ForeignKeyField)]
if len(indexes) > 0:
first = True
for index in indexes:
if not isinstance(index, IndexMetadata):
continue
column_names = index.columns
ilabel = 'INDEX'
if len(column_names) == 1:
column_name = column_names[0]
if column_name == '':
match = re.match(r'.+q3c_ang2ipix\("*(\w+)"*, "*(\w+)"*\).+',
index.sql)
if match:
column_name = '(' + ', '.join(match.groups()) + ')'
ilabel = 'Q3C'
else:
continue
else:
column_name = '(' + ', '.join(column_names) + ')'
if index.unique:
if pk and column_name == pk.column_name:
continue
ilabel = 'UNIQUE'
if first:
html += '<TR><TD BORDER="1" CELLPADDING="0"></TD></TR>'
first = False
html += f'<TR><TD ALIGN="LEFT">{ilabel} {column_name}</TD></TR>'
html += '</TABLE>>'
return html
[docs]
def create_schema_graph(models=None, base=None, schema=None, show_columns=True,
show_pks=True, show_indices=True, show_datatypes=True,
skip_tables=[], font='Bitstream-Vera Sans',
graph_options={}, relation_options={}):
"""Creates a graph visualisation from a series of Peewee models.
Produces a `pydot <https://pypi.org/project/pydot/>`__ graph including the
tables and relationships from a series of models or from a base model
class.
Parameters
----------
models : list
A list of Peewee `models <peewee:Model>` to be graphed.
base : peewee:Model
A base model class. If passed, all the model classes that were created
by subclassing from the base model will be used.
schema : str
A schema name. If passed, will be used to limit the list of models or
``base`` subclasses to only the models that match the schema name.
show_columns : bool
Whether to show the column names.
show_pks : bool
Whether to show the primary key. Supersedes ``show_columns``.
show_indices : bool
Whether to show the indices from the table as separate rows.
show_datatypes : bool
Whether to show the data type of each column.
skip_tables : list
List of table names to skip.
font : str
The name of the font to use.
graph_options : dict
Options for creating the graph. Any valid Graphviz option.
relation_options : dict
Additional parameters to be passed to ``pydot.Edge`` when creating the
relationships.
Returns
-------
graph : `pydot.Dot`
A ``pydot.Dot`` object with the graph representation of the schema.
Example
-------
::
>>> graph = create_schema_graph([User, Tweet])
>>> graph.write_pdf('tweetdb.pdf')
"""
assert models or base, 'either model or base must be passed.'
assert pydot, ('pydot is required for create_schema_graph. '
'Try running "pip install sdssdb[all]"')
relation_kwargs = {'fontsize': '7.0'}
relation_kwargs.update(relation_options)
if base and not models:
models = set(base.__subclasses__())
while True:
old_models = models.copy()
for model in old_models:
models |= set(model.__subclasses__())
if models == old_models:
break
if schema:
models = [model for model in models if model._meta.schema == schema]
default_graph_options = dict(program='dot',
rankdir='TB',
sep='0.01',
mode='ipsep',
overlap='ipsep')
default_graph_options.update(graph_options)
graph = pydot.Dot(prog='dot', **graph_options)
for model in models:
if model._meta.table_name in skip_tables:
continue
if model._meta.database.connected and not model.table_exists():
continue
graph.add_node(
pydot.Node(str(model._meta.table_name),
shape='plaintext',
label=_render_table_html(model,
show_columns=show_columns,
show_pks=show_pks,
show_indices=show_indices,
show_datatypes=show_datatypes),
fontname=font,
fontsize='7.0')
)
for field in model._meta.fields.values():
if (not isinstance(field, ForeignKeyField) or
field.rel_model not in models):
continue
from_col_name = '+ ' + field.column_name
to_col_name = field.rel_field.column_name
if field.rel_field.primary_key:
to_col_name = ''
else:
to_col_name = '+ ' + to_col_name
edge = [model._meta.table_name, field.rel_model._meta.table_name]
# is_inheritance = from_field.primary_key and to_field.primary_key
# if is_inheritance:
# edge = edge[::-1]
# is_index = from_field.primary_key or from_field.unique
graph_edge = pydot.Edge(
dir='both',
headlabel=to_col_name,
taillabel=from_col_name,
arrowhead='none',
arrowtail='none',
# arrowhead=is_inheritance and 'none' or 'odot',
# arrowtail=is_index and 'empty' or 'crow',
fontname=font,
*edge,
**relation_kwargs
)
graph.add_edge(graph_edge)
return graph
def show_schema_graph(*args, **kwargs):
"""Creates and displays a schema graph."""
from io import StringIO
from PIL import Image
iostream = StringIO(create_schema_graph(*args, **kwargs).create_png())
Image.open(iostream).show(command=kwargs.get('command', 'gwenview'))