|
from typing import Optional |
|
from duckdb import DuckDBPyConnection, DuckDBPyRelation |
|
|
|
|
|
class ConnectionWrapper(): |
|
"""Connection wrapper for easier querying""" |
|
def __init__(self, con: DuckDBPyConnection): |
|
self.con = con |
|
(self.database, ) = self.con.sql("select current_database()").fetchone() |
|
|
|
def attach(self, database: str, name: str): |
|
return self.con.sql(f"ATTACH '{database}' AS {name}") |
|
|
|
def describe(self, table_name: str): |
|
return self.con.sql(f"DESCRIBE {table_name}") |
|
|
|
def summarize(self, table_name: str): |
|
return self.con.sql(f"SUMMARIZE {table_name}") |
|
|
|
def tables(self): |
|
return self.con.sql(f"SELECT database, schema, name FROM (SHOW ALL TABLES) WHERE database IN ('memdb', '{self.database}')") |
|
|
|
def table(self, table_name: str) -> DuckDBPyRelation: |
|
if "." in table_name: |
|
return self.con.sql(f"FROM {table_name}") |
|
return self.con.table(table_name) |
|
|
|
def to_table(self, relation: DuckDBPyRelation, table_name: str, replace: bool = False, temp: bool = False): |
|
or_replace = " OR REPLACE " if replace is True else " " |
|
temp_tbl = "TEMP " |
|
self.con.sql(f"CREATE{or_replace}{temp_tbl}TABLE {table_name} AS ({relation.sql_query()})") |
|
|
|
def to_temp_table(self, relation: DuckDBPyRelation, table_name: str): |
|
self.to_table(relation, table_name, replace=True, temp=True) |
|
|
|
def prompt_query(self, prompt: str): |
|
return self.con.sql(f"PRAGMA prompt_query('{prompt}')") |
|
|
|
def prompt_sql(self, prompt: str): |
|
(query,) = self.con.sql(f"CALL prompt_sql('{prompt}')").fetchone() |
|
return query |
|
|
|
def prompt(self, prompt: str): |
|
return self.con.sql(f"SELECT prompt('{prompt}')").fetchone()[0] |
|
|
|
def __getattr__(self, name: str): |
|
if hasattr(self.con, name): |
|
return getattr(self.con, name) |
|
|
|
|
|
def get_observations(con: ConnectionWrapper, common_name: str, year: int): |
|
from duckdb.experimental.spark.sql import SparkSession |
|
from duckdb.experimental.spark.sql.functions import col, month |
|
from duckdb.experimental.spark.sql.functions import year as apply_year |
|
from duckdb.experimental.spark.sql.dataframe import DataFrame |
|
|
|
spark = SparkSession.builder.appName(common_name).getOrCreate() |
|
df = DataFrame(con.table("ebd"), spark) |
|
|
|
result = ( |
|
df |
|
.filter((apply_year(col("OBSERVATION DATE")) == year) & (col("COMMON NAME") == common_name)) |
|
.orderBy(col("OBSERVATION DATE").asc()) |
|
.select( |
|
col("COMMON NAME").alias("name"), |
|
col("OBSERVATION DATE").alias("obs_dt"), |
|
month("OBSERVATION DATE").alias("month"), |
|
col("OBSERVATION COUNT").alias("count"), |
|
col("LATITUDE").alias("lat"), |
|
col("LONGITUDE").alias("lng"), |
|
col("COUNTRY").alias("country") |
|
) |
|
) |
|
|
|
con.to_temp_table(result.relation, table_name="bird_obs") |
|
|
|
# Use H3 to downsample the data for plotting |
|
con.install_extension("h3", repository="community") |
|
con.load_extension("h3") |
|
|
|
rel = ( |
|
con |
|
.table("bird_obs") |
|
.select(""" |
|
*, |
|
h3_latlng_to_cell(lat::double, lng::double, 6) AS cell_id, |
|
h3_cell_to_lat(cell_id) AS cell_lat, |
|
h3_cell_to_lng(cell_id) AS cell_lng |
|
""") |
|
) |
|
|
|
df = DataFrame(rel, spark) |
|
|
|
result = ( |
|
df |
|
.orderBy( |
|
col("obs_dt").asc() |
|
) # Order by the week number |
|
) |
|
|
|
# Export to Pandas |
|
df_result = result.toPandas() |
|
|
|
return df_result |