Skip to content

Commit

Permalink
Merge pull request #607 from sinjup/main
Browse files Browse the repository at this point in the history
fix: Postgres doesn't reconnect after idle time #541
  • Loading branch information
zainhoda authored Aug 21, 2024
2 parents 7cb744a + 78c7efc commit ece3ef9
Showing 1 changed file with 33 additions and 15 deletions.
48 changes: 33 additions & 15 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ def connect_to_postgres(
port: int = None,
**kwargs
):

"""
Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
**Example:**
Expand Down Expand Up @@ -940,26 +941,44 @@ def connect_to_postgres(
except psycopg2.Error as e:
raise ValidationError(e)

def connect_to_db():
return psycopg2.connect(host=host, dbname=dbname,
user=user, password=password, port=port, **kwargs)


def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
if conn:
try:
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
conn = None
try:
conn = connect_to_db() # Initial connection attempt
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()

# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
return df
# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description])
return df

except psycopg2.InterfaceError as e:
# Attempt to reconnect and retry the operation
if conn:
conn.close() # Ensure any existing connection is closed
conn = connect_to_db()
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()

# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description])
return df

except psycopg2.Error as e:
except psycopg2.Error as e:
if conn:
conn.rollback()
raise ValidationError(e)

except Exception as e:
conn.rollback()
raise e
except Exception as e:
conn.rollback()
raise e

self.dialect = "PostgreSQL"
self.run_sql_is_set = True
Expand Down Expand Up @@ -1306,7 +1325,6 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
job = conn.query(sql)
df = job.result().to_dataframe()
return df

return None

self.dialect = "BigQuery SQL"
Expand Down

0 comments on commit ece3ef9

Please sign in to comment.