Skip to content

Instantly share code, notes, and snippets.

@dhermes
Last active May 11, 2023 21:47
Show Gist options
  • Save dhermes/bc5bb087684afe32906c8d6537983db3 to your computer and use it in GitHub Desktop.
Save dhermes/bc5bb087684afe32906c8d6537983db3 to your computer and use it in GitHub Desktop.
[2023-05-11] Script to generate an ERD using only foreign keys
# Copyright (c) 2023 - Present. Hardfin, Inc. All rights reserved.
# Hardfin Confidential - Restricted
import csv
import io
import pathlib
import subprocess
import tempfile
HERE = pathlib.Path(__file__).resolve().parent
SQL_COMMAND = """\
SELECT
tc.table_schema || '.' || tc.table_name AS source_table,
ccu.table_schema || '.' || ccu.table_name AS target_table
FROM
information_schema.table_constraints AS tc
INNER JOIN information_schema.key_column_usage AS kcu ON kcu.constraint_name = tc.constraint_name
INNER JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name
WHERE
tc.constraint_type = 'FOREIGN KEY'
"""
DSN = "postgres://..."
def get_foreign_keys(dsn):
with tempfile.NamedTemporaryFile("w") as file_obj:
file_obj.write(SQL_COMMAND)
file_obj.flush()
print("[DEBUG] Invoking psql to get schema...")
cmd = ("psql", "--dbname", dsn, "--file", file_obj.name, "--csv")
csv_rows = subprocess.check_output(cmd)
csv_rows = csv_rows.decode("ascii")
reader = csv.DictReader(io.StringIO(csv_rows))
return list(reader)
def main():
foreign_keys = get_foreign_keys(DSN)
print("[DEBUG] Processing foreign keys...")
from_source = {}
all_tables = set()
for foreign_key in foreign_keys:
source_table = foreign_key["source_table"]
target_table = foreign_key["target_table"]
all_tables.update([source_table, target_table])
target_tables = from_source.setdefault(source_table, [])
target_tables.append(target_table)
graphviz_lines = ["digraph G {", ""]
aliases = {}
for i, table in enumerate(sorted(all_tables)):
alias = f"T{i}"
aliases[table] = alias
graphviz_lines.append(f' {alias} [label="{table}"]')
graphviz_lines.append("")
source_tables = sorted(from_source.keys())
for source_table in source_tables:
source_alias = aliases[source_table]
target_tables = sorted(from_source[source_table])
for target_table in target_tables:
target_alias = aliases[target_table]
graphviz_lines.append(f" {source_alias} -> {target_alias};")
graphviz_lines.extend(["}", ""]) # Trailing newline
print("[DEBUG] Saving erd.dot file...")
with open(HERE / "erd.dot", "w") as file_obj:
file_obj.write("\n".join(graphviz_lines))
if __name__ == "__main__":
main()
# Copyright (c) 2023 - Present. Hardfin, Inc. All rights reserved.
# Hardfin Confidential - Restricted
import csv
import io
import pathlib
import subprocess
import tempfile
HERE = pathlib.Path(__file__).resolve().parent
# NOTE: This command / script **ASSUMES** all tables are involved in some reference
SQL_COMMAND = """\
SELECT
tc.table_schema || '.' || tc.table_name AS source_table,
ccu.table_schema || '.' || ccu.table_name AS target_table
FROM
information_schema.table_constraints AS tc
INNER JOIN information_schema.key_column_usage AS kcu ON kcu.constraint_name = tc.constraint_name
INNER JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name
WHERE
tc.constraint_type = 'FOREIGN KEY'
"""
DSN = "postgres://..."
def get_foreign_keys(dsn):
with tempfile.NamedTemporaryFile("w") as file_obj:
file_obj.write(SQL_COMMAND)
file_obj.flush()
print("[DEBUG] Invoking psql to get schema...")
cmd = ("psql", "--dbname", dsn, "--file", file_obj.name, "--csv")
csv_rows = subprocess.check_output(cmd)
csv_rows = csv_rows.decode("ascii")
reader = csv.DictReader(io.StringIO(csv_rows))
return list(reader)
def main():
foreign_keys = get_foreign_keys(DSN)
print("[DEBUG] Processing foreign keys...")
from_source = {}
all_tables = set()
for foreign_key in foreign_keys:
source_table = foreign_key["source_table"]
target_table = foreign_key["target_table"]
all_tables.update([source_table, target_table])
target_tables = from_source.setdefault(source_table, [])
target_tables.append(target_table)
aliases = {}
for i, table in enumerate(sorted(all_tables)):
aliases[table] = f"T{i}"
mermaid_lines = ["graph TD;"]
source_tables = sorted(from_source.keys())
for source_table in source_tables:
source_alias = aliases[source_table]
target_tables = sorted(from_source[source_table])
for target_table in target_tables:
target_alias = aliases[target_table]
mermaid_lines.append(
f" {source_alias}[{source_table}] --> {target_alias}[{target_table}];"
)
mermaid_lines.append("") # Trailing newline
print("[DEBUG] Saving erd.mermaid file...")
with open(HERE / "erd.mermaid", "w") as file_obj:
file_obj.write("\n".join(mermaid_lines))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment