diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index adeda0ad..e8e0bbf9 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -867,6 +867,7 @@ def connect_to_postgres( port: int = None, **kwargs ): + """ Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] **Example:** @@ -940,26 +941,44 @@ def connect_to_postgres( except psycopg2.Error as e: raise ValidationError(e) + def connect_to_db(): + return psycopg2.connect(host=host, dbname=dbname, + user=user, password=password, port=port, **kwargs) + + def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]: - if conn: - try: - cs = conn.cursor() - cs.execute(sql) - results = cs.fetchall() + conn = None + try: + conn = connect_to_db() # Initial connection attempt + cs = conn.cursor() + cs.execute(sql) + results = cs.fetchall() - # Create a pandas dataframe from the results - df = pd.DataFrame( - results, columns=[desc[0] for desc in cs.description] - ) - return df + # Create a pandas dataframe from the results + df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description]) + return df + + except psycopg2.InterfaceError as e: + # Attempt to reconnect and retry the operation + if conn: + conn.close() # Ensure any existing connection is closed + conn = connect_to_db() + cs = conn.cursor() + cs.execute(sql) + results = cs.fetchall() + + # Create a pandas dataframe from the results + df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description]) + return df - except psycopg2.Error as e: + except psycopg2.Error as e: + if conn: conn.rollback() raise ValidationError(e) - except Exception as e: - conn.rollback() - raise e + except Exception as e: + conn.rollback() + raise e self.dialect = "PostgreSQL" self.run_sql_is_set = True @@ -1306,7 +1325,6 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]: job = conn.query(sql) df = job.result().to_dataframe() return df - return None self.dialect = "BigQuery SQL"