Commit a2f6fde0 authored by Wenzel Jakob's avatar Wenzel Jakob
Browse files

support for overriding virtual functions

parent 04358b02
......@@ -61,6 +61,7 @@ add_library(example SHARED
example/example9.cpp
example/example10.cpp
example/example11.cpp
example/example12.cpp
)
set_target_properties(example PROPERTIES PREFIX "")
......
......@@ -38,6 +38,7 @@ The following core C++ features can be mapped to Python
- STL data structures
- Smart pointers with reference counting like `std::shared_ptr`
- Internal references with correct reference counting
- C++ classes with virtual (and pure virtual) methods can be extended in Python
## Goodies
In addition to the core functionality, pybind11 provides some extra goodies:
......
......@@ -20,6 +20,7 @@ void init_ex8(py::module &);
void init_ex9(py::module &);
void init_ex10(py::module &);
void init_ex11(py::module &);
void init_ex12(py::module &);
PYTHON_PLUGIN(example) {
py::module m("example", "pybind example plugin");
......@@ -35,6 +36,7 @@ PYTHON_PLUGIN(example) {
init_ex9(m);
init_ex10(m);
init_ex11(m);
init_ex12(m);
return m.ptr();
}
/*
example/example12.cpp -- overriding virtual functions from Python
Copyright (c) 2015 Wenzel Jakob <wenzel@inf.ethz.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#include "example.h"
#include <pybind/functional.h>
/* This is an example class that we'll want to be able to extend from Python */
class Example12 {
public:
Example12(int state) : state(state) {
cout << "Constructing Example12.." << endl;
}
~Example12() {
cout << "Destructing Example12.." << endl;
}
virtual int run(int value) {
std::cout << "Original implementation of Example12::run(state=" << state
<< ", value=" << value << ")" << std::endl;
return state + value;
}
virtual void pure_virtual() = 0;
private:
int state;
};
/* This is a wrapper class that must be generated */
class PyExample12 : public Example12 {
public:
using Example12::Example12; /* Inherit constructors */
virtual int run(int value) {
/* Generate wrapping code that enables native function overloading */
PYBIND_OVERLOAD(
int, /* Return type */
Example12, /* Parent class */
run, /* Name of function */
value /* Argument(s) */
);
}
virtual void pure_virtual() {
PYBIND_OVERLOAD_PURE(
void, /* Return type */
Example12, /* Parent class */
pure_virtual /* Name of function */
/* This function has no arguments */
);
}
};
int runExample12(Example12 *ex, int value) {
return ex->run(value);
}
void runExample12Virtual(Example12 *ex) {
ex->pure_virtual();
}
void init_ex12(py::module &m) {
/* Important: use the wrapper type as a template
argument to class_<>, but use the original name
to denote the type */
py::class_<PyExample12>(m, "Example12")
/* Declare that 'PyExample12' is really an alias for the original type 'Example12' */
.alias<Example12>()
.def(py::init<int>())
/* Reference original class in function definitions */
.def("run", &Example12::run)
.def("pure_virtual", &Example12::pure_virtual);
m.def("runExample12", &runExample12);
m.def("runExample12Virtual", &runExample12Virtual);
}
#!/usr/bin/env python
from __future__ import print_function
import sys
sys.path.append('.')
from example import Example12, runExample12, runExample12Virtual
class ExtendedExample12(Example12):
def __init__(self, state):
super(ExtendedExample12, self).__init__(state + 1)
self.data = "Hello world"
def run(self, value):
print('ExtendedExample12::run(%i), calling parent..' % value)
return super(ExtendedExample12, self).run(value + 1)
def pure_virtual(self):
print('ExtendedExample12::pure_virtual(): %s' % self.data)
ex12 = Example12(10)
print(runExample12(ex12, 20))
try:
runExample12Virtual(ex12)
except Exception as e:
print("Caught expected exception: " + str(e))
ex12p = ExtendedExample12(10)
print(runExample12(ex12p, 20))
runExample12Virtual(ex12p)
......@@ -37,28 +37,6 @@ void dog_bark(const Dog &dog) {
dog.bark();
}
class Example5 {
public:
Example5(py::handle self, int state)
: self(self), state(state) {
cout << "Constructing Example5.." << endl;
}
~Example5() {
cout << "Destructing Example5.." << endl;
}
void callback(int value) {
py::gil_scoped_acquire gil;
cout << "In Example5::callback() " << endl;
py::object method = self.attr("callback");
method.call(state, value);
}
private:
py::handle self;
int state;
};
bool test_callback1(py::object func) {
func.call();
return false;
......@@ -69,16 +47,11 @@ int test_callback2(py::object func) {
return result.cast<int>();
}
void test_callback3(Example5 *ex, int value) {
py::gil_scoped_release gil;
ex->callback(value);
}
void test_callback4(const std::function<int(int)> &func) {
void test_callback3(const std::function<int(int)> &func) {
cout << "func(43) = " << func(43)<< std::endl;
}
std::function<int(int)> test_callback5() {
std::function<int(int)> test_callback4() {
return [](int i) { return i+1; };
}
......@@ -99,8 +72,4 @@ void init_ex5(py::module &m) {
m.def("test_callback2", &test_callback2);
m.def("test_callback3", &test_callback3);
m.def("test_callback4", &test_callback4);
m.def("test_callback5", &test_callback5);
py::class_<Example5>(m, "Example5")
.def(py::init<py::object, int>());
}
......@@ -24,29 +24,17 @@ from example import test_callback1
from example import test_callback2
from example import test_callback3
from example import test_callback4
from example import test_callback5
from example import Example5
def func1():
print('Callback function 1 called!')
def func2(a, b, c, d):
print('Callback function 2 called : ' + str(a) + ", " + str(b) + ", " + str(c) + ", "+ str(d))
return c
class MyCallback(Example5):
def __init__(self, value):
Example5.__init__(self, self, value)
def callback(self, value1, value2):
print('got callback: %i %i' % (value1, value2))
return d
print(test_callback1(func1))
print(test_callback2(func2))
callback = MyCallback(3)
test_callback3(callback, 4)
test_callback4(lambda i: i+1)
f = test_callback5()
test_callback3(lambda i: i + 1)
f = test_callback4()
print("func(43) = %i" % f(43))
......@@ -601,6 +601,7 @@ template <typename T> inline object cast(const T &value, return_value_policy pol
}
template <typename T> inline T handle::cast() { return pybind::cast<T>(m_ptr); }
template <> inline void handle::cast() { return; }
template <typename... Args> inline object handle::call(Args&&... args_) {
const size_t size = sizeof...(Args);
......@@ -624,6 +625,8 @@ template <typename... Args> inline object handle::call(Args&&... args_) {
PyTuple_SetItem(tuple, counter++, result);
PyObject *result = PyObject_CallObject(m_ptr, tuple);
Py_DECREF(tuple);
if (result == nullptr && PyErr_Occurred())
throw error_already_set();
return object(result, false);
}
......
......@@ -27,6 +27,7 @@
#include <vector>
#include <string>
#include <stdexcept>
#include <unordered_set>
#include <unordered_map>
#include <memory>
......@@ -114,13 +115,6 @@ struct buffer_info {
}
};
// C++ bindings of core Python exceptions
struct stop_iteration : public std::runtime_error { public: stop_iteration(const std::string &w="") : std::runtime_error(w) {} };
struct index_error : public std::runtime_error { public: index_error(const std::string &w="") : std::runtime_error(w) {} };
struct error_already_set : public std::exception { public: error_already_set() {} };
/// Thrown when pybind::cast or handle::call fail due to a type casting error
struct cast_error : public std::runtime_error { public: cast_error(const std::string &w = "") : std::runtime_error(w) {} };
NAMESPACE_BEGIN(detail)
inline std::string error_string();
......@@ -145,10 +139,19 @@ struct type_info {
void *get_buffer_data = nullptr;
};
struct overload_hash {
inline std::size_t operator()(const std::pair<const PyObject *, const char *>& v) const {
size_t value = std::hash<const void *>()(v.first);
value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2);
return value;
}
};
/// Internal data struture used to track registered instances and types
struct internals {
std::unordered_map<const std::type_info *, type_info> registered_types;
std::unordered_map<void *, PyObject *> registered_instances;
std::unordered_map<const void *, PyObject *> registered_instances;
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
};
/// Return a reference to the current 'internals' information
......@@ -176,5 +179,20 @@ template <typename T, size_t N> struct decay<T[N]> { typedef typename deca
/// Helper type to replace 'void' in some expressions
struct void_type { };
/// to_string variant which also accepts strings
template <typename T> inline typename std::enable_if<!std::is_enum<T>::value, std::string>::type
to_string(const T &value) { return std::to_string(value); }
template <> inline std::string to_string(const std::string &value) { return value; }
template <typename T> inline typename std::enable_if<std::is_enum<T>::value, std::string>::type
to_string(T value) { return std::to_string((int) value); }
NAMESPACE_END(detail)
// C++ bindings of core Python exceptions
struct stop_iteration : public std::runtime_error { public: stop_iteration(const std::string &w="") : std::runtime_error(w) {} };
struct index_error : public std::runtime_error { public: index_error(const std::string &w="") : std::runtime_error(w) {} };
struct error_already_set : public std::runtime_error { public: error_already_set() : std::runtime_error(detail::error_string()) {} };
/// Thrown when pybind::cast or handle::call fail due to a type casting error
struct cast_error : public std::runtime_error { public: cast_error(const std::string &w = "") : std::runtime_error(w) {} };
NAMESPACE_END(pybind)
......@@ -25,8 +25,6 @@ public:
object src(src_, true);
value = [src](Args... args) -> Return {
object retval(pybind::handle(src).call(std::move(args)...));
if (retval.ptr() == nullptr && PyErr_Occurred())
throw error_already_set();
/* Visual studio 2015 parser issue: need parentheses around this expression */
return (retval.template cast<Return>());
};
......
......@@ -24,7 +24,6 @@
#endif
#include <pybind/cast.h>
#include <iostream>
NAMESPACE_BEGIN(pybind)
......@@ -46,12 +45,8 @@ template <typename T> inline arg_t<T> arg::operator=(const T &value) { return ar
/// Annotation for methods
struct is_method {
#if PY_MAJOR_VERSION < 3
PyObject *class_;
is_method(object *o) : class_(o->ptr()) { }
#else
is_method(object *) { }
#endif
};
/// Annotation for documentation
......@@ -76,9 +71,7 @@ private:
short keywords = 0;
return_value_policy policy = return_value_policy::automatic;
std::string signature;
#if PY_MAJOR_VERSION < 3
PyObject *class_ = nullptr;
#endif
PyObject *sibling = nullptr;
const char *doc = nullptr;
function_entry *next = nullptr;
......@@ -126,21 +119,18 @@ private:
kw[entry->keywords++] = "self";
kw[entry->keywords++] = a.name;
}
template <typename T>
static void process_extra(const pybind::arg_t<T> &a, function_entry *entry, const char **kw, const char **def) {
if (entry->is_method && entry->keywords == 0)
kw[entry->keywords++] = "self";
kw[entry->keywords] = a.name;
def[entry->keywords++] = strdup(std::to_string(a.value).c_str());
def[entry->keywords++] = strdup(detail::to_string(a.value).c_str());
}
static void process_extra(const pybind::is_method &m, function_entry *entry, const char **, const char **) {
entry->is_method = true;
#if PY_MAJOR_VERSION < 3
entry->class_ = m.class_;
#else
(void) m;
#endif
}
static void process_extra(const pybind::return_value_policy p, function_entry *entry, const char **, const char **) { entry->policy = p; }
static void process_extra(pybind::sibling s, function_entry *entry, const char **, const char **) { entry->sibling = s.value; }
......@@ -366,35 +356,38 @@ private:
m_entry->sibling = PyMethod_GET_FUNCTION(m_entry->sibling);
#endif
function_entry *entry = m_entry;
bool overloaded = false;
if (!entry->sibling || !PyCFunction_Check(entry->sibling)) {
entry->def = new PyMethodDef();
memset(entry->def, 0, sizeof(PyMethodDef));
entry->def->ml_name = entry->name;
entry->def->ml_meth = reinterpret_cast<PyCFunction>(*dispatcher);
entry->def->ml_flags = METH_VARARGS | METH_KEYWORDS;
capsule entry_capsule(entry, [](PyObject *o) { destruct((function_entry *) PyCapsule_GetPointer(o, nullptr)); });
m_ptr = PyCFunction_New(entry->def, entry_capsule.ptr());
function_entry *s_entry = nullptr, *entry = m_entry;
if (m_entry->sibling && PyCFunction_Check(m_entry->sibling)) {
capsule entry_capsule(PyCFunction_GetSelf(m_entry->sibling), true);
s_entry = (function_entry *) entry_capsule;
if (s_entry->class_ != m_entry->class_)
s_entry = nullptr; /* Method override */
}
if (!s_entry) {
m_entry->def = new PyMethodDef();
memset(m_entry->def, 0, sizeof(PyMethodDef));
m_entry->def->ml_name = m_entry->name;
m_entry->def->ml_meth = reinterpret_cast<PyCFunction>(*dispatcher);
m_entry->def->ml_flags = METH_VARARGS | METH_KEYWORDS;
capsule entry_capsule(m_entry, [](PyObject *o) { destruct((function_entry *) PyCapsule_GetPointer(o, nullptr)); });
m_ptr = PyCFunction_New(m_entry->def, entry_capsule.ptr());
if (!m_ptr)
throw std::runtime_error("cpp_function::cpp_function(): Could not allocate function object");
} else {
m_ptr = entry->sibling;
m_ptr = m_entry->sibling;
inc_ref();
capsule entry_capsule(PyCFunction_GetSelf(m_ptr), true);
function_entry *parent = (function_entry *) entry_capsule, *backup = parent;
while (parent->next)
parent = parent->next;
parent->next = entry;
entry = backup;
overloaded = true;
entry = s_entry;
while (s_entry->next)
s_entry = s_entry->next;
s_entry->next = m_entry;
}
std::string signatures;
int index = 0;
function_entry *it = entry;
while (it) { /* Create pydoc it */
if (overloaded)
if (s_entry)
signatures += std::to_string(++index) + ". ";
signatures += "Signature : " + std::string(it->signature) + "\n";
if (it->doc && strlen(it->doc) > 0)
......@@ -783,6 +776,12 @@ public:
metaclass().attr(name) = property;
return *this;
}
template <typename target> class_ alias() {
auto &instances = pybind::detail::get_internals().registered_types;
instances[&typeid(target)] = instances[&typeid(type)];
return *this;
}
private:
static void init_holder(PyObject *inst_) {
instance_type *inst = (instance_type *) inst_;
......@@ -882,6 +881,43 @@ public:
inline ~gil_scoped_release() { PyEval_RestoreThread(state); }
};
inline function get_overload(const void *this_ptr, const char *name) {
handle py_object = detail::get_object_handle(this_ptr);
handle type = py_object.get_type();
auto key = std::make_pair(type.ptr(), name);
/* Cache functions that aren't overloaded in python to avoid
many costly dictionary lookups in Python */
auto &cache = detail::get_internals().inactive_overload_cache;
if (cache.find(key) != cache.end())
return function();
function overload = (function) py_object.attr(name);
if (overload.is_cpp_function()) {
cache.insert(key);
return function();
}
PyFrameObject *frame = PyThreadState_Get()->frame;
pybind::str caller = pybind::handle(frame->f_code->co_name).str();
if (strcmp((const char *) caller, name) == 0)
return function();
return overload;
}
#define PYBIND_OVERLOAD_INT(ret_type, class_name, name, ...) { \
pybind::gil_scoped_acquire gil; \
pybind::function overload = pybind::get_overload(this, #name); \
if (overload) \
return overload.call(__VA_ARGS__).cast<ret_type>(); }
#define PYBIND_OVERLOAD(ret_type, class_name, name, ...) \
PYBIND_OVERLOAD_INT(ret_type, class_name, name, __VA_ARGS__) \
return class_name::name(__VA_ARGS__)
#define PYBIND_OVERLOAD_PURE(ret_type, class_name, name, ...) \
PYBIND_OVERLOAD_INT(ret_type, class_name, name, __VA_ARGS__) \
throw std::runtime_error("Tried to call pure virtual function \"" #name "\"");
NAMESPACE_END(pybind)
#if defined(_MSC_VER)
......
......@@ -331,10 +331,12 @@ public:
PyObject *ptr = m_ptr;
if (ptr == nullptr)
return false;
#if PY_MAJOR_VERSION < 3
#if PY_MAJOR_VERSION >= 3
if (PyInstanceMethod_Check(ptr))
ptr = PyInstanceMethod_GET_FUNCTION(ptr);
#endif
if (PyMethod_Check(ptr))
ptr = PyMethod_GET_FUNCTION(ptr);
#endif
return PyCFunction_Check(ptr);
}
};
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment