Skip to content

Commit 545b3a2

Browse files
committed
[python][RDF] Add support for std::string in AsRVec
1 parent e85f9ba commit 545b3a2

File tree

1 file changed

+18
-6
lines changed
  • bindings/pyroot/pythonizations/python/ROOT/_pythonization

1 file changed

+18
-6
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rvec.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
def _get_cpp_type_from_numpy_type(dtype):
7878
cpptypes = {"i2": "Short_t", "u2": "UShort_t", "i4": "int", "u4": "unsigned int", "i8": "Long64_t", "u8": "ULong64_t", "f4": "float", "f8": "double", "b1": "bool"}
7979

80-
if not dtype in cpptypes:
80+
if dtype not in cpptypes:
8181
raise RuntimeError("Object not convertible: Python object has unknown data-type '" + dtype + "'.")
8282

8383
return cpptypes[dtype]
@@ -93,10 +93,11 @@ def _AsRVec(arr):
9393
This function returns an RVec which adopts the memory of the given
9494
PyObject. The RVec takes the data pointer and the size from the array
9595
interface dictionary.
96+
Note that for arrays of strings, the input strings are copied into the RVec.
9697
"""
9798
import ROOT
9899
import math
99-
import platform
100+
import numpy as np
100101

101102
# Get array interface of object
102103
interface = arr.__array_interface__
@@ -110,17 +111,28 @@ def _AsRVec(arr):
110111

111112
# Get the typestring and properties thereof
112113
typestr = interface["typestr"]
114+
dtype = typestr[1:]
115+
116+
# Construct an RVec of strings
117+
if dtype == "O" or dtype.startswith("U"):
118+
underlying_object_types = {type(elem) for elem in arr}
119+
if len(underlying_object_types) > 1:
120+
raise TypeError("All elements in the numpy array must be of the same type. Found types: {}".format(underlying_object_types))
121+
122+
if underlying_object_types and underlying_object_types.pop() in [str, np.str_]:
123+
return ROOT.VecOps.RVec["std::string"](arr)
124+
else:
125+
raise TypeError("Cannot create an RVec from a numpy array of data type object.")
126+
113127
if len(typestr) != 3:
114128
raise RuntimeError(
115129
"Object not convertible: __array_interface__['typestr'] returned '"
116130
+ typestr
117131
+ "' with invalid length unequal 3."
118132
)
119-
120-
dtype = typestr[1:]
121-
cppdtype = _get_cpp_type_from_numpy_type(dtype)
122-
133+
123134
# Construct an RVec of the correct data-type
135+
cppdtype = _get_cpp_type_from_numpy_type(dtype)
124136
out = ROOT.VecOps.RVec[cppdtype](ROOT.module.cppyy.ll.reinterpret_cast[f"{cppdtype} *"](data), size)
125137

126138
# Bind pyobject holding adopted memory to the RVec

0 commit comments

Comments
 (0)