Skip to content

Commit cea2802

Browse files
[mypyc] Add primitive for bytes decode() method
1 parent a3ce6d5 commit cea2802

File tree

6 files changed

+161
-7
lines changed

6 files changed

+161
-7
lines changed

mypyc/irbuild/specialize.py

+47
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
RTuple,
5050
RType,
5151
bool_rprimitive,
52+
bytes_rprimitive,
5253
c_int_rprimitive,
5354
dict_rprimitive,
5455
int16_rprimitive,
@@ -83,6 +84,11 @@
8384
join_formatted_strings,
8485
tokenizer_format_call,
8586
)
87+
from mypyc.primitives.bytes_ops import (
88+
bytes_decode_ascii_strict,
89+
bytes_decode_latin1_strict,
90+
bytes_decode_utf8_strict,
91+
)
8692
from mypyc.primitives.dict_ops import (
8793
dict_items_op,
8894
dict_keys_op,
@@ -740,6 +746,47 @@ def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) ->
740746
return None
741747

742748

749+
@specialize_function("decode", bytes_rprimitive)
750+
def bytes_decode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
751+
if not isinstance(callee, MemberExpr):
752+
return None
753+
754+
encoding = "utf8"
755+
756+
if len(expr.arg_kinds) > 0 and isinstance(expr.args[0], StrExpr):
757+
if expr.arg_kinds[0] == ARG_NAMED:
758+
if expr.arg_names[0] == "encoding":
759+
encoding = expr.args[0].value
760+
elif expr.arg_kinds[0] == ARG_POS:
761+
encoding = expr.args[0].value
762+
else:
763+
return None
764+
765+
if len(expr.arg_kinds) > 1 and isinstance(expr.args[1], StrExpr):
766+
if expr.arg_kinds[1] == ARG_NAMED:
767+
if expr.arg_names[1] == "encoding":
768+
encoding = expr.args[1].value
769+
else:
770+
return None
771+
772+
normalized = encoding.lower().replace("-", "").replace("_", "")
773+
774+
if normalized in ("utf8", "utf", "u8", "cp65001"):
775+
return builder.primitive_op(
776+
bytes_decode_utf8_strict, [builder.accept(callee.expr)], expr.line
777+
)
778+
elif normalized in ("ascii", "usascii", "646"):
779+
return builder.primitive_op(
780+
bytes_decode_ascii_strict, [builder.accept(callee.expr)], expr.line
781+
)
782+
elif normalized in ("latin1", "latin", "iso88591", "cp819", "8859", "l1"):
783+
return builder.primitive_op(
784+
bytes_decode_latin1_strict, [builder.accept(callee.expr)], expr.line
785+
)
786+
787+
return None
788+
789+
743790
@specialize_function("mypy_extensions.i64")
744791
def translate_i64(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
745792
if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS:

mypyc/lib-rt/CPy.h

+3
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,9 @@ CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index);
764764
PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
765765
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
766766
CPyTagged CPyBytes_Ord(PyObject *obj);
767+
PyObject *CPy_DecodeUtf8(PyObject *bytes_obj);
768+
PyObject *CPy_DecodeLatin1(PyObject *bytes_obj);
769+
PyObject *CPy_DecodeAscii(PyObject *bytes_obj);
767770

768771

769772
int CPyBytes_Compare(PyObject *left, PyObject *right);

mypyc/lib-rt/bytes_ops.c

+39
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,42 @@ CPyTagged CPyBytes_Ord(PyObject *obj) {
162162
PyErr_SetString(PyExc_TypeError, "ord() expects a character");
163163
return CPY_INT_TAG;
164164
}
165+
166+
167+
PyObject *CPy_DecodeUtf8(PyObject *bytes_obj) {
168+
if (!PyBytes_Check(bytes_obj)) {
169+
PyErr_SetString(PyExc_TypeError, "expected bytes object");
170+
return NULL;
171+
}
172+
173+
char *data = PyBytes_AS_STRING(bytes_obj);
174+
Py_ssize_t size = PyBytes_GET_SIZE(bytes_obj);
175+
176+
return PyUnicode_DecodeUTF8(data, size, NULL);
177+
}
178+
179+
180+
PyObject *CPy_DecodeLatin1(PyObject *bytes_obj) {
181+
if (!PyBytes_Check(bytes_obj)) {
182+
PyErr_SetString(PyExc_TypeError, "expected bytes object");
183+
return NULL;
184+
}
185+
186+
char *data = PyBytes_AS_STRING(bytes_obj);
187+
Py_ssize_t size = PyBytes_GET_SIZE(bytes_obj);
188+
189+
return PyUnicode_DecodeLatin1(data, size, NULL);
190+
}
191+
192+
193+
PyObject *CPy_DecodeAscii(PyObject *bytes_obj) {
194+
if (!PyBytes_Check(bytes_obj)) {
195+
PyErr_SetString(PyExc_TypeError, "expected bytes object");
196+
return NULL;
197+
}
198+
199+
char *data = PyBytes_AS_STRING(bytes_obj);
200+
Py_ssize_t size = PyBytes_GET_SIZE(bytes_obj);
201+
202+
return PyUnicode_DecodeASCII(data, size, NULL);
203+
}

mypyc/primitives/bytes_ops.py

+25
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ERR_NEG_INT,
1919
binary_op,
2020
custom_op,
21+
custom_primitive_op,
2122
function_op,
2223
load_address_op,
2324
method_op,
@@ -107,3 +108,27 @@
107108
c_function_name="CPyBytes_Ord",
108109
error_kind=ERR_MAGIC,
109110
)
111+
112+
bytes_decode_utf8_strict = custom_primitive_op(
113+
name="decode",
114+
arg_types=[bytes_rprimitive],
115+
return_type=str_rprimitive,
116+
c_function_name="CPy_DecodeUtf8",
117+
error_kind=ERR_MAGIC,
118+
)
119+
120+
bytes_decode_latin1_strict = custom_primitive_op(
121+
name="decode_latin1",
122+
arg_types=[bytes_rprimitive],
123+
return_type=str_rprimitive,
124+
c_function_name="CPy_DecodeLatin1",
125+
error_kind=ERR_MAGIC,
126+
)
127+
128+
bytes_decode_ascii_strict = custom_primitive_op(
129+
name="decode_ascii",
130+
arg_types=[bytes_rprimitive],
131+
return_type=str_rprimitive,
132+
c_function_name="CPy_DecodeAscii",
133+
error_kind=ERR_MAGIC,
134+
)

mypyc/test-data/irbuild-bytes.test

+41
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,44 @@ L0:
185185
r10 = CPyBytes_Build(2, var, r9)
186186
b4 = r10
187187
return 1
188+
189+
[case testDecode]
190+
def f(b: bytes) -> None:
191+
b.decode()
192+
b.decode('utf8')
193+
b.decode('utf-8', 'strict')
194+
b.decode('utf-8', 'strict')
195+
b.decode('latin1', 'strict')
196+
b.decode('ascii')
197+
b.decode('latin-1')
198+
b.decode('utf-8', 'ignore')
199+
b.decode('ascii', 'replace')
200+
b.decode('latin1', 'ignore')
201+
[out]
202+
def f(b):
203+
b :: bytes
204+
r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15, r16, r17, r18, r19, r20, r21 :: str
205+
L0:
206+
r0 = CPy_DecodeUtf8(b)
207+
r1 = CPy_DecodeUtf8(b)
208+
r2 = 'utf-8'
209+
r3 = 'strict'
210+
r4 = CPy_Decode(b, r2, r3)
211+
r5 = 'utf-8'
212+
r6 = 'strict'
213+
r7 = CPy_Decode(b, r5, r6)
214+
r8 = 'latin1'
215+
r9 = 'strict'
216+
r10 = CPy_Decode(b, r8, r9)
217+
r11 = CPy_DecodeAscii(b)
218+
r12 = CPy_DecodeLatin1(b)
219+
r13 = 'utf-8'
220+
r14 = 'ignore'
221+
r15 = CPy_Decode(b, r13, r14)
222+
r16 = 'ascii'
223+
r17 = 'replace'
224+
r18 = CPy_Decode(b, r16, r17)
225+
r19 = 'latin1'
226+
r20 = 'ignore'
227+
r21 = CPy_Decode(b, r19, r20)
228+
return 1

mypyc/test-data/irbuild-str.test

+6-7
Original file line numberDiff line numberDiff line change
@@ -335,14 +335,13 @@ def f(b: bytes) -> None:
335335
[out]
336336
def f(b):
337337
b :: bytes
338-
r0, r1, r2, r3, r4, r5 :: str
338+
r0, r1, r2, r3, r4 :: str
339339
L0:
340-
r0 = CPy_Decode(b, 0, 0)
341-
r1 = 'utf-8'
342-
r2 = CPy_Decode(b, r1, 0)
343-
r3 = 'utf-8'
344-
r4 = 'backslashreplace'
345-
r5 = CPy_Decode(b, r3, r4)
340+
r0 = CPy_DecodeUtf8(b)
341+
r1 = CPy_DecodeUtf8(b)
342+
r2 = 'utf-8'
343+
r3 = 'backslashreplace'
344+
r4 = CPy_Decode(b, r2, r3)
346345
return 1
347346

348347
[case testEncode_64bit]

0 commit comments

Comments
 (0)