|
14 | 14 | from databricks.sql.result_set import ThriftResultSet |
15 | 15 |
|
16 | 16 |
|
| 17 | +class _StubArrowQueue: |
| 18 | + """Minimal queue that hands back a pre-built pyarrow.Table once. |
| 19 | +
|
| 20 | + Used to inject a schemaless / wrong-schema placeholder that the real |
| 21 | + ArrowQueue would never produce — this is what CloudFetchQueue emits |
| 22 | + when ``self.table is None`` and ``schema_bytes`` is missing. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, table): |
| 26 | + self._table = table |
| 27 | + self._consumed = False |
| 28 | + |
| 29 | + def _take(self): |
| 30 | + if self._consumed: |
| 31 | + return self._table.slice(0, 0) |
| 32 | + self._consumed = True |
| 33 | + return self._table |
| 34 | + |
| 35 | + def next_n_rows(self, num_rows): |
| 36 | + return self._take() |
| 37 | + |
| 38 | + def remaining_rows(self): |
| 39 | + return self._take() |
| 40 | + |
| 41 | + def close(self): |
| 42 | + pass |
| 43 | + |
| 44 | + |
17 | 45 | @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") |
18 | 46 | class FetchTests(unittest.TestCase): |
19 | 47 | """ |
@@ -110,6 +138,39 @@ def fetch_results( |
110 | 138 | ) |
111 | 139 | return rs |
112 | 140 |
|
| 141 | + @staticmethod |
| 142 | + def make_dummy_result_set_from_queue_list(queue_list, description=None): |
| 143 | + """Like make_dummy_result_set_from_batch_list but yields pre-built queues. |
| 144 | +
|
| 145 | + Lets tests inject queues whose returned tables have an arbitrary schema |
| 146 | + (or no schema at all) — needed to reproduce the CloudFetch placeholder |
| 147 | + case that ``ArrowQueue`` would never produce. |
| 148 | + """ |
| 149 | + queue_index = 0 |
| 150 | + |
| 151 | + def fetch_results(**_): |
| 152 | + nonlocal queue_index |
| 153 | + q = queue_list[queue_index] |
| 154 | + queue_index += 1 |
| 155 | + return q, queue_index < len(queue_list), 0 |
| 156 | + |
| 157 | + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) |
| 158 | + mock_thrift_backend.fetch_results = fetch_results |
| 159 | + |
| 160 | + rs = ThriftResultSet( |
| 161 | + connection=Mock(), |
| 162 | + execute_response=ExecuteResponse( |
| 163 | + command_id=None, |
| 164 | + status=None, |
| 165 | + has_been_closed_server_side=False, |
| 166 | + description=description or [], |
| 167 | + lz4_compressed=True, |
| 168 | + is_staging_operation=False, |
| 169 | + ), |
| 170 | + thrift_client=mock_thrift_backend, |
| 171 | + ) |
| 172 | + return rs |
| 173 | + |
113 | 174 | def assertEqualRowValues(self, actual, expected): |
114 | 175 | self.assertEqual(len(actual) if actual else 0, len(expected) if expected else 0) |
115 | 176 | for act, exp in zip(actual, expected): |
@@ -267,6 +328,68 @@ def test_fetchone_without_initial_results(self): |
267 | 328 | dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_2) |
268 | 329 | self.assertEqual(dummy_result_set.fetchone(), None) |
269 | 330 |
|
| 331 | + # Regression tests for fetchmany_arrow / fetchall_arrow handling of |
| 332 | + # the schemaless CloudFetch placeholder. |
| 333 | + def test_fetchall_arrow_drops_mismatched_empty_placeholder(self): |
| 334 | + # First fetch_results() call hands back a 0-row placeholder whose |
| 335 | + # schema does not match the real chunks. The second call |
| 336 | + # hands back real data. |
| 337 | + placeholder = pa.Table.from_pydict( |
| 338 | + {"stale_col": []}, schema=pa.schema({"stale_col": pa.string()}) |
| 339 | + ) |
| 340 | + _, real_table = self.make_arrow_table([[1], [2], [3]]) |
| 341 | + rs = self.make_dummy_result_set_from_queue_list( |
| 342 | + [_StubArrowQueue(placeholder), _StubArrowQueue(real_table)], |
| 343 | + description=[("col0", "integer", None, None, None, None, None)], |
| 344 | + ) |
| 345 | + |
| 346 | + result = rs.fetchall_arrow() |
| 347 | + |
| 348 | + self.assertEqual(result.num_rows, 3) |
| 349 | + self.assertEqual(result.schema.names, ["col0"]) |
| 350 | + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) |
| 351 | + |
| 352 | + def test_fetchall_arrow_all_empty_returns_zero_row_table(self): |
| 353 | + # Every queue call returns the schemaless placeholder — the |
| 354 | + # call site should fall back to zero_row_table without crashing. |
| 355 | + placeholder = pa.Table.from_pydict({}) |
| 356 | + rs = self.make_dummy_result_set_from_queue_list( |
| 357 | + [_StubArrowQueue(placeholder)], |
| 358 | + ) |
| 359 | + |
| 360 | + result = rs.fetchall_arrow() |
| 361 | + |
| 362 | + self.assertIsInstance(result, pa.Table) |
| 363 | + self.assertEqual(result.num_rows, 0) |
| 364 | + |
| 365 | + def test_fetchmany_arrow_drops_mismatched_empty_placeholder(self): |
| 366 | + # See ``test_fetchall_arrow_drops_mismatched_empty_placeholder``. |
| 367 | + placeholder = pa.Table.from_pydict( |
| 368 | + {"stale_col": []}, schema=pa.schema({"stale_col": pa.string()}) |
| 369 | + ) |
| 370 | + _, real_table = self.make_arrow_table([[1], [2], [3]]) |
| 371 | + rs = self.make_dummy_result_set_from_queue_list( |
| 372 | + [_StubArrowQueue(placeholder), _StubArrowQueue(real_table)], |
| 373 | + description=[("col0", "integer", None, None, None, None, None)], |
| 374 | + ) |
| 375 | + |
| 376 | + result = rs.fetchmany_arrow(3) |
| 377 | + |
| 378 | + self.assertEqual(result.num_rows, 3) |
| 379 | + self.assertEqual(result.schema.names, ["col0"]) |
| 380 | + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) |
| 381 | + |
| 382 | + def test_fetchmany_arrow_all_empty_returns_zero_row_table(self): |
| 383 | + placeholder = pa.Table.from_pydict({}) |
| 384 | + rs = self.make_dummy_result_set_from_queue_list( |
| 385 | + [_StubArrowQueue(placeholder)], |
| 386 | + ) |
| 387 | + |
| 388 | + result = rs.fetchmany_arrow(10) |
| 389 | + |
| 390 | + self.assertIsInstance(result, pa.Table) |
| 391 | + self.assertEqual(result.num_rows, 0) |
| 392 | + |
270 | 393 |
|
271 | 394 | if __name__ == "__main__": |
272 | 395 | unittest.main() |
0 commit comments