diff --git a/CLAUDE.md b/CLAUDE.md index 1fb7d115..d15e5578 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -59,6 +59,14 @@ make py-build make py-docs ``` +Before finishing your implementation or committing any code, you should run: + +```bash +uv run ruff check --fix pkg-py --config pyproject.toml +``` + +To get help with making sure code adheres to project standards. + ### R Package ```bash diff --git a/pkg-py/CHANGELOG.md b/pkg-py/CHANGELOG.md index 4d6e38f4..2f70ae3e 100644 --- a/pkg-py/CHANGELOG.md +++ b/pkg-py/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] +### Breaking Changes + +* Methods like `execute_query()`, `get_data()`, and `df()` now return a `narwhals.DataFrame` instead of a `pandas.DataFrame`. This allows querychat to drop its `pandas` dependency, and for you to use any `narwhals`-compatible dataframe of your choosing. + * If this breaks existing code, note you can call `.to_native()` on the new dataframe value to get your `pandas` dataframe back. + * Note that `polars` or `pandas` will be needed to realize a `sqlalchemy` connection query as a dataframe. Install with `pip install querychat[pandas]` or `pip install querychat[polars]` + ### New features * `QueryChat.sidebar()`, `QueryChat.ui()`, and `QueryChat.server()` now support an optional `id` parameter to create multiple chat instances from a single `QueryChat` object. (#172) diff --git a/pkg-py/docs/build.qmd b/pkg-py/docs/build.qmd index 71971b24..512cec3e 100644 --- a/pkg-py/docs/build.qmd +++ b/pkg-py/docs/build.qmd @@ -203,7 +203,7 @@ with ui.layout_columns(): @render_plotly def survival_plot(): - d = qc.df() + d = qc.df().to_native() # Convert for pandas groupby() summary = d.groupby('pclass')['survived'].mean().reset_index() return px.bar(summary, x='pclass', y='survived') ``` @@ -271,7 +271,7 @@ with ui.layout_columns(): @render_plotly def survival_by_class(): - df = qc.df() + df = qc.df().to_native() # Convert for pandas groupby() summary = df.groupby('pclass')['survived'].mean().reset_index() return px.bar( summary, @@ -286,16 +286,14 @@ with ui.layout_columns(): @render_plotly def age_dist(): - df = qc.df() - return px.histogram(df, x='age', nbins=30) + return px.histogram(qc.df(), x='age', nbins=30) with ui.card(): ui.card_header("Fare by Class") @render_plotly def fare_by_class(): - df = qc.df() - return px.box(df, x='pclass', y='fare', color='survived') + return px.box(qc.df(), x='pclass', y='fare', color='survived') ui.page_opts( title="Titanic Survival Analysis", @@ -461,7 +459,7 @@ with ui.layout_columns(): @render.plot def survival_by_class(): - df = qc.df() + df = qc.df().to_native() # Convert for pandas groupby() summary = df.groupby('pclass')['survived'].mean().reset_index() fig = px.bar( summary, @@ -477,18 +475,14 @@ with ui.layout_columns(): @render.plot def age_dist(): - df = qc.df() - fig = px.histogram(df, x='age', nbins=30) - return fig + return px.histogram(qc.df(), x='age', nbins=30) with ui.card(): ui.card_header("Fare by Class") @render.plot def fare_by_class(): - df = qc.df() - fig = px.box(df, x='pclass', y='fare', color='survived') - return fig + return px.box(qc.df(), x='pclass', y='fare', color='survived') # Reset button handler @reactive.effect diff --git a/pkg-py/docs/data-sources.qmd b/pkg-py/docs/data-sources.qmd index 5ac97e27..327f6adf 100644 --- a/pkg-py/docs/data-sources.qmd +++ b/pkg-py/docs/data-sources.qmd @@ -63,7 +63,7 @@ app = qc.app() ::: -If you're [building an app](build.qmd), note you can read the queried data frame reactively using the `df()` method, which returns a `pandas.DataFrame` by default. +If you're [building an app](build.qmd), note you can read the queried data frame reactively using the `df()` method, which returns a `narwhals.DataFrame`. Call `.to_native()` on the result to get the underlying pandas or polars DataFrame. ## Databases diff --git a/pkg-py/src/querychat/_datasource.py b/pkg-py/src/querychat/_datasource.py index fd35e3c4..bb082aee 100644 --- a/pkg-py/src/querychat/_datasource.py +++ b/pkg-py/src/querychat/_datasource.py @@ -5,12 +5,12 @@ import duckdb import narwhals.stable.v1 as nw -import pandas as pd from sqlalchemy import inspect, text from sqlalchemy.sql import sqltypes +from ._df_compat import duckdb_result_to_nw, read_sql + if TYPE_CHECKING: - from narwhals.stable.v1.typing import IntoFrame from sqlalchemy.engine import Connection, Engine @@ -53,7 +53,7 @@ def get_schema(self, *, categorical_threshold: int) -> str: ... @abstractmethod - def execute_query(self, query: str) -> pd.DataFrame: + def execute_query(self, query: str) -> nw.DataFrame: """ Execute SQL query and return results as DataFrame. @@ -65,20 +65,20 @@ def execute_query(self, query: str) -> pd.DataFrame: Returns ------- : - Query results as a pandas DataFrame + Query results as a narwhals DataFrame """ ... @abstractmethod - def get_data(self) -> pd.DataFrame: + def get_data(self) -> nw.DataFrame: """ Return the unfiltered data as a DataFrame. Returns ------- : - The complete dataset as a pandas DataFrame + The complete dataset as a narwhals DataFrame """ ... @@ -99,27 +99,26 @@ def cleanup(self) -> None: class DataFrameSource(DataSource): - """A DataSource implementation that wraps a pandas DataFrame using DuckDB.""" + """A DataSource implementation that wraps a DataFrame using DuckDB.""" - _df: nw.DataFrame | nw.LazyFrame + _df: nw.DataFrame - def __init__(self, df: IntoFrame, table_name: str): + def __init__(self, df: nw.DataFrame, table_name: str): """ - Initialize with a pandas DataFrame. + Initialize with a DataFrame. Parameters ---------- df - The DataFrame to wrap + The DataFrame to wrap (pandas, polars, or any narwhals-compatible frame) table_name Name of the table in SQL queries """ self._conn = duckdb.connect(database=":memory:") - self._df = nw.from_native(df) + self._df = nw.from_native(df) if not isinstance(df, nw.DataFrame) else df self.table_name = table_name - # TODO(@gadenbuie): If the data frame is already SQL-backed, maybe we shouldn't be making a new copy here. - self._conn.register(table_name, self._df.lazy().collect().to_pandas()) + self._conn.register(table_name, self._df.to_native()) def get_db_type(self) -> str: """ @@ -151,16 +150,8 @@ def get_schema(self, *, categorical_threshold: int) -> str: """ schema = [f"Table: {self.table_name}", "Columns:"] - # Ensure we're working with a DataFrame, not a LazyFrame - ndf = ( - self._df.head(10).collect() - if isinstance(self._df, nw.LazyFrame) - else self._df - ) - - for column in ndf.columns: - # Map pandas dtypes to SQL-like types - dtype = ndf[column].dtype + for column in self._df.columns: + dtype = self._df[column].dtype if dtype.is_integer(): sql_type = "INTEGER" elif dtype.is_float(): @@ -176,17 +167,14 @@ def get_schema(self, *, categorical_threshold: int) -> str: column_info = [f"- {column} ({sql_type})"] - # For TEXT columns, check if they're categorical if sql_type == "TEXT": - unique_values = ndf[column].drop_nulls().unique() + unique_values = self._df[column].drop_nulls().unique() if unique_values.len() <= categorical_threshold: categories = unique_values.to_list() categories_str = ", ".join([f"'{c}'" for c in categories]) column_info.append(f" Categorical values: {categories_str}") - - # For numeric columns, include range elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]: - rng = ndf[column].min(), ndf[column].max() + rng = self._df[column].min(), self._df[column].max() if rng[0] is None and rng[1] is None: column_info.append(" Range: NULL to NULL") else: @@ -196,10 +184,12 @@ def get_schema(self, *, categorical_threshold: int) -> str: return "\n".join(schema) - def execute_query(self, query: str) -> pd.DataFrame: + def execute_query(self, query: str) -> nw.DataFrame: """ Execute query using DuckDB. + Uses polars if available, otherwise falls back to pandas. + Parameters ---------- query @@ -208,23 +198,22 @@ def execute_query(self, query: str) -> pd.DataFrame: Returns ------- : - Query results as pandas DataFrame + Query results as narwhals DataFrame """ - return self._conn.execute(query).df() + return duckdb_result_to_nw(self._conn.execute(query)) - def get_data(self) -> pd.DataFrame: + def get_data(self) -> nw.DataFrame: """ Return the unfiltered data as a DataFrame. Returns ------- : - The complete dataset as a pandas DataFrame + The complete dataset as a narwhals DataFrame """ - # TODO(@gadenbuie): This should just return `self._df` and not a pandas DataFrame - return self._df.lazy().collect().to_pandas() + return self._df def cleanup(self) -> None: """ @@ -412,10 +401,12 @@ def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 return "\n".join(schema) - def execute_query(self, query: str) -> pd.DataFrame: + def execute_query(self, query: str) -> nw.DataFrame: """ Execute SQL query and return results as DataFrame. + Uses polars if available, otherwise falls back to pandas. + Parameters ---------- query @@ -424,20 +415,20 @@ def execute_query(self, query: str) -> pd.DataFrame: Returns ------- : - Query results as pandas DataFrame + Query results as narwhals DataFrame """ with self._get_connection() as conn: - return pd.read_sql_query(text(query), conn) + return read_sql(text(query), conn) - def get_data(self) -> pd.DataFrame: + def get_data(self) -> nw.DataFrame: """ Return the unfiltered data as a DataFrame. Returns ------- : - The complete dataset as a pandas DataFrame + The complete dataset as a narwhals DataFrame """ return self.execute_query(f"SELECT * FROM {self.table_name}") diff --git a/pkg-py/src/querychat/_df_compat.py b/pkg-py/src/querychat/_df_compat.py new file mode 100644 index 00000000..bda2748b --- /dev/null +++ b/pkg-py/src/querychat/_df_compat.py @@ -0,0 +1,74 @@ +""" +DataFrame compatibility: try polars first, fall back to pandas. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import narwhals.stable.v1 as nw + +if TYPE_CHECKING: + import duckdb + from sqlalchemy.engine import Connection + from sqlalchemy.sql.elements import TextClause + +_INSTALL_MSG = "Install one with: pip install polars OR pip install pandas" + + +def read_sql(query: TextClause, conn: Connection) -> nw.DataFrame: + try: + import polars as pl # noqa: PLC0415 # pyright: ignore[reportMissingImports] + + return nw.from_native(pl.read_database(query, connection=conn)) + except Exception: # noqa: S110 + # Catches ImportError for polars, and other errors (e.g., missing pyarrow) + # Intentional fallback to pandas - no logging needed + pass + + try: + import pandas as pd # noqa: PLC0415 # pyright: ignore[reportMissingImports] + + return nw.from_native(pd.read_sql_query(query, conn)) + except ImportError: + pass + + raise ImportError(f"SQLAlchemySource requires 'polars' or 'pandas'. {_INSTALL_MSG}") + + +def duckdb_result_to_nw( + result: duckdb.DuckDBPyRelation | duckdb.DuckDBPyConnection, +) -> nw.DataFrame: + try: + return nw.from_native(result.pl()) + except Exception: # noqa: S110 + # Catches ImportError for polars, and other errors (e.g., missing pyarrow) + # Intentional fallback to pandas - no logging needed + pass + + try: + return nw.from_native(result.df()) + except ImportError: + pass + + raise ImportError(f"DataFrameSource requires 'polars' or 'pandas'. {_INSTALL_MSG}") + + +def read_csv(path: str) -> nw.DataFrame: + try: + import polars as pl # noqa: PLC0415 # pyright: ignore[reportMissingImports] + + return nw.from_native(pl.read_csv(path)) + except Exception: # noqa: S110 + # Catches ImportError for polars, and other errors (e.g., missing pyarrow) + # Intentional fallback to pandas - no logging needed + pass + + try: + import pandas as pd # noqa: PLC0415 # pyright: ignore[reportMissingImports] + + return nw.from_native(pd.read_csv(path, compression="gzip")) + except ImportError: + pass + + raise ImportError(f"Loading data requires 'polars' or 'pandas'. {_INSTALL_MSG}") diff --git a/pkg-py/src/querychat/_querychat.py b/pkg-py/src/querychat/_querychat.py index 3a83b090..f6374971 100644 --- a/pkg-py/src/querychat/_querychat.py +++ b/pkg-py/src/querychat/_querychat.py @@ -8,6 +8,7 @@ import chatlas import chevron +import narwhals.stable.v1 as nw import sqlalchemy from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui from shiny.express._stub_session import ExpressStubSession @@ -29,8 +30,7 @@ if TYPE_CHECKING: from collections.abc import Callable - import pandas as pd - from narwhals.stable.v1.typing import IntoFrame + from narwhals.typing import IntoFrame TOOL_GROUPS = Literal["update", "query"] @@ -797,14 +797,14 @@ def __init__( enable_bookmarking=enable, ) - def df(self) -> pd.DataFrame: + def df(self) -> nw.DataFrame: """ Reactively read the current filtered data frame that is in effect. Returns ------- : - The current filtered data frame as a pandas DataFrame. If no query + The current filtered data frame as a narwhals DataFrame. If no query has been set, this will return the unfiltered data frame from the data source. @@ -883,7 +883,16 @@ def normalize_data_source( return data_source if isinstance(data_source, sqlalchemy.Engine): return SQLAlchemySource(data_source, table_name) - return DataFrameSource(data_source, table_name) + src = nw.from_native(data_source, pass_through=True) + if isinstance(src, nw.DataFrame): + return DataFrameSource(src, table_name) + if isinstance(src, nw.LazyFrame): + raise NotImplementedError("LazyFrame data sources are not yet supported (they will be soon).") + raise TypeError( + f"Unsupported data source type: {type(data_source)}." + "If you believe this type should be supported, please open an issue at " + "https://github.com/posit-dev/querychat/issues" + ) def as_querychat_client(client: str | chatlas.Chat | None) -> chatlas.Chat: diff --git a/pkg-py/src/querychat/_querychat_module.py b/pkg-py/src/querychat/_querychat_module.py index f2bed066..320dcd53 100644 --- a/pkg-py/src/querychat/_querychat_module.py +++ b/pkg-py/src/querychat/_querychat_module.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from collections.abc import Callable - import pandas as pd + import narwhals.stable.v1 as nw from shiny import Inputs, Outputs, Session from shiny.bookmark import BookmarkState, RestoreState @@ -78,7 +78,7 @@ class ServerValues: """ - df: Callable[[], pd.DataFrame] + df: Callable[[], nw.DataFrame] sql: ReactiveStringOrNone title: ReactiveStringOrNone client: chatlas.Chat @@ -182,14 +182,14 @@ def _(): @session.bookmark.on_bookmark def _on_bookmark(x: BookmarkState) -> None: - vals = x.values # noqa: PD011 + vals = x.values vals["querychat_sql"] = sql.get() vals["querychat_title"] = title.get() vals["querychat_has_greeted"] = has_greeted.get() @session.bookmark.on_restore def _on_restore(x: RestoreState) -> None: - vals = x.values # noqa: PD011 + vals = x.values if "querychat_sql" in vals: sql.set(vals["querychat_sql"]) if "querychat_title" in vals: diff --git a/pkg-py/src/querychat/_utils.py b/pkg-py/src/querychat/_utils.py index 1d90dfd9..56f2ec37 100644 --- a/pkg-py/src/querychat/_utils.py +++ b/pkg-py/src/querychat/_utils.py @@ -122,6 +122,17 @@ def querychat_tool_starts_open(action: Literal["update", "query", "reset"]) -> b return action != "reset" +def _escape_html(s: str) -> str: + """Escape HTML special characters.""" + return ( + str(s) + .replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + ) + + def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: """ Convert a DataFrame to an HTML table for display in chat. @@ -149,11 +160,30 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: "Must be able to convert `df` into a Narwhals DataFrame or LazyFrame", ) - # Generate HTML table - table_html = df_short.to_pandas().to_html( - index=False, - classes="table table-striped", - ) + # Generate HTML table directly from narwhals DataFrame + columns = df_short.columns + rows = df_short.rows() + + # Build HTML table + html_parts = [''] + + # Header + html_parts.append(" ") + html_parts.append(' ') + html_parts.extend(f" " for col in columns) + html_parts.append(" ") + html_parts.append(" ") + + # Body + html_parts.append(" ") + for row in rows: + html_parts.append(" ") + html_parts.extend(f" " for val in row) + html_parts.append(" ") + html_parts.append(" ") + + html_parts.append("
{_escape_html(col)}
{_escape_html(str(val))}
") + table_html = "\n".join(html_parts) # Add note about truncated rows if needed if len(df_short) != nrow_full: diff --git a/pkg-py/src/querychat/data/__init__.py b/pkg-py/src/querychat/data/__init__.py index fe9bec96..867a326f 100644 --- a/pkg-py/src/querychat/data/__init__.py +++ b/pkg-py/src/querychat/data/__init__.py @@ -8,11 +8,15 @@ from __future__ import annotations from importlib.resources import files +from typing import TYPE_CHECKING -import pandas as pd +from querychat._df_compat import read_csv +if TYPE_CHECKING: + import narwhals.stable.v1 as nw -def titanic() -> pd.DataFrame: + +def titanic() -> nw.DataFrame: """ Load the Titanic dataset. @@ -21,8 +25,9 @@ def titanic() -> pd.DataFrame: Returns ------- - pandas.DataFrame - A DataFrame with 891 rows and 15 columns containing Titanic passenger data. + : + A narwhals DataFrame with 891 rows and 15 columns containing Titanic + passenger data. Examples -------- @@ -35,10 +40,10 @@ def titanic() -> pd.DataFrame: """ # Get the path to the gzipped CSV file using importlib.resources data_file = files("querychat.data") / "titanic.csv.gz" - return pd.read_csv(str(data_file), compression="gzip") + return read_csv(str(data_file)) -def tips() -> pd.DataFrame: +def tips() -> nw.DataFrame: """ Load the tips dataset. @@ -48,8 +53,9 @@ def tips() -> pd.DataFrame: Returns ------- - pandas.DataFrame - A DataFrame with 244 rows and 7 columns containing restaurant tip data. + : + A narwhals DataFrame with 244 rows and 7 columns containing restaurant + tip data. Examples -------- @@ -62,7 +68,7 @@ def tips() -> pd.DataFrame: """ # Get the path to the gzipped CSV file using importlib.resources data_file = files("querychat.data") / "tips.csv.gz" - return pd.read_csv(str(data_file), compression="gzip") + return read_csv(str(data_file)) __all__ = ["tips", "titanic"] diff --git a/pkg-py/src/querychat/tools.py b/pkg-py/src/querychat/tools.py index e0001de1..7d06b73a 100644 --- a/pkg-py/src/querychat/tools.py +++ b/pkg-py/src/querychat/tools.py @@ -218,7 +218,7 @@ def query(query: str, _intent: str = "") -> ContentToolResult: try: result_df = data_source.execute_query(query) - value = result_df.to_dict(orient="records") + value = result_df.rows(named=True) # Format table results tbl_html = df_to_html(result_df, maxrows=5) diff --git a/pkg-py/tests/test_data.py b/pkg-py/tests/test_data.py index 0ee2f8f8..6128a779 100644 --- a/pkg-py/tests/test_data.py +++ b/pkg-py/tests/test_data.py @@ -1,13 +1,13 @@ """Tests for the querychat.data module.""" -import pandas as pd +import narwhals.stable.v1 as nw from querychat.data import tips, titanic def test_titanic_returns_dataframe(): - """Test that titanic() returns a pandas DataFrame.""" + """Test that titanic() returns a narwhals DataFrame.""" df = titanic() - assert isinstance(df, pd.DataFrame) + assert isinstance(df, nw.DataFrame) def test_titanic_has_expected_shape(): @@ -44,16 +44,19 @@ def test_titanic_data_integrity(): df = titanic() # Check that survived column has only 0 and 1 values - assert set(df["survived"].dropna().unique()) <= {0, 1} + unique_survived = set(df["survived"].drop_nulls().unique().to_list()) + assert unique_survived <= {0, 1} # Check that pclass has only 1, 2, 3 - assert set(df["pclass"].dropna().unique()) <= {1, 2, 3} + unique_pclass = set(df["pclass"].drop_nulls().unique().to_list()) + assert unique_pclass <= {1, 2, 3} # Check that sex has only 'male' and 'female' - assert set(df["sex"].dropna().unique()) <= {"male", "female"} + unique_sex = set(df["sex"].drop_nulls().unique().to_list()) + assert unique_sex <= {"male", "female"} # Check that fare is non-negative - assert (df["fare"].dropna() >= 0).all() + assert df["fare"].drop_nulls().min() >= 0 def test_titanic_creates_new_copy(): @@ -64,14 +67,15 @@ def test_titanic_creates_new_copy(): # They should not be the same object assert df1 is not df2 - # But they should have the same data - assert df1.equals(df2) + # But they should have the same shape and columns + assert df1.shape == df2.shape + assert list(df1.columns) == list(df2.columns) def test_tips_returns_dataframe(): - """Test that tips() returns a pandas DataFrame.""" + """Test that tips() returns a narwhals DataFrame.""" df = tips() - assert isinstance(df, pd.DataFrame) + assert isinstance(df, nw.DataFrame) def test_tips_has_expected_shape(): @@ -100,19 +104,21 @@ def test_tips_data_integrity(): df = tips() # Check that total_bill is positive - assert (df["total_bill"] > 0).all() + assert df["total_bill"].min() > 0 # Check that tip is non-negative - assert (df["tip"] >= 0).all() + assert df["tip"].min() >= 0 # Check that sex has only expected values - assert set(df["sex"].dropna().unique()) <= {"Male", "Female"} + unique_sex = set(df["sex"].drop_nulls().unique().to_list()) + assert unique_sex <= {"Male", "Female"} # Check that smoker has only expected values - assert set(df["smoker"].dropna().unique()) <= {"Yes", "No"} + unique_smoker = set(df["smoker"].drop_nulls().unique().to_list()) + assert unique_smoker <= {"Yes", "No"} # Check that size is positive - assert (df["size"] > 0).all() + assert df["size"].min() > 0 def test_tips_creates_new_copy(): @@ -123,5 +129,6 @@ def test_tips_creates_new_copy(): # They should not be the same object assert df1 is not df2 - # But they should have the same data - assert df1.equals(df2) + # But they should have the same shape and columns + assert df1.shape == df2.shape + assert list(df1.columns) == list(df2.columns) diff --git a/pkg-py/tests/test_dataframe_source.py b/pkg-py/tests/test_dataframe_source.py new file mode 100644 index 00000000..7602583a --- /dev/null +++ b/pkg-py/tests/test_dataframe_source.py @@ -0,0 +1,284 @@ +"""Tests for the DataFrameSource class with narwhals compatibility.""" + +import duckdb +import narwhals.stable.v1 as nw +import pandas as pd +import pytest +from querychat._datasource import DataFrameSource + +# Check if polars and pyarrow are available (both needed for DuckDB + polars) +try: + import polars as pl + import pyarrow as pa # noqa: F401 + + HAS_POLARS_WITH_PYARROW = True +except ImportError: + HAS_POLARS_WITH_PYARROW = False + pl = None # type: ignore[assignment] + + +@pytest.fixture +def pandas_df(): + """Create a sample pandas DataFrame.""" + return pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"], + "age": [25, 30, 35, 28, 32], + "salary": [50000.0, 60000.0, 70000.0, 55000.0, 65000.0], + "department": ["Engineering", "Sales", "Engineering", "Sales", "Engineering"], + } + ) + + +@pytest.fixture +def narwhals_df(pandas_df): + """Create a narwhals DataFrame from pandas.""" + return nw.from_native(pandas_df) + + +class TestDataFrameSourceInit: + """Tests for DataFrameSource initialization.""" + + def test_init_with_pandas_dataframe(self, pandas_df): + """Test that DataFrameSource accepts a pandas DataFrame.""" + source = DataFrameSource(pandas_df, "test_table") + assert source.table_name == "test_table" + + def test_init_with_narwhals_dataframe(self, narwhals_df): + """Test that DataFrameSource accepts a narwhals DataFrame.""" + source = DataFrameSource(narwhals_df, "test_table") + assert source.table_name == "test_table" + + @pytest.mark.skipif(not HAS_POLARS_WITH_PYARROW, reason="polars or pyarrow not installed") + def test_init_with_polars_dataframe(self): + """Test that DataFrameSource accepts a polars DataFrame.""" + polars_df = pl.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + } + ) + source = DataFrameSource(polars_df, "test_table") + assert source.table_name == "test_table" + + +class TestDataFrameSourceExecuteQuery: + """Tests for DataFrameSource.execute_query method.""" + + def test_execute_query_returns_narwhals_dataframe(self, pandas_df): + """Test that execute_query returns a narwhals DataFrame.""" + source = DataFrameSource(pandas_df, "employees") + result = source.execute_query("SELECT * FROM employees") + assert isinstance(result, nw.DataFrame) + + def test_execute_query_select_all(self, pandas_df): + """Test SELECT * query.""" + source = DataFrameSource(pandas_df, "employees") + result = source.execute_query("SELECT * FROM employees") + + assert result.shape == (5, 5) + assert set(result.columns) == {"id", "name", "age", "salary", "department"} + + def test_execute_query_with_filter(self, pandas_df): + """Test query with WHERE clause.""" + source = DataFrameSource(pandas_df, "employees") + result = source.execute_query( + "SELECT * FROM employees WHERE department = 'Engineering'" + ) + + assert result.shape == (3, 5) + departments = result["department"].unique().to_list() + assert departments == ["Engineering"] + + def test_execute_query_with_aggregation(self, pandas_df): + """Test query with aggregation.""" + source = DataFrameSource(pandas_df, "employees") + result = source.execute_query( + "SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department" + ) + + assert result.shape == (2, 2) + assert "department" in result.columns + assert "avg_salary" in result.columns + + def test_execute_query_select_columns(self, pandas_df): + """Test selecting specific columns.""" + source = DataFrameSource(pandas_df, "employees") + result = source.execute_query("SELECT name, age FROM employees") + + assert result.shape == (5, 2) + assert list(result.columns) == ["name", "age"] + + def test_execute_query_order_by(self, pandas_df): + """Test query with ORDER BY clause.""" + source = DataFrameSource(pandas_df, "employees") + result = source.execute_query( + "SELECT name, age FROM employees ORDER BY age DESC" + ) + + ages = result["age"].to_list() + assert ages == sorted(ages, reverse=True) + + def test_execute_query_empty_result(self, pandas_df): + """Test query that returns no rows.""" + source = DataFrameSource(pandas_df, "employees") + result = source.execute_query( + "SELECT * FROM employees WHERE age > 100" + ) + + assert isinstance(result, nw.DataFrame) + assert result.shape == (0, 5) + + +class TestDataFrameSourceGetData: + """Tests for DataFrameSource.get_data method.""" + + def test_get_data_returns_narwhals_dataframe(self, pandas_df): + """Test that get_data returns a narwhals DataFrame.""" + source = DataFrameSource(pandas_df, "employees") + result = source.get_data() + assert isinstance(result, nw.DataFrame) + + def test_get_data_returns_full_dataset(self, pandas_df): + """Test that get_data returns all rows.""" + source = DataFrameSource(pandas_df, "employees") + result = source.get_data() + + assert result.shape == pandas_df.shape + assert set(result.columns) == set(pandas_df.columns) + + def test_get_data_preserves_data(self, pandas_df): + """Test that get_data preserves data values.""" + source = DataFrameSource(pandas_df, "employees") + result = source.get_data() + + # Check that the data matches + original_names = sorted(pandas_df["name"].tolist()) + result_names = sorted(result["name"].to_list()) + assert original_names == result_names + + +class TestDataFrameSourceGetSchema: + """Tests for DataFrameSource.get_schema method.""" + + def test_get_schema_includes_table_name(self, pandas_df): + """Test that schema includes table name.""" + source = DataFrameSource(pandas_df, "employees") + schema = source.get_schema(categorical_threshold=10) + + assert "Table: employees" in schema + assert "Columns:" in schema + + def test_get_schema_includes_all_columns(self, pandas_df): + """Test that schema includes all columns.""" + source = DataFrameSource(pandas_df, "employees") + schema = source.get_schema(categorical_threshold=10) + + for col in pandas_df.columns: + assert f"- {col} (" in schema + + def test_get_schema_numeric_ranges(self, pandas_df): + """Test that numeric columns include range information.""" + source = DataFrameSource(pandas_df, "employees") + schema = source.get_schema(categorical_threshold=10) + + # Age should have range + assert "Range: 25 to 35" in schema + # Salary should have range + assert "Range: 50000.0 to 70000.0" in schema + + def test_get_schema_categorical_values(self, pandas_df): + """Test that categorical columns show unique values.""" + source = DataFrameSource(pandas_df, "employees") + schema = source.get_schema(categorical_threshold=10) + + # Department has only 2 unique values, should be categorical + assert "Categorical values:" in schema + assert "'Engineering'" in schema + assert "'Sales'" in schema + + def test_get_schema_respects_threshold(self, pandas_df): + """Test that categorical_threshold is respected.""" + source = DataFrameSource(pandas_df, "employees") + + # With threshold 1, no columns should be categorical + schema_low = source.get_schema(categorical_threshold=1) + # Department has 2 unique values, should not be listed as categorical + lines = schema_low.split("\n") + dept_idx = next(i for i, line in enumerate(lines) if "- department" in line) + if dept_idx + 1 < len(lines): + assert "Categorical values:" not in lines[dept_idx + 1] + + # With threshold 5, department should be categorical + schema_high = source.get_schema(categorical_threshold=5) + assert "'Engineering'" in schema_high + + +class TestDataFrameSourceDbType: + """Tests for DataFrameSource.get_db_type method.""" + + def test_get_db_type_returns_duckdb(self, pandas_df): + """Test that get_db_type returns 'DuckDB'.""" + source = DataFrameSource(pandas_df, "employees") + assert source.get_db_type() == "DuckDB" + + +class TestDataFrameSourceCleanup: + """Tests for DataFrameSource.cleanup method.""" + + def test_cleanup_closes_connection(self, pandas_df): + """Test that cleanup closes the DuckDB connection.""" + source = DataFrameSource(pandas_df, "employees") + + # Should work before cleanup + result = source.execute_query("SELECT * FROM employees LIMIT 1") + assert result.shape[0] == 1 + + # Cleanup + source.cleanup() + + # After cleanup, queries should fail + with pytest.raises(duckdb.ConnectionException): + source.execute_query("SELECT * FROM employees") + + +@pytest.mark.skipif(not HAS_POLARS_WITH_PYARROW, reason="polars or pyarrow not installed") +class TestDataFrameSourceWithPolars: + """Tests for DataFrameSource with polars DataFrames.""" + + @pytest.fixture + def polars_df(self): + """Create a sample polars DataFrame.""" + return pl.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "value": [10.5, 20.5, 30.5], + } + ) + + def test_execute_query_with_polars(self, polars_df): + """Test execute_query with polars source.""" + source = DataFrameSource(polars_df, "test_data") + result = source.execute_query("SELECT * FROM test_data") + + assert isinstance(result, nw.DataFrame) + assert result.shape == (3, 3) + + def test_get_data_with_polars(self, polars_df): + """Test get_data with polars source.""" + source = DataFrameSource(polars_df, "test_data") + result = source.get_data() + + assert isinstance(result, nw.DataFrame) + assert result.shape == polars_df.shape + + def test_polars_result_backend(self, polars_df): + """Test that results use polars backend when input is polars.""" + source = DataFrameSource(polars_df, "test_data") + result = source.execute_query("SELECT * FROM test_data") + + # When polars is available, the result should use polars backend + native = result.to_native() + assert isinstance(native, pl.DataFrame) diff --git a/pkg-py/tests/test_df_compat.py b/pkg-py/tests/test_df_compat.py new file mode 100644 index 00000000..065efc2a --- /dev/null +++ b/pkg-py/tests/test_df_compat.py @@ -0,0 +1,134 @@ +"""Tests for the _df_compat module and narwhals DataFrame compatibility.""" + +import gzip +import tempfile +from pathlib import Path + +import duckdb +import narwhals.stable.v1 as nw +import pytest +from querychat._df_compat import duckdb_result_to_nw, read_csv + +# Check if polars and pyarrow are available (both needed for DuckDB + polars) +try: + import polars as pl + import pyarrow as pa # noqa: F401 + + HAS_POLARS_WITH_PYARROW = True +except ImportError: + HAS_POLARS_WITH_PYARROW = False + pl = None # type: ignore[assignment] + + +class TestReadCsv: + """ + Tests for the read_csv function. + + Note: read_csv is designed for reading gzipped CSV files (used for bundled data). + """ + + @pytest.fixture + def gzip_csv_file(self): + """Create a temporary gzipped CSV file for testing.""" + with tempfile.NamedTemporaryFile(suffix=".csv.gz", delete=False) as f: + temp_path = f.name + + with gzip.open(temp_path, "wt") as f: + f.write("id,name,value\n") + f.write("1,Alice,100\n") + f.write("2,Bob,200\n") + f.write("3,Charlie,300\n") + + yield temp_path + Path(temp_path).unlink() + + def test_read_csv_returns_narwhals_dataframe(self, gzip_csv_file): + """Test that read_csv returns a narwhals DataFrame.""" + result = read_csv(gzip_csv_file) + assert isinstance(result, nw.DataFrame) + + def test_read_csv_has_correct_shape(self, gzip_csv_file): + """Test that read_csv produces correct data.""" + result = read_csv(gzip_csv_file) + assert result.shape == (3, 3) + assert list(result.columns) == ["id", "name", "value"] + + def test_read_csv_data_integrity(self, gzip_csv_file): + """Test that read_csv preserves data correctly.""" + result = read_csv(gzip_csv_file) + names = result["name"].to_list() + assert names == ["Alice", "Bob", "Charlie"] + + +class TestDuckdbResultToNw: + """Tests for the duckdb_result_to_nw function.""" + + @pytest.fixture + def duckdb_conn(self): + """Create a DuckDB connection with test data.""" + conn = duckdb.connect(":memory:") + conn.execute("CREATE TABLE test (id INTEGER, name VARCHAR, value DOUBLE)") + conn.execute("INSERT INTO test VALUES (1, 'Alice', 10.5)") + conn.execute("INSERT INTO test VALUES (2, 'Bob', 20.5)") + yield conn + conn.close() + + def test_duckdb_result_returns_narwhals_dataframe(self, duckdb_conn): + """Test that duckdb_result_to_nw returns a narwhals DataFrame.""" + result = duckdb_conn.execute("SELECT * FROM test") + df = duckdb_result_to_nw(result) + assert isinstance(df, nw.DataFrame) + + def test_duckdb_result_has_correct_data(self, duckdb_conn): + """Test that duckdb_result_to_nw preserves data correctly.""" + result = duckdb_conn.execute("SELECT * FROM test ORDER BY id") + df = duckdb_result_to_nw(result) + + assert df.shape == (2, 3) + assert list(df.columns) == ["id", "name", "value"] + assert df["id"].to_list() == [1, 2] + assert df["name"].to_list() == ["Alice", "Bob"] + + def test_duckdb_result_empty_query(self, duckdb_conn): + """Test handling of empty query results.""" + result = duckdb_conn.execute("SELECT * FROM test WHERE id > 100") + df = duckdb_result_to_nw(result) + + assert isinstance(df, nw.DataFrame) + assert df.shape == (0, 3) + + +@pytest.mark.skipif( + not HAS_POLARS_WITH_PYARROW, reason="polars or pyarrow not installed" +) +class TestPolarsBackend: + """Tests that verify polars backend works correctly when available.""" + + def test_read_csv_uses_polars_when_available(self): + """Test that read_csv uses polars as the backend when available.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + f.write("x,y\n1,2\n3,4\n") + temp_path = f.name + + try: + result = read_csv(temp_path) + # The native frame should be polars when polars is available + native = result.to_native() + assert isinstance(native, pl.DataFrame) + finally: + Path(temp_path).unlink() + + def test_duckdb_result_uses_polars_when_available(self): + """Test that duckdb_result_to_nw uses polars when available.""" + conn = duckdb.connect(":memory:") + conn.execute("CREATE TABLE t (x INTEGER)") + conn.execute("INSERT INTO t VALUES (1)") + + result = conn.execute("SELECT * FROM t") + df = duckdb_result_to_nw(result) + + # The native frame should be polars when polars is available + native = df.to_native() + assert isinstance(native, pl.DataFrame) + + conn.close() diff --git a/pkg-py/tests/test_init_with_pandas.py b/pkg-py/tests/test_init_with_pandas.py index 7c182639..1179a25c 100644 --- a/pkg-py/tests/test_init_with_pandas.py +++ b/pkg-py/tests/test_init_with_pandas.py @@ -64,9 +64,8 @@ def test_init_with_narwhals_dataframe(): assert qc is not None -def test_init_with_narwhals_lazyframe_direct_query(): - """Test that QueryChat() can accept a narwhals LazyFrame and execute queries.""" - # Create a pandas DataFrame and convert to narwhals LazyFrame +def test_init_with_narwhals_lazyframe_raises(): + """Test that QueryChat() raises TypeError for LazyFrames.""" pdf = pd.DataFrame( { "id": [1, 2, 3], @@ -76,20 +75,9 @@ def test_init_with_narwhals_lazyframe_direct_query(): ) nw_lazy = nw.from_native(pdf).lazy() - # Call QueryChat with the narwhals LazyFrame - qc = QueryChat( - data_source=nw_lazy, # TODO(@gadebuie): Fix this type error - table_name="test_table", - greeting="hello!", - ) - - # Verify the result is correctly configured - assert qc is not None - assert hasattr(qc, "data_source") - - # Test that we can run a query on the data source - query_result = qc.data_source.execute_query( - "SELECT * FROM test_table WHERE id = 2", - ) - assert len(query_result) == 1 - assert query_result.iloc[0]["name"] == "Bob" + with pytest.raises(NotImplementedError, match="LazyFrame"): + QueryChat( + data_source=nw_lazy, + table_name="test_table", + greeting="hello!", + ) diff --git a/pkg-py/tests/test_querychat.py b/pkg-py/tests/test_querychat.py index 22dab6d7..3c0e1e9d 100644 --- a/pkg-py/tests/test_querychat.py +++ b/pkg-py/tests/test_querychat.py @@ -47,7 +47,7 @@ def test_querychat_init(sample_df): ) assert len(result) == 1 - assert result.iloc[0]["name"] == "Bob" + assert result.item(0, "name") == "Bob" def test_querychat_custom_id(sample_df): diff --git a/pyproject.toml b/pyproject.toml index c4fa03f9..dc296fff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,6 @@ maintainers = [ ] dependencies = [ "duckdb", - "pandas", "shiny>=1.5.1", "shinywidgets", "htmltools", @@ -40,6 +39,11 @@ classifiers = [ "Programming Language :: Python :: 3.14", ] +[project.optional-dependencies] +# For SQLAlchemySource and sample data, one of polars or pandas is required +pandas = ["pandas"] +polars = ["polars"] + [project.urls] Homepage = "https://github.com/posit-dev/querychat" # TODO update when we have docs Repository = "https://github.com/posit-dev/querychat"