diff --git a/test/test_message.py b/test/test_message.py new file mode 100644 index 0000000000..c984a21167 --- /dev/null +++ b/test/test_message.py @@ -0,0 +1,349 @@ +# Copyright 2026-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for message.py.""" + +from __future__ import annotations + +import struct +import sys +from unittest.mock import MagicMock + +sys.path[0:0] = [""] + +from test import unittest + +from bson import CodecOptions, encode +from pymongo.compression_support import ZlibContext +from pymongo.errors import DocumentTooLarge, OperationFailure +from pymongo.message import ( + _compress, + _convert_client_bulk_exception, + _convert_exception, + _convert_write_result, + _gen_find_command, + _gen_get_more_command, + _get_more_compressed, + _get_more_uncompressed, + _maybe_add_read_preference, + _op_msg, + _query_compressed, + _query_uncompressed, + _raise_document_too_large, +) +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference, SecondaryPreferred + +_OPTS = CodecOptions() +_ZLIB_CTX = ZlibContext(-1) + + +class TestMaybeAddReadPreference(unittest.TestCase): + def test_primary_no_read_preference_added(self): + spec: dict = {"find": "col"} + result = _maybe_add_read_preference(spec, ReadPreference.PRIMARY) + self.assertNotIn("$readPreference", result) + self.assertNotIn("$query", result) + + def test_secondary_adds_read_preference(self): + spec: dict = {"find": "col"} + result = _maybe_add_read_preference(spec, ReadPreference.SECONDARY) + self.assertIn("$readPreference", result) + self.assertEqual(result["$readPreference"]["mode"], "secondary") + self.assertIn("$query", result) + + def test_secondary_preferred_no_tags_does_not_add(self): + spec: dict = {"find": "col"} + result = _maybe_add_read_preference(spec, ReadPreference.SECONDARY_PREFERRED) + self.assertNotIn("$readPreference", result) + + def test_secondary_preferred_with_tags_adds_read_preference(self): + pref = SecondaryPreferred(tag_sets=[{"dc": "east"}]) + spec: dict = {"find": "col"} + result = _maybe_add_read_preference(spec, pref) + self.assertIn("$readPreference", result) + + def test_existing_query_wrapper_preserved(self): + spec: dict = {"$query": {"x": 1}, "other": 2} + result = _maybe_add_read_preference(spec, ReadPreference.SECONDARY) + self.assertIn("$readPreference", result) + self.assertEqual(result["$query"], {"x": 1}) + + +class TestConvertException(unittest.TestCase): + def test_basic_exception(self): + exc = ValueError("bad value") + doc = _convert_exception(exc) + self.assertEqual(doc["errmsg"], "bad value") + self.assertEqual(doc["errtype"], "ValueError") + + def test_client_bulk_exception_includes_code(self): + exc = OperationFailure("failed", code=11000) + doc = _convert_client_bulk_exception(exc) + self.assertEqual(doc["errmsg"], "failed") + self.assertEqual(doc["code"], 11000) + self.assertEqual(doc["errtype"], "OperationFailure") + + +class TestConvertWriteResult(unittest.TestCase): + """Tests for _convert_write_result. + + In the update command spec, `q` is the query/filter and `u` is the update document. + """ + + def test_insert_basic(self): + cmd = {"documents": [{"_id": 1}, {"_id": 2}]} + result = _convert_write_result("insert", cmd, {"n": 0}) + self.assertEqual(result["ok"], 1) + self.assertEqual(result["n"], 2) + + def test_update_basic(self): + cmd = {"updates": [{"q": {}, "u": {"$set": {"x": 1}}}]} + result = _convert_write_result("update", cmd, {"n": 1, "updatedExisting": True}) + self.assertEqual(result["ok"], 1) + self.assertNotIn("upserted", result) + + def test_update_with_upserted_id(self): + cmd = {"updates": [{"q": {}, "u": {"_id": 42}}]} + result = _convert_write_result("update", cmd, {"n": 1, "upserted": 42}) + self.assertIn("upserted", result) + self.assertEqual(result["upserted"][0]["_id"], 42) + + def test_update_upsert_id_precedence(self): + # When _id is in both the update document and the query spec, + # the update document's _id wins. + cmd = {"updates": [{"q": {"_id": 99}, "u": {"_id": 42}}]} + result = _convert_write_result("update", cmd, {"n": 1, "updatedExisting": False}) + self.assertEqual(result["upserted"][0]["_id"], 42) + + def test_update_upsert_no_upserted_id_from_query(self): + cmd = {"updates": [{"q": {"_id": 77}, "u": {"$set": {"x": 1}}}]} + result = _convert_write_result("update", cmd, {"n": 1, "updatedExisting": False}) + self.assertIn("upserted", result) + self.assertEqual(result["upserted"][0]["_id"], 77) + + def test_delete_basic(self): + cmd = {"deletes": [{"q": {}, "limit": 1}]} + result = _convert_write_result("delete", cmd, {"n": 1}) + self.assertEqual(result["ok"], 1) + self.assertEqual(result["n"], 1) + + def test_write_error(self): + cmd = {"documents": [{"_id": 1}]} + gle = {"n": 0, "err": "duplicate key error", "code": 11000} + result = _convert_write_result("insert", cmd, gle) + self.assertIn("writeErrors", result) + self.assertEqual(result["writeErrors"][0]["code"], 11000) + + def test_write_concern_timeout(self): + cmd = {"documents": [{"_id": 1}]} + gle = {"n": 1, "errmsg": "timeout", "wtimeout": True} + result = _convert_write_result("insert", cmd, gle) + self.assertIn("writeConcernError", result) + self.assertEqual(result["writeConcernError"]["code"], 64) + + def test_write_error_with_err_info(self): + # Covers the `if "errInfo" in result:` branch, which test_write_error does not enter. + cmd = {"documents": [{"_id": 1}]} + gle = {"n": 0, "err": "err", "code": 123, "errInfo": {"detail": "x"}} + result = _convert_write_result("insert", cmd, gle) + self.assertIn("errInfo", result["writeErrors"][0]) + + +class TestCompress(unittest.TestCase): + def test_compressed_message_has_op_compressed_header(self): + msg = _compress(2013, b"hello world", _ZLIB_CTX)[1] + op_code = struct.unpack("