Last active
May 11, 2023 21:47
-
-
Save dhermes/bc5bb087684afe32906c8d6537983db3 to your computer and use it in GitHub Desktop.
[2023-05-11] Script to generate an ERD using only foreign keys
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2023 - Present. Hardfin, Inc. All rights reserved. | |
# Hardfin Confidential - Restricted | |
import csv | |
import io | |
import pathlib | |
import subprocess | |
import tempfile | |
HERE = pathlib.Path(__file__).resolve().parent | |
SQL_COMMAND = """\ | |
SELECT | |
tc.table_schema || '.' || tc.table_name AS source_table, | |
ccu.table_schema || '.' || ccu.table_name AS target_table | |
FROM | |
information_schema.table_constraints AS tc | |
INNER JOIN information_schema.key_column_usage AS kcu ON kcu.constraint_name = tc.constraint_name | |
INNER JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name | |
WHERE | |
tc.constraint_type = 'FOREIGN KEY' | |
""" | |
DSN = "postgres://..." | |
def get_foreign_keys(dsn): | |
with tempfile.NamedTemporaryFile("w") as file_obj: | |
file_obj.write(SQL_COMMAND) | |
file_obj.flush() | |
print("[DEBUG] Invoking psql to get schema...") | |
cmd = ("psql", "--dbname", dsn, "--file", file_obj.name, "--csv") | |
csv_rows = subprocess.check_output(cmd) | |
csv_rows = csv_rows.decode("ascii") | |
reader = csv.DictReader(io.StringIO(csv_rows)) | |
return list(reader) | |
def main(): | |
foreign_keys = get_foreign_keys(DSN) | |
print("[DEBUG] Processing foreign keys...") | |
from_source = {} | |
all_tables = set() | |
for foreign_key in foreign_keys: | |
source_table = foreign_key["source_table"] | |
target_table = foreign_key["target_table"] | |
all_tables.update([source_table, target_table]) | |
target_tables = from_source.setdefault(source_table, []) | |
target_tables.append(target_table) | |
graphviz_lines = ["digraph G {", ""] | |
aliases = {} | |
for i, table in enumerate(sorted(all_tables)): | |
alias = f"T{i}" | |
aliases[table] = alias | |
graphviz_lines.append(f' {alias} [label="{table}"]') | |
graphviz_lines.append("") | |
source_tables = sorted(from_source.keys()) | |
for source_table in source_tables: | |
source_alias = aliases[source_table] | |
target_tables = sorted(from_source[source_table]) | |
for target_table in target_tables: | |
target_alias = aliases[target_table] | |
graphviz_lines.append(f" {source_alias} -> {target_alias};") | |
graphviz_lines.extend(["}", ""]) # Trailing newline | |
print("[DEBUG] Saving erd.dot file...") | |
with open(HERE / "erd.dot", "w") as file_obj: | |
file_obj.write("\n".join(graphviz_lines)) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2023 - Present. Hardfin, Inc. All rights reserved. | |
# Hardfin Confidential - Restricted | |
import csv | |
import io | |
import pathlib | |
import subprocess | |
import tempfile | |
HERE = pathlib.Path(__file__).resolve().parent | |
# NOTE: This command / script **ASSUMES** all tables are involved in some reference | |
SQL_COMMAND = """\ | |
SELECT | |
tc.table_schema || '.' || tc.table_name AS source_table, | |
ccu.table_schema || '.' || ccu.table_name AS target_table | |
FROM | |
information_schema.table_constraints AS tc | |
INNER JOIN information_schema.key_column_usage AS kcu ON kcu.constraint_name = tc.constraint_name | |
INNER JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name | |
WHERE | |
tc.constraint_type = 'FOREIGN KEY' | |
""" | |
DSN = "postgres://..." | |
def get_foreign_keys(dsn): | |
with tempfile.NamedTemporaryFile("w") as file_obj: | |
file_obj.write(SQL_COMMAND) | |
file_obj.flush() | |
print("[DEBUG] Invoking psql to get schema...") | |
cmd = ("psql", "--dbname", dsn, "--file", file_obj.name, "--csv") | |
csv_rows = subprocess.check_output(cmd) | |
csv_rows = csv_rows.decode("ascii") | |
reader = csv.DictReader(io.StringIO(csv_rows)) | |
return list(reader) | |
def main(): | |
foreign_keys = get_foreign_keys(DSN) | |
print("[DEBUG] Processing foreign keys...") | |
from_source = {} | |
all_tables = set() | |
for foreign_key in foreign_keys: | |
source_table = foreign_key["source_table"] | |
target_table = foreign_key["target_table"] | |
all_tables.update([source_table, target_table]) | |
target_tables = from_source.setdefault(source_table, []) | |
target_tables.append(target_table) | |
aliases = {} | |
for i, table in enumerate(sorted(all_tables)): | |
aliases[table] = f"T{i}" | |
mermaid_lines = ["graph TD;"] | |
source_tables = sorted(from_source.keys()) | |
for source_table in source_tables: | |
source_alias = aliases[source_table] | |
target_tables = sorted(from_source[source_table]) | |
for target_table in target_tables: | |
target_alias = aliases[target_table] | |
mermaid_lines.append( | |
f" {source_alias}[{source_table}] --> {target_alias}[{target_table}];" | |
) | |
mermaid_lines.append("") # Trailing newline | |
print("[DEBUG] Saving erd.mermaid file...") | |
with open(HERE / "erd.mermaid", "w") as file_obj: | |
file_obj.write("\n".join(mermaid_lines)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment