diff --git a/mypyc/doc/str_operations.rst b/mypyc/doc/str_operations.rst index 64ebbdcfa0bb..c36574ed7864 100644 --- a/mypyc/doc/str_operations.rst +++ b/mypyc/doc/str_operations.rst @@ -38,8 +38,9 @@ Methods * ``s1.find(s2: str)`` * ``s1.find(s2: str, start: int)`` * ``s1.find(s2: str, start: int, end: int)`` -* ``s.isspace()`` * ``s.isalnum()`` +* ``s.isdigit()`` +* ``s.isspace()`` * ``s.join(x: Iterable)`` * ``s.lstrip()`` * ``s.lstrip(chars: str)`` diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index ee5f5932f849..6e4f9a729ab1 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -782,6 +782,7 @@ CPyTagged CPyStr_Ord(PyObject *obj); PyObject *CPyStr_Multiply(PyObject *str, CPyTagged count); bool CPyStr_IsSpace(PyObject *str); bool CPyStr_IsAlnum(PyObject *str); +bool CPyStr_IsDigit(PyObject *str); // Bytes operations diff --git a/mypyc/lib-rt/str_ops.c b/mypyc/lib-rt/str_ops.c index 6530472e42e4..fba0bb39395c 100644 --- a/mypyc/lib-rt/str_ops.c +++ b/mypyc/lib-rt/str_ops.c @@ -677,3 +677,40 @@ bool CPyStr_IsAlnum(PyObject *str) { } return true; } + +bool CPyStr_IsDigit(PyObject *str) { + Py_ssize_t len = PyUnicode_GET_LENGTH(str); + if (len == 0) return false; + +#define CHECK_ISDIGIT(TYPE, DATA, CHECK) \ + { \ + const TYPE *data = (const TYPE *)(DATA); \ + for (Py_ssize_t i = 0; i < len; i++) { \ + if (!CHECK(data[i])) \ + return false; \ + } \ + } + + // ASCII fast path + if (PyUnicode_IS_ASCII(str)) { + CHECK_ISDIGIT(Py_UCS1, PyUnicode_1BYTE_DATA(str), Py_ISDIGIT); + return true; + } + + switch (PyUnicode_KIND(str)) { + case PyUnicode_1BYTE_KIND: + CHECK_ISDIGIT(Py_UCS1, PyUnicode_1BYTE_DATA(str), Py_UNICODE_ISDIGIT); + break; + case PyUnicode_2BYTE_KIND: + CHECK_ISDIGIT(Py_UCS2, PyUnicode_2BYTE_DATA(str), Py_UNICODE_ISDIGIT); + break; + case PyUnicode_4BYTE_KIND: + CHECK_ISDIGIT(Py_UCS4, PyUnicode_4BYTE_DATA(str), Py_UNICODE_ISDIGIT); + break; + default: + Py_UNREACHABLE(); + } + return true; + +#undef CHECK_ISDIGIT +} diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index 86559f162f90..374654b32df2 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -413,6 +413,14 @@ error_kind=ERR_NEVER, ) +method_op( + name="isdigit", + arg_types=[str_rprimitive], + return_type=bool_rprimitive, + c_function_name="CPyStr_IsDigit", + error_kind=ERR_NEVER, +) + # obj.decode() method_op( diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 633a3fdc32e3..f2d4e77addec 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -133,6 +133,7 @@ def islower(self) -> bool: ... def count(self, substr: str, start: Optional[int] = None, end: Optional[int] = None) -> int: pass def isspace(self) -> bool: ... def isalnum(self) -> bool: ... + def isdigit(self) -> bool: ... class float: def __init__(self, x: object) -> None: pass diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index ce2fbab6fad5..4245279937c4 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -994,3 +994,14 @@ def is_alnum(x): L0: r0 = CPyStr_IsAlnum(x) return r0 + +[case testStrIsDigit] +def is_digit(x: str) -> bool: + return x.isdigit() +[out] +def is_digit(x): + x :: str + r0 :: bool +L0: + r0 = CPyStr_IsDigit(x) + return r0 diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index 1f9651e13c69..df6d7363a5b1 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -1299,3 +1299,37 @@ def test_isalnum_unicode() -> None: # Unicode letter/digit mixed with punctuation — not alnum assert not "\u00E9!".isalnum() assert not "\u4E2D\u2000".isalnum() # CJK + whitespace + +[case testIsDigit] +from typing import Any + +def test_isdigit() -> None: + for i in range(0x110000): + c = chr(i) + a: Any = c + assert c.isdigit() == a.isdigit() + +def test_isdigit_strings() -> None: + # ASCII digits + assert "0123456789".isdigit() + assert not "".isdigit() + assert not " ".isdigit() + assert not "a".isdigit() + assert not "abc".isdigit() + assert not "!@#".isdigit() + + # Mixed ASCII + assert not "123abc".isdigit() + assert not "abc123".isdigit() + assert not "12 34".isdigit() + assert not "123!".isdigit() + + # Unicode digits + assert "\u0660\u0661\u0662".isdigit() + assert "\u00b2\u00b3".isdigit() + assert "123\U0001d7ce\U0001d7cf\U0001d7d0".isdigit() + + # Mixed digits and Unicode non-digits + assert not "\u00e9\u00e8".isdigit() + assert not "123\u00e9".isdigit() + assert not "\U0001d7ce!".isdigit()