diff --git a/CLAUDE.md b/CLAUDE.md
index 1fb7d115..d15e5578 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -59,6 +59,14 @@ make py-build
make py-docs
```
+Before finishing your implementation or committing any code, you should run:
+
+```bash
+uv run ruff check --fix pkg-py --config pyproject.toml
+```
+
+To get help with making sure code adheres to project standards.
+
### R Package
```bash
diff --git a/pkg-py/CHANGELOG.md b/pkg-py/CHANGELOG.md
index 4d6e38f4..2f70ae3e 100644
--- a/pkg-py/CHANGELOG.md
+++ b/pkg-py/CHANGELOG.md
@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [UNRELEASED]
+### Breaking Changes
+
+* Methods like `execute_query()`, `get_data()`, and `df()` now return a `narwhals.DataFrame` instead of a `pandas.DataFrame`. This allows querychat to drop its `pandas` dependency, and for you to use any `narwhals`-compatible dataframe of your choosing.
+ * If this breaks existing code, note you can call `.to_native()` on the new dataframe value to get your `pandas` dataframe back.
+ * Note that `polars` or `pandas` will be needed to realize a `sqlalchemy` connection query as a dataframe. Install with `pip install querychat[pandas]` or `pip install querychat[polars]`
+
### New features
* `QueryChat.sidebar()`, `QueryChat.ui()`, and `QueryChat.server()` now support an optional `id` parameter to create multiple chat instances from a single `QueryChat` object. (#172)
diff --git a/pkg-py/docs/build.qmd b/pkg-py/docs/build.qmd
index 71971b24..512cec3e 100644
--- a/pkg-py/docs/build.qmd
+++ b/pkg-py/docs/build.qmd
@@ -203,7 +203,7 @@ with ui.layout_columns():
@render_plotly
def survival_plot():
- d = qc.df()
+ d = qc.df().to_native() # Convert for pandas groupby()
summary = d.groupby('pclass')['survived'].mean().reset_index()
return px.bar(summary, x='pclass', y='survived')
```
@@ -271,7 +271,7 @@ with ui.layout_columns():
@render_plotly
def survival_by_class():
- df = qc.df()
+ df = qc.df().to_native() # Convert for pandas groupby()
summary = df.groupby('pclass')['survived'].mean().reset_index()
return px.bar(
summary,
@@ -286,16 +286,14 @@ with ui.layout_columns():
@render_plotly
def age_dist():
- df = qc.df()
- return px.histogram(df, x='age', nbins=30)
+ return px.histogram(qc.df(), x='age', nbins=30)
with ui.card():
ui.card_header("Fare by Class")
@render_plotly
def fare_by_class():
- df = qc.df()
- return px.box(df, x='pclass', y='fare', color='survived')
+ return px.box(qc.df(), x='pclass', y='fare', color='survived')
ui.page_opts(
title="Titanic Survival Analysis",
@@ -461,7 +459,7 @@ with ui.layout_columns():
@render.plot
def survival_by_class():
- df = qc.df()
+ df = qc.df().to_native() # Convert for pandas groupby()
summary = df.groupby('pclass')['survived'].mean().reset_index()
fig = px.bar(
summary,
@@ -477,18 +475,14 @@ with ui.layout_columns():
@render.plot
def age_dist():
- df = qc.df()
- fig = px.histogram(df, x='age', nbins=30)
- return fig
+ return px.histogram(qc.df(), x='age', nbins=30)
with ui.card():
ui.card_header("Fare by Class")
@render.plot
def fare_by_class():
- df = qc.df()
- fig = px.box(df, x='pclass', y='fare', color='survived')
- return fig
+ return px.box(qc.df(), x='pclass', y='fare', color='survived')
# Reset button handler
@reactive.effect
diff --git a/pkg-py/docs/data-sources.qmd b/pkg-py/docs/data-sources.qmd
index 5ac97e27..327f6adf 100644
--- a/pkg-py/docs/data-sources.qmd
+++ b/pkg-py/docs/data-sources.qmd
@@ -63,7 +63,7 @@ app = qc.app()
:::
-If you're [building an app](build.qmd), note you can read the queried data frame reactively using the `df()` method, which returns a `pandas.DataFrame` by default.
+If you're [building an app](build.qmd), note you can read the queried data frame reactively using the `df()` method, which returns a `narwhals.DataFrame`. Call `.to_native()` on the result to get the underlying pandas or polars DataFrame.
## Databases
diff --git a/pkg-py/src/querychat/_datasource.py b/pkg-py/src/querychat/_datasource.py
index fd35e3c4..bb082aee 100644
--- a/pkg-py/src/querychat/_datasource.py
+++ b/pkg-py/src/querychat/_datasource.py
@@ -5,12 +5,12 @@
import duckdb
import narwhals.stable.v1 as nw
-import pandas as pd
from sqlalchemy import inspect, text
from sqlalchemy.sql import sqltypes
+from ._df_compat import duckdb_result_to_nw, read_sql
+
if TYPE_CHECKING:
- from narwhals.stable.v1.typing import IntoFrame
from sqlalchemy.engine import Connection, Engine
@@ -53,7 +53,7 @@ def get_schema(self, *, categorical_threshold: int) -> str:
...
@abstractmethod
- def execute_query(self, query: str) -> pd.DataFrame:
+ def execute_query(self, query: str) -> nw.DataFrame:
"""
Execute SQL query and return results as DataFrame.
@@ -65,20 +65,20 @@ def execute_query(self, query: str) -> pd.DataFrame:
Returns
-------
:
- Query results as a pandas DataFrame
+ Query results as a narwhals DataFrame
"""
...
@abstractmethod
- def get_data(self) -> pd.DataFrame:
+ def get_data(self) -> nw.DataFrame:
"""
Return the unfiltered data as a DataFrame.
Returns
-------
:
- The complete dataset as a pandas DataFrame
+ The complete dataset as a narwhals DataFrame
"""
...
@@ -99,27 +99,26 @@ def cleanup(self) -> None:
class DataFrameSource(DataSource):
- """A DataSource implementation that wraps a pandas DataFrame using DuckDB."""
+ """A DataSource implementation that wraps a DataFrame using DuckDB."""
- _df: nw.DataFrame | nw.LazyFrame
+ _df: nw.DataFrame
- def __init__(self, df: IntoFrame, table_name: str):
+ def __init__(self, df: nw.DataFrame, table_name: str):
"""
- Initialize with a pandas DataFrame.
+ Initialize with a DataFrame.
Parameters
----------
df
- The DataFrame to wrap
+ The DataFrame to wrap (pandas, polars, or any narwhals-compatible frame)
table_name
Name of the table in SQL queries
"""
self._conn = duckdb.connect(database=":memory:")
- self._df = nw.from_native(df)
+ self._df = nw.from_native(df) if not isinstance(df, nw.DataFrame) else df
self.table_name = table_name
- # TODO(@gadenbuie): If the data frame is already SQL-backed, maybe we shouldn't be making a new copy here.
- self._conn.register(table_name, self._df.lazy().collect().to_pandas())
+ self._conn.register(table_name, self._df.to_native())
def get_db_type(self) -> str:
"""
@@ -151,16 +150,8 @@ def get_schema(self, *, categorical_threshold: int) -> str:
"""
schema = [f"Table: {self.table_name}", "Columns:"]
- # Ensure we're working with a DataFrame, not a LazyFrame
- ndf = (
- self._df.head(10).collect()
- if isinstance(self._df, nw.LazyFrame)
- else self._df
- )
-
- for column in ndf.columns:
- # Map pandas dtypes to SQL-like types
- dtype = ndf[column].dtype
+ for column in self._df.columns:
+ dtype = self._df[column].dtype
if dtype.is_integer():
sql_type = "INTEGER"
elif dtype.is_float():
@@ -176,17 +167,14 @@ def get_schema(self, *, categorical_threshold: int) -> str:
column_info = [f"- {column} ({sql_type})"]
- # For TEXT columns, check if they're categorical
if sql_type == "TEXT":
- unique_values = ndf[column].drop_nulls().unique()
+ unique_values = self._df[column].drop_nulls().unique()
if unique_values.len() <= categorical_threshold:
categories = unique_values.to_list()
categories_str = ", ".join([f"'{c}'" for c in categories])
column_info.append(f" Categorical values: {categories_str}")
-
- # For numeric columns, include range
elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]:
- rng = ndf[column].min(), ndf[column].max()
+ rng = self._df[column].min(), self._df[column].max()
if rng[0] is None and rng[1] is None:
column_info.append(" Range: NULL to NULL")
else:
@@ -196,10 +184,12 @@ def get_schema(self, *, categorical_threshold: int) -> str:
return "\n".join(schema)
- def execute_query(self, query: str) -> pd.DataFrame:
+ def execute_query(self, query: str) -> nw.DataFrame:
"""
Execute query using DuckDB.
+ Uses polars if available, otherwise falls back to pandas.
+
Parameters
----------
query
@@ -208,23 +198,22 @@ def execute_query(self, query: str) -> pd.DataFrame:
Returns
-------
:
- Query results as pandas DataFrame
+ Query results as narwhals DataFrame
"""
- return self._conn.execute(query).df()
+ return duckdb_result_to_nw(self._conn.execute(query))
- def get_data(self) -> pd.DataFrame:
+ def get_data(self) -> nw.DataFrame:
"""
Return the unfiltered data as a DataFrame.
Returns
-------
:
- The complete dataset as a pandas DataFrame
+ The complete dataset as a narwhals DataFrame
"""
- # TODO(@gadenbuie): This should just return `self._df` and not a pandas DataFrame
- return self._df.lazy().collect().to_pandas()
+ return self._df
def cleanup(self) -> None:
"""
@@ -412,10 +401,12 @@ def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912
return "\n".join(schema)
- def execute_query(self, query: str) -> pd.DataFrame:
+ def execute_query(self, query: str) -> nw.DataFrame:
"""
Execute SQL query and return results as DataFrame.
+ Uses polars if available, otherwise falls back to pandas.
+
Parameters
----------
query
@@ -424,20 +415,20 @@ def execute_query(self, query: str) -> pd.DataFrame:
Returns
-------
:
- Query results as pandas DataFrame
+ Query results as narwhals DataFrame
"""
with self._get_connection() as conn:
- return pd.read_sql_query(text(query), conn)
+ return read_sql(text(query), conn)
- def get_data(self) -> pd.DataFrame:
+ def get_data(self) -> nw.DataFrame:
"""
Return the unfiltered data as a DataFrame.
Returns
-------
:
- The complete dataset as a pandas DataFrame
+ The complete dataset as a narwhals DataFrame
"""
return self.execute_query(f"SELECT * FROM {self.table_name}")
diff --git a/pkg-py/src/querychat/_df_compat.py b/pkg-py/src/querychat/_df_compat.py
new file mode 100644
index 00000000..bda2748b
--- /dev/null
+++ b/pkg-py/src/querychat/_df_compat.py
@@ -0,0 +1,74 @@
+"""
+DataFrame compatibility: try polars first, fall back to pandas.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import narwhals.stable.v1 as nw
+
+if TYPE_CHECKING:
+ import duckdb
+ from sqlalchemy.engine import Connection
+ from sqlalchemy.sql.elements import TextClause
+
+_INSTALL_MSG = "Install one with: pip install polars OR pip install pandas"
+
+
+def read_sql(query: TextClause, conn: Connection) -> nw.DataFrame:
+ try:
+ import polars as pl # noqa: PLC0415 # pyright: ignore[reportMissingImports]
+
+ return nw.from_native(pl.read_database(query, connection=conn))
+ except Exception: # noqa: S110
+ # Catches ImportError for polars, and other errors (e.g., missing pyarrow)
+ # Intentional fallback to pandas - no logging needed
+ pass
+
+ try:
+ import pandas as pd # noqa: PLC0415 # pyright: ignore[reportMissingImports]
+
+ return nw.from_native(pd.read_sql_query(query, conn))
+ except ImportError:
+ pass
+
+ raise ImportError(f"SQLAlchemySource requires 'polars' or 'pandas'. {_INSTALL_MSG}")
+
+
+def duckdb_result_to_nw(
+ result: duckdb.DuckDBPyRelation | duckdb.DuckDBPyConnection,
+) -> nw.DataFrame:
+ try:
+ return nw.from_native(result.pl())
+ except Exception: # noqa: S110
+ # Catches ImportError for polars, and other errors (e.g., missing pyarrow)
+ # Intentional fallback to pandas - no logging needed
+ pass
+
+ try:
+ return nw.from_native(result.df())
+ except ImportError:
+ pass
+
+ raise ImportError(f"DataFrameSource requires 'polars' or 'pandas'. {_INSTALL_MSG}")
+
+
+def read_csv(path: str) -> nw.DataFrame:
+ try:
+ import polars as pl # noqa: PLC0415 # pyright: ignore[reportMissingImports]
+
+ return nw.from_native(pl.read_csv(path))
+ except Exception: # noqa: S110
+ # Catches ImportError for polars, and other errors (e.g., missing pyarrow)
+ # Intentional fallback to pandas - no logging needed
+ pass
+
+ try:
+ import pandas as pd # noqa: PLC0415 # pyright: ignore[reportMissingImports]
+
+ return nw.from_native(pd.read_csv(path, compression="gzip"))
+ except ImportError:
+ pass
+
+ raise ImportError(f"Loading data requires 'polars' or 'pandas'. {_INSTALL_MSG}")
diff --git a/pkg-py/src/querychat/_querychat.py b/pkg-py/src/querychat/_querychat.py
index 3a83b090..f6374971 100644
--- a/pkg-py/src/querychat/_querychat.py
+++ b/pkg-py/src/querychat/_querychat.py
@@ -8,6 +8,7 @@
import chatlas
import chevron
+import narwhals.stable.v1 as nw
import sqlalchemy
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
from shiny.express._stub_session import ExpressStubSession
@@ -29,8 +30,7 @@
if TYPE_CHECKING:
from collections.abc import Callable
- import pandas as pd
- from narwhals.stable.v1.typing import IntoFrame
+ from narwhals.typing import IntoFrame
TOOL_GROUPS = Literal["update", "query"]
@@ -797,14 +797,14 @@ def __init__(
enable_bookmarking=enable,
)
- def df(self) -> pd.DataFrame:
+ def df(self) -> nw.DataFrame:
"""
Reactively read the current filtered data frame that is in effect.
Returns
-------
:
- The current filtered data frame as a pandas DataFrame. If no query
+ The current filtered data frame as a narwhals DataFrame. If no query
has been set, this will return the unfiltered data frame from the
data source.
@@ -883,7 +883,16 @@ def normalize_data_source(
return data_source
if isinstance(data_source, sqlalchemy.Engine):
return SQLAlchemySource(data_source, table_name)
- return DataFrameSource(data_source, table_name)
+ src = nw.from_native(data_source, pass_through=True)
+ if isinstance(src, nw.DataFrame):
+ return DataFrameSource(src, table_name)
+ if isinstance(src, nw.LazyFrame):
+ raise NotImplementedError("LazyFrame data sources are not yet supported (they will be soon).")
+ raise TypeError(
+ f"Unsupported data source type: {type(data_source)}."
+ "If you believe this type should be supported, please open an issue at "
+ "https://github.com/posit-dev/querychat/issues"
+ )
def as_querychat_client(client: str | chatlas.Chat | None) -> chatlas.Chat:
diff --git a/pkg-py/src/querychat/_querychat_module.py b/pkg-py/src/querychat/_querychat_module.py
index f2bed066..320dcd53 100644
--- a/pkg-py/src/querychat/_querychat_module.py
+++ b/pkg-py/src/querychat/_querychat_module.py
@@ -15,7 +15,7 @@
if TYPE_CHECKING:
from collections.abc import Callable
- import pandas as pd
+ import narwhals.stable.v1 as nw
from shiny import Inputs, Outputs, Session
from shiny.bookmark import BookmarkState, RestoreState
@@ -78,7 +78,7 @@ class ServerValues:
"""
- df: Callable[[], pd.DataFrame]
+ df: Callable[[], nw.DataFrame]
sql: ReactiveStringOrNone
title: ReactiveStringOrNone
client: chatlas.Chat
@@ -182,14 +182,14 @@ def _():
@session.bookmark.on_bookmark
def _on_bookmark(x: BookmarkState) -> None:
- vals = x.values # noqa: PD011
+ vals = x.values
vals["querychat_sql"] = sql.get()
vals["querychat_title"] = title.get()
vals["querychat_has_greeted"] = has_greeted.get()
@session.bookmark.on_restore
def _on_restore(x: RestoreState) -> None:
- vals = x.values # noqa: PD011
+ vals = x.values
if "querychat_sql" in vals:
sql.set(vals["querychat_sql"])
if "querychat_title" in vals:
diff --git a/pkg-py/src/querychat/_utils.py b/pkg-py/src/querychat/_utils.py
index 1d90dfd9..56f2ec37 100644
--- a/pkg-py/src/querychat/_utils.py
+++ b/pkg-py/src/querychat/_utils.py
@@ -122,6 +122,17 @@ def querychat_tool_starts_open(action: Literal["update", "query", "reset"]) -> b
return action != "reset"
+def _escape_html(s: str) -> str:
+ """Escape HTML special characters."""
+ return (
+ str(s)
+ .replace("&", "&")
+ .replace("<", "<")
+ .replace(">", ">")
+ .replace('"', """)
+ )
+
+
def df_to_html(df: IntoFrame, maxrows: int = 5) -> str:
"""
Convert a DataFrame to an HTML table for display in chat.
@@ -149,11 +160,30 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str:
"Must be able to convert `df` into a Narwhals DataFrame or LazyFrame",
)
- # Generate HTML table
- table_html = df_short.to_pandas().to_html(
- index=False,
- classes="table table-striped",
- )
+ # Generate HTML table directly from narwhals DataFrame
+ columns = df_short.columns
+ rows = df_short.rows()
+
+ # Build HTML table
+ html_parts = ['
']
+
+ # Header
+ html_parts.append(" ")
+ html_parts.append(' ')
+ html_parts.extend(f" | {_escape_html(col)} | " for col in columns)
+ html_parts.append("
")
+ html_parts.append(" ")
+
+ # Body
+ html_parts.append(" ")
+ for row in rows:
+ html_parts.append(" ")
+ html_parts.extend(f" | {_escape_html(str(val))} | " for val in row)
+ html_parts.append("
")
+ html_parts.append(" ")
+
+ html_parts.append("
")
+ table_html = "\n".join(html_parts)
# Add note about truncated rows if needed
if len(df_short) != nrow_full:
diff --git a/pkg-py/src/querychat/data/__init__.py b/pkg-py/src/querychat/data/__init__.py
index fe9bec96..867a326f 100644
--- a/pkg-py/src/querychat/data/__init__.py
+++ b/pkg-py/src/querychat/data/__init__.py
@@ -8,11 +8,15 @@
from __future__ import annotations
from importlib.resources import files
+from typing import TYPE_CHECKING
-import pandas as pd
+from querychat._df_compat import read_csv
+if TYPE_CHECKING:
+ import narwhals.stable.v1 as nw
-def titanic() -> pd.DataFrame:
+
+def titanic() -> nw.DataFrame:
"""
Load the Titanic dataset.
@@ -21,8 +25,9 @@ def titanic() -> pd.DataFrame:
Returns
-------
- pandas.DataFrame
- A DataFrame with 891 rows and 15 columns containing Titanic passenger data.
+ :
+ A narwhals DataFrame with 891 rows and 15 columns containing Titanic
+ passenger data.
Examples
--------
@@ -35,10 +40,10 @@ def titanic() -> pd.DataFrame:
"""
# Get the path to the gzipped CSV file using importlib.resources
data_file = files("querychat.data") / "titanic.csv.gz"
- return pd.read_csv(str(data_file), compression="gzip")
+ return read_csv(str(data_file))
-def tips() -> pd.DataFrame:
+def tips() -> nw.DataFrame:
"""
Load the tips dataset.
@@ -48,8 +53,9 @@ def tips() -> pd.DataFrame:
Returns
-------
- pandas.DataFrame
- A DataFrame with 244 rows and 7 columns containing restaurant tip data.
+ :
+ A narwhals DataFrame with 244 rows and 7 columns containing restaurant
+ tip data.
Examples
--------
@@ -62,7 +68,7 @@ def tips() -> pd.DataFrame:
"""
# Get the path to the gzipped CSV file using importlib.resources
data_file = files("querychat.data") / "tips.csv.gz"
- return pd.read_csv(str(data_file), compression="gzip")
+ return read_csv(str(data_file))
__all__ = ["tips", "titanic"]
diff --git a/pkg-py/src/querychat/tools.py b/pkg-py/src/querychat/tools.py
index e0001de1..7d06b73a 100644
--- a/pkg-py/src/querychat/tools.py
+++ b/pkg-py/src/querychat/tools.py
@@ -218,7 +218,7 @@ def query(query: str, _intent: str = "") -> ContentToolResult:
try:
result_df = data_source.execute_query(query)
- value = result_df.to_dict(orient="records")
+ value = result_df.rows(named=True)
# Format table results
tbl_html = df_to_html(result_df, maxrows=5)
diff --git a/pkg-py/tests/test_data.py b/pkg-py/tests/test_data.py
index 0ee2f8f8..6128a779 100644
--- a/pkg-py/tests/test_data.py
+++ b/pkg-py/tests/test_data.py
@@ -1,13 +1,13 @@
"""Tests for the querychat.data module."""
-import pandas as pd
+import narwhals.stable.v1 as nw
from querychat.data import tips, titanic
def test_titanic_returns_dataframe():
- """Test that titanic() returns a pandas DataFrame."""
+ """Test that titanic() returns a narwhals DataFrame."""
df = titanic()
- assert isinstance(df, pd.DataFrame)
+ assert isinstance(df, nw.DataFrame)
def test_titanic_has_expected_shape():
@@ -44,16 +44,19 @@ def test_titanic_data_integrity():
df = titanic()
# Check that survived column has only 0 and 1 values
- assert set(df["survived"].dropna().unique()) <= {0, 1}
+ unique_survived = set(df["survived"].drop_nulls().unique().to_list())
+ assert unique_survived <= {0, 1}
# Check that pclass has only 1, 2, 3
- assert set(df["pclass"].dropna().unique()) <= {1, 2, 3}
+ unique_pclass = set(df["pclass"].drop_nulls().unique().to_list())
+ assert unique_pclass <= {1, 2, 3}
# Check that sex has only 'male' and 'female'
- assert set(df["sex"].dropna().unique()) <= {"male", "female"}
+ unique_sex = set(df["sex"].drop_nulls().unique().to_list())
+ assert unique_sex <= {"male", "female"}
# Check that fare is non-negative
- assert (df["fare"].dropna() >= 0).all()
+ assert df["fare"].drop_nulls().min() >= 0
def test_titanic_creates_new_copy():
@@ -64,14 +67,15 @@ def test_titanic_creates_new_copy():
# They should not be the same object
assert df1 is not df2
- # But they should have the same data
- assert df1.equals(df2)
+ # But they should have the same shape and columns
+ assert df1.shape == df2.shape
+ assert list(df1.columns) == list(df2.columns)
def test_tips_returns_dataframe():
- """Test that tips() returns a pandas DataFrame."""
+ """Test that tips() returns a narwhals DataFrame."""
df = tips()
- assert isinstance(df, pd.DataFrame)
+ assert isinstance(df, nw.DataFrame)
def test_tips_has_expected_shape():
@@ -100,19 +104,21 @@ def test_tips_data_integrity():
df = tips()
# Check that total_bill is positive
- assert (df["total_bill"] > 0).all()
+ assert df["total_bill"].min() > 0
# Check that tip is non-negative
- assert (df["tip"] >= 0).all()
+ assert df["tip"].min() >= 0
# Check that sex has only expected values
- assert set(df["sex"].dropna().unique()) <= {"Male", "Female"}
+ unique_sex = set(df["sex"].drop_nulls().unique().to_list())
+ assert unique_sex <= {"Male", "Female"}
# Check that smoker has only expected values
- assert set(df["smoker"].dropna().unique()) <= {"Yes", "No"}
+ unique_smoker = set(df["smoker"].drop_nulls().unique().to_list())
+ assert unique_smoker <= {"Yes", "No"}
# Check that size is positive
- assert (df["size"] > 0).all()
+ assert df["size"].min() > 0
def test_tips_creates_new_copy():
@@ -123,5 +129,6 @@ def test_tips_creates_new_copy():
# They should not be the same object
assert df1 is not df2
- # But they should have the same data
- assert df1.equals(df2)
+ # But they should have the same shape and columns
+ assert df1.shape == df2.shape
+ assert list(df1.columns) == list(df2.columns)
diff --git a/pkg-py/tests/test_dataframe_source.py b/pkg-py/tests/test_dataframe_source.py
new file mode 100644
index 00000000..7602583a
--- /dev/null
+++ b/pkg-py/tests/test_dataframe_source.py
@@ -0,0 +1,284 @@
+"""Tests for the DataFrameSource class with narwhals compatibility."""
+
+import duckdb
+import narwhals.stable.v1 as nw
+import pandas as pd
+import pytest
+from querychat._datasource import DataFrameSource
+
+# Check if polars and pyarrow are available (both needed for DuckDB + polars)
+try:
+ import polars as pl
+ import pyarrow as pa # noqa: F401
+
+ HAS_POLARS_WITH_PYARROW = True
+except ImportError:
+ HAS_POLARS_WITH_PYARROW = False
+ pl = None # type: ignore[assignment]
+
+
+@pytest.fixture
+def pandas_df():
+ """Create a sample pandas DataFrame."""
+ return pd.DataFrame(
+ {
+ "id": [1, 2, 3, 4, 5],
+ "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"],
+ "age": [25, 30, 35, 28, 32],
+ "salary": [50000.0, 60000.0, 70000.0, 55000.0, 65000.0],
+ "department": ["Engineering", "Sales", "Engineering", "Sales", "Engineering"],
+ }
+ )
+
+
+@pytest.fixture
+def narwhals_df(pandas_df):
+ """Create a narwhals DataFrame from pandas."""
+ return nw.from_native(pandas_df)
+
+
+class TestDataFrameSourceInit:
+ """Tests for DataFrameSource initialization."""
+
+ def test_init_with_pandas_dataframe(self, pandas_df):
+ """Test that DataFrameSource accepts a pandas DataFrame."""
+ source = DataFrameSource(pandas_df, "test_table")
+ assert source.table_name == "test_table"
+
+ def test_init_with_narwhals_dataframe(self, narwhals_df):
+ """Test that DataFrameSource accepts a narwhals DataFrame."""
+ source = DataFrameSource(narwhals_df, "test_table")
+ assert source.table_name == "test_table"
+
+ @pytest.mark.skipif(not HAS_POLARS_WITH_PYARROW, reason="polars or pyarrow not installed")
+ def test_init_with_polars_dataframe(self):
+ """Test that DataFrameSource accepts a polars DataFrame."""
+ polars_df = pl.DataFrame(
+ {
+ "id": [1, 2, 3],
+ "name": ["Alice", "Bob", "Charlie"],
+ }
+ )
+ source = DataFrameSource(polars_df, "test_table")
+ assert source.table_name == "test_table"
+
+
+class TestDataFrameSourceExecuteQuery:
+ """Tests for DataFrameSource.execute_query method."""
+
+ def test_execute_query_returns_narwhals_dataframe(self, pandas_df):
+ """Test that execute_query returns a narwhals DataFrame."""
+ source = DataFrameSource(pandas_df, "employees")
+ result = source.execute_query("SELECT * FROM employees")
+ assert isinstance(result, nw.DataFrame)
+
+ def test_execute_query_select_all(self, pandas_df):
+ """Test SELECT * query."""
+ source = DataFrameSource(pandas_df, "employees")
+ result = source.execute_query("SELECT * FROM employees")
+
+ assert result.shape == (5, 5)
+ assert set(result.columns) == {"id", "name", "age", "salary", "department"}
+
+ def test_execute_query_with_filter(self, pandas_df):
+ """Test query with WHERE clause."""
+ source = DataFrameSource(pandas_df, "employees")
+ result = source.execute_query(
+ "SELECT * FROM employees WHERE department = 'Engineering'"
+ )
+
+ assert result.shape == (3, 5)
+ departments = result["department"].unique().to_list()
+ assert departments == ["Engineering"]
+
+ def test_execute_query_with_aggregation(self, pandas_df):
+ """Test query with aggregation."""
+ source = DataFrameSource(pandas_df, "employees")
+ result = source.execute_query(
+ "SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department"
+ )
+
+ assert result.shape == (2, 2)
+ assert "department" in result.columns
+ assert "avg_salary" in result.columns
+
+ def test_execute_query_select_columns(self, pandas_df):
+ """Test selecting specific columns."""
+ source = DataFrameSource(pandas_df, "employees")
+ result = source.execute_query("SELECT name, age FROM employees")
+
+ assert result.shape == (5, 2)
+ assert list(result.columns) == ["name", "age"]
+
+ def test_execute_query_order_by(self, pandas_df):
+ """Test query with ORDER BY clause."""
+ source = DataFrameSource(pandas_df, "employees")
+ result = source.execute_query(
+ "SELECT name, age FROM employees ORDER BY age DESC"
+ )
+
+ ages = result["age"].to_list()
+ assert ages == sorted(ages, reverse=True)
+
+ def test_execute_query_empty_result(self, pandas_df):
+ """Test query that returns no rows."""
+ source = DataFrameSource(pandas_df, "employees")
+ result = source.execute_query(
+ "SELECT * FROM employees WHERE age > 100"
+ )
+
+ assert isinstance(result, nw.DataFrame)
+ assert result.shape == (0, 5)
+
+
+class TestDataFrameSourceGetData:
+ """Tests for DataFrameSource.get_data method."""
+
+ def test_get_data_returns_narwhals_dataframe(self, pandas_df):
+ """Test that get_data returns a narwhals DataFrame."""
+ source = DataFrameSource(pandas_df, "employees")
+ result = source.get_data()
+ assert isinstance(result, nw.DataFrame)
+
+ def test_get_data_returns_full_dataset(self, pandas_df):
+ """Test that get_data returns all rows."""
+ source = DataFrameSource(pandas_df, "employees")
+ result = source.get_data()
+
+ assert result.shape == pandas_df.shape
+ assert set(result.columns) == set(pandas_df.columns)
+
+ def test_get_data_preserves_data(self, pandas_df):
+ """Test that get_data preserves data values."""
+ source = DataFrameSource(pandas_df, "employees")
+ result = source.get_data()
+
+ # Check that the data matches
+ original_names = sorted(pandas_df["name"].tolist())
+ result_names = sorted(result["name"].to_list())
+ assert original_names == result_names
+
+
+class TestDataFrameSourceGetSchema:
+ """Tests for DataFrameSource.get_schema method."""
+
+ def test_get_schema_includes_table_name(self, pandas_df):
+ """Test that schema includes table name."""
+ source = DataFrameSource(pandas_df, "employees")
+ schema = source.get_schema(categorical_threshold=10)
+
+ assert "Table: employees" in schema
+ assert "Columns:" in schema
+
+ def test_get_schema_includes_all_columns(self, pandas_df):
+ """Test that schema includes all columns."""
+ source = DataFrameSource(pandas_df, "employees")
+ schema = source.get_schema(categorical_threshold=10)
+
+ for col in pandas_df.columns:
+ assert f"- {col} (" in schema
+
+ def test_get_schema_numeric_ranges(self, pandas_df):
+ """Test that numeric columns include range information."""
+ source = DataFrameSource(pandas_df, "employees")
+ schema = source.get_schema(categorical_threshold=10)
+
+ # Age should have range
+ assert "Range: 25 to 35" in schema
+ # Salary should have range
+ assert "Range: 50000.0 to 70000.0" in schema
+
+ def test_get_schema_categorical_values(self, pandas_df):
+ """Test that categorical columns show unique values."""
+ source = DataFrameSource(pandas_df, "employees")
+ schema = source.get_schema(categorical_threshold=10)
+
+ # Department has only 2 unique values, should be categorical
+ assert "Categorical values:" in schema
+ assert "'Engineering'" in schema
+ assert "'Sales'" in schema
+
+ def test_get_schema_respects_threshold(self, pandas_df):
+ """Test that categorical_threshold is respected."""
+ source = DataFrameSource(pandas_df, "employees")
+
+ # With threshold 1, no columns should be categorical
+ schema_low = source.get_schema(categorical_threshold=1)
+ # Department has 2 unique values, should not be listed as categorical
+ lines = schema_low.split("\n")
+ dept_idx = next(i for i, line in enumerate(lines) if "- department" in line)
+ if dept_idx + 1 < len(lines):
+ assert "Categorical values:" not in lines[dept_idx + 1]
+
+ # With threshold 5, department should be categorical
+ schema_high = source.get_schema(categorical_threshold=5)
+ assert "'Engineering'" in schema_high
+
+
+class TestDataFrameSourceDbType:
+ """Tests for DataFrameSource.get_db_type method."""
+
+ def test_get_db_type_returns_duckdb(self, pandas_df):
+ """Test that get_db_type returns 'DuckDB'."""
+ source = DataFrameSource(pandas_df, "employees")
+ assert source.get_db_type() == "DuckDB"
+
+
+class TestDataFrameSourceCleanup:
+ """Tests for DataFrameSource.cleanup method."""
+
+ def test_cleanup_closes_connection(self, pandas_df):
+ """Test that cleanup closes the DuckDB connection."""
+ source = DataFrameSource(pandas_df, "employees")
+
+ # Should work before cleanup
+ result = source.execute_query("SELECT * FROM employees LIMIT 1")
+ assert result.shape[0] == 1
+
+ # Cleanup
+ source.cleanup()
+
+ # After cleanup, queries should fail
+ with pytest.raises(duckdb.ConnectionException):
+ source.execute_query("SELECT * FROM employees")
+
+
+@pytest.mark.skipif(not HAS_POLARS_WITH_PYARROW, reason="polars or pyarrow not installed")
+class TestDataFrameSourceWithPolars:
+ """Tests for DataFrameSource with polars DataFrames."""
+
+ @pytest.fixture
+ def polars_df(self):
+ """Create a sample polars DataFrame."""
+ return pl.DataFrame(
+ {
+ "id": [1, 2, 3],
+ "name": ["Alice", "Bob", "Charlie"],
+ "value": [10.5, 20.5, 30.5],
+ }
+ )
+
+ def test_execute_query_with_polars(self, polars_df):
+ """Test execute_query with polars source."""
+ source = DataFrameSource(polars_df, "test_data")
+ result = source.execute_query("SELECT * FROM test_data")
+
+ assert isinstance(result, nw.DataFrame)
+ assert result.shape == (3, 3)
+
+ def test_get_data_with_polars(self, polars_df):
+ """Test get_data with polars source."""
+ source = DataFrameSource(polars_df, "test_data")
+ result = source.get_data()
+
+ assert isinstance(result, nw.DataFrame)
+ assert result.shape == polars_df.shape
+
+ def test_polars_result_backend(self, polars_df):
+ """Test that results use polars backend when input is polars."""
+ source = DataFrameSource(polars_df, "test_data")
+ result = source.execute_query("SELECT * FROM test_data")
+
+ # When polars is available, the result should use polars backend
+ native = result.to_native()
+ assert isinstance(native, pl.DataFrame)
diff --git a/pkg-py/tests/test_df_compat.py b/pkg-py/tests/test_df_compat.py
new file mode 100644
index 00000000..065efc2a
--- /dev/null
+++ b/pkg-py/tests/test_df_compat.py
@@ -0,0 +1,134 @@
+"""Tests for the _df_compat module and narwhals DataFrame compatibility."""
+
+import gzip
+import tempfile
+from pathlib import Path
+
+import duckdb
+import narwhals.stable.v1 as nw
+import pytest
+from querychat._df_compat import duckdb_result_to_nw, read_csv
+
+# Check if polars and pyarrow are available (both needed for DuckDB + polars)
+try:
+ import polars as pl
+ import pyarrow as pa # noqa: F401
+
+ HAS_POLARS_WITH_PYARROW = True
+except ImportError:
+ HAS_POLARS_WITH_PYARROW = False
+ pl = None # type: ignore[assignment]
+
+
+class TestReadCsv:
+ """
+ Tests for the read_csv function.
+
+ Note: read_csv is designed for reading gzipped CSV files (used for bundled data).
+ """
+
+ @pytest.fixture
+ def gzip_csv_file(self):
+ """Create a temporary gzipped CSV file for testing."""
+ with tempfile.NamedTemporaryFile(suffix=".csv.gz", delete=False) as f:
+ temp_path = f.name
+
+ with gzip.open(temp_path, "wt") as f:
+ f.write("id,name,value\n")
+ f.write("1,Alice,100\n")
+ f.write("2,Bob,200\n")
+ f.write("3,Charlie,300\n")
+
+ yield temp_path
+ Path(temp_path).unlink()
+
+ def test_read_csv_returns_narwhals_dataframe(self, gzip_csv_file):
+ """Test that read_csv returns a narwhals DataFrame."""
+ result = read_csv(gzip_csv_file)
+ assert isinstance(result, nw.DataFrame)
+
+ def test_read_csv_has_correct_shape(self, gzip_csv_file):
+ """Test that read_csv produces correct data."""
+ result = read_csv(gzip_csv_file)
+ assert result.shape == (3, 3)
+ assert list(result.columns) == ["id", "name", "value"]
+
+ def test_read_csv_data_integrity(self, gzip_csv_file):
+ """Test that read_csv preserves data correctly."""
+ result = read_csv(gzip_csv_file)
+ names = result["name"].to_list()
+ assert names == ["Alice", "Bob", "Charlie"]
+
+
+class TestDuckdbResultToNw:
+ """Tests for the duckdb_result_to_nw function."""
+
+ @pytest.fixture
+ def duckdb_conn(self):
+ """Create a DuckDB connection with test data."""
+ conn = duckdb.connect(":memory:")
+ conn.execute("CREATE TABLE test (id INTEGER, name VARCHAR, value DOUBLE)")
+ conn.execute("INSERT INTO test VALUES (1, 'Alice', 10.5)")
+ conn.execute("INSERT INTO test VALUES (2, 'Bob', 20.5)")
+ yield conn
+ conn.close()
+
+ def test_duckdb_result_returns_narwhals_dataframe(self, duckdb_conn):
+ """Test that duckdb_result_to_nw returns a narwhals DataFrame."""
+ result = duckdb_conn.execute("SELECT * FROM test")
+ df = duckdb_result_to_nw(result)
+ assert isinstance(df, nw.DataFrame)
+
+ def test_duckdb_result_has_correct_data(self, duckdb_conn):
+ """Test that duckdb_result_to_nw preserves data correctly."""
+ result = duckdb_conn.execute("SELECT * FROM test ORDER BY id")
+ df = duckdb_result_to_nw(result)
+
+ assert df.shape == (2, 3)
+ assert list(df.columns) == ["id", "name", "value"]
+ assert df["id"].to_list() == [1, 2]
+ assert df["name"].to_list() == ["Alice", "Bob"]
+
+ def test_duckdb_result_empty_query(self, duckdb_conn):
+ """Test handling of empty query results."""
+ result = duckdb_conn.execute("SELECT * FROM test WHERE id > 100")
+ df = duckdb_result_to_nw(result)
+
+ assert isinstance(df, nw.DataFrame)
+ assert df.shape == (0, 3)
+
+
+@pytest.mark.skipif(
+ not HAS_POLARS_WITH_PYARROW, reason="polars or pyarrow not installed"
+)
+class TestPolarsBackend:
+ """Tests that verify polars backend works correctly when available."""
+
+ def test_read_csv_uses_polars_when_available(self):
+ """Test that read_csv uses polars as the backend when available."""
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
+ f.write("x,y\n1,2\n3,4\n")
+ temp_path = f.name
+
+ try:
+ result = read_csv(temp_path)
+ # The native frame should be polars when polars is available
+ native = result.to_native()
+ assert isinstance(native, pl.DataFrame)
+ finally:
+ Path(temp_path).unlink()
+
+ def test_duckdb_result_uses_polars_when_available(self):
+ """Test that duckdb_result_to_nw uses polars when available."""
+ conn = duckdb.connect(":memory:")
+ conn.execute("CREATE TABLE t (x INTEGER)")
+ conn.execute("INSERT INTO t VALUES (1)")
+
+ result = conn.execute("SELECT * FROM t")
+ df = duckdb_result_to_nw(result)
+
+ # The native frame should be polars when polars is available
+ native = df.to_native()
+ assert isinstance(native, pl.DataFrame)
+
+ conn.close()
diff --git a/pkg-py/tests/test_init_with_pandas.py b/pkg-py/tests/test_init_with_pandas.py
index 7c182639..1179a25c 100644
--- a/pkg-py/tests/test_init_with_pandas.py
+++ b/pkg-py/tests/test_init_with_pandas.py
@@ -64,9 +64,8 @@ def test_init_with_narwhals_dataframe():
assert qc is not None
-def test_init_with_narwhals_lazyframe_direct_query():
- """Test that QueryChat() can accept a narwhals LazyFrame and execute queries."""
- # Create a pandas DataFrame and convert to narwhals LazyFrame
+def test_init_with_narwhals_lazyframe_raises():
+ """Test that QueryChat() raises TypeError for LazyFrames."""
pdf = pd.DataFrame(
{
"id": [1, 2, 3],
@@ -76,20 +75,9 @@ def test_init_with_narwhals_lazyframe_direct_query():
)
nw_lazy = nw.from_native(pdf).lazy()
- # Call QueryChat with the narwhals LazyFrame
- qc = QueryChat(
- data_source=nw_lazy, # TODO(@gadebuie): Fix this type error
- table_name="test_table",
- greeting="hello!",
- )
-
- # Verify the result is correctly configured
- assert qc is not None
- assert hasattr(qc, "data_source")
-
- # Test that we can run a query on the data source
- query_result = qc.data_source.execute_query(
- "SELECT * FROM test_table WHERE id = 2",
- )
- assert len(query_result) == 1
- assert query_result.iloc[0]["name"] == "Bob"
+ with pytest.raises(NotImplementedError, match="LazyFrame"):
+ QueryChat(
+ data_source=nw_lazy,
+ table_name="test_table",
+ greeting="hello!",
+ )
diff --git a/pkg-py/tests/test_querychat.py b/pkg-py/tests/test_querychat.py
index 22dab6d7..3c0e1e9d 100644
--- a/pkg-py/tests/test_querychat.py
+++ b/pkg-py/tests/test_querychat.py
@@ -47,7 +47,7 @@ def test_querychat_init(sample_df):
)
assert len(result) == 1
- assert result.iloc[0]["name"] == "Bob"
+ assert result.item(0, "name") == "Bob"
def test_querychat_custom_id(sample_df):
diff --git a/pyproject.toml b/pyproject.toml
index c4fa03f9..dc296fff 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,7 +21,6 @@ maintainers = [
]
dependencies = [
"duckdb",
- "pandas",
"shiny>=1.5.1",
"shinywidgets",
"htmltools",
@@ -40,6 +39,11 @@ classifiers = [
"Programming Language :: Python :: 3.14",
]
+[project.optional-dependencies]
+# For SQLAlchemySource and sample data, one of polars or pandas is required
+pandas = ["pandas"]
+polars = ["polars"]
+
[project.urls]
Homepage = "https://github.com/posit-dev/querychat" # TODO update when we have docs
Repository = "https://github.com/posit-dev/querychat"