Wednesday, October 16, 2013

Creating Recursive Lambdas ... and returning them too!

Ever since C++ adopted lambda expressions, many have stumbled upon the question whether they can be recursive. IMO, if an anonymous function needs to call itself why not create a named function/functor in the first place? But I'm not here today to argue whether it is a good or bad design. So here it is: the most common way to create a recursive lambda in C++11.
void test() 
{
  std::function<int(int)> fib = [&fib](int n)
  {
    return (n <= 2)? 1 : fib(n-1) + fib(n-2);
  };
}
The way it works is quite interesting. First, I created a std::function object, fib, and captured that in the lambda by reference. When the lambda captures fib, it is uninitialized because nothing has been assigned to it yet. The compiler knows the type of fib so it does not complain. It happily creates a closure with an uninitialized std::function. Right after that, it assigns the closure to the fib object so it gets initialized. Therefore, the reference inside lambda also works automatically. The lambda simply uses the captured std::function reference to call itself.

A problem with this approach is that the recursive lambda can not be returned. Why? When the function ends, so does the fib object and consequently, the reference inside the closure becomes invalid. Capture by value is also futile because in that case the closure object will end up getting a copy of the uninitialized std::function object that upon invocation throws std::bad_function_call exception.

Is there a way out? As it turns out, there is!

UPDATE: Within hours of uploading the original post, I came to understand a sleeker way to create and return recursive lambdas. Thanks to the generous people out there on reddit/r/cpp.
std::function<int(int)> create()
{
  int foo = 20;
  std::function<int(int)> f = 
    [=](int n) mutable {
         std::function<int(int)> recurse;
         recurse = [&](int n) { 
            foo = 10;
            return (n<=2)? 1 : recurse(n-1) + recurse(n-2); 
         };  
         return recurse(n);
       };  
  std::cout << f(6) << ", foo=" << foo << "\n"; // prints "8, foo=20"
  return f;
}
The technique involves creating nested lambdas. You don't have to return the inner recursive lambda, which becomes somewhat clunky as described later in this post. Instead, you return the outer lambda, which when invoked creates a recursive lambda using the same technique described at the beginning. It also invokes it and returns the result because the inner lambda does not live for too long.

Capturing the right state at the right place may become important if your recursive lambda is stateful. The outer lambda (copied in f) captures the state by value. I.e., it captures foo by value. The inner lambda captures its state by reference. However, the foo it refers is not the local foo variable in function create. Instead, it refers to the state captured by the outer lambda. So when you modify foo inside the inner lambda, it modifies the copy inside outer and the local foo remains unchanged. Further, to allow modification, the outer lambda must be mutable.

This behavior does not appear surprising at all when you consider the "context" in which the inner lambda runs. The create function might have returned long before the inner lambda ever got a chance to execute. The compiler is smart to figure that out and so the context of the capture for the inner lambda is limited to the state captured by the outer lambda.

foo is referenced inside the inner lambda only. But syntactically it lives inside the outer lambda too. And therefore, it must be captured by the outer lambda. This allows clear separation of who manages the state and who manages recursion.

Of course, there are many other combinations of capturing state for the inner and outer lambdas. Those are left to the reader as an excercise!

Finally, returning the std::function object will make a copy of the outer lambda and in turn the state captured by it. In most cases this is desirable. But if it is not, you could create a shared_ptr to the std::function object and use that to pass the stateful recursive lambda around.

std::shared_ptr<std::function<int(int)>> is not callable, however. I mean, you have to dereference it before you can call it like a function. So it cannot be passed to STL algorithms. A wrapper may be what you need to take care of that.
template <class T>
class func_wrapper // version 3
{
  std::shared_ptr<std::function<T>> func;

public:

  func_wrapper(std::function<T> f)
    : func(std::make_shared<std::function<T>>(std::move(f)))
  { }

  template <class... Args>
  auto operator () (Args&&... args) const 
    -> decltype((*func)(std::forward<Args>(args)...)) 
  {
    return (*func)(std::forward<Args>(args)...);
  }
};
Continue reading if you are interested in avoiding nested lambdas.

WARNING: Hacks ahead!
The idea is something like this: If the std::function is captured by value, we get two copies. One is uninitialized and the other is initialized. If somehow, during the initialization of the std::function object, we could go back and modify the copy inside the closure, may be we will get lucky.

So here is how I've to rewrite the above recursive lambda. Because I also intend to return it, I captured the variables by-value.
const func_wrapper<int(int)> create()
{
  func_wrapper<int(int)> fib;
  fib = [fib](int n)
  {
    return (n <= 2)? 1 : fib(n-1) + fib(n-2);
  };
  return fib;
}

The func_wrapper class does a bunch of interesting things and most of my discussion is going to center around it. Note, however, that I've separated the default initialized fib object from the assignment. It is very important to do it that way.

The reason behind separating the default construction from assignment is that copy-constructor and copy-assignment operators are called separately for the fib object. First, the fib object is default-initialized so we have something "decent" to work with. When the fib object is captured by-value, its copy-constructor is called as part of the usual capture semantics. The copy constructor is the only chance we get to "interact" with the func_wrapper inside the closure object. Later on, copy-assignment operator is called on fib with closure object as the right hand side.

These two functions give us just the right opportunity to setup some references inside func_wrapper to change the state of the captured func_wrapper. So lets look at the func_wrapper class without further ado.
template <class T>
class func_wrapper
{
  std::shared_ptr<std::function<T>> func;
  func_wrapper *captured;

public:

  func_wrapper() 
    : captured(nullptr)
  {}

  func_wrapper(const func_wrapper &) = default;
  func_wrapper(func_wrapper &&) = default;

  // non-const copy-ctor
  func_wrapper(func_wrapper & f)
    : func(f.func),
      captured(nullptr)
  {
    f.captured = this;
  }

  template <class U> 
  const func_wrapper & operator = (U&& closure)
  {
    func = std::make_shared<std::function<T>>();
    if(captured)
      captured->func = func;

    (*func) = std::forward<U>(closure);
    return *this;
  }

  template <class... Args>
  auto operator () (Args&&... args) const 
    -> decltype((*func)(std::forward<Args>(args)...)) 
  {
    return (*func)(std::forward<Args>(args)...);
  }
};

The func_wrapper class is small but quite interesting (at least, I think so!). The role of func_wrapper is to behave just like std::function but the main difference is that func_wrapper uses a shared_ptr to a std::function. This way, multiple func_wrapper objects can share the same std::function object. You know where I am going with this, right?!

The default constructor, copy-constructor (const version), and the move-constructor of func_wrapper are very typical. There is also a non-const copy-constructor: "func_wrapper(func_wrapper &)". This constructor is called only when the right-hand-side object is a non-const lvalue. I.e., this is the constructor that gets invoked when the fib object is captured by-value in the lambda.

What's unusual about this constructor is that it modifies the right-hand-side lvalue. It is allowed because the parameter is non-const. We assign this pointer of the nameless func_wrapper (inside the closure) to the named fib object (outside the closure). I.e., the named fib object now "links" to the namesless object inside the closure. The func_wrapper inside the closure links to NULL. You can think of it as a little linked-list of func_wrapper objects. The shared_ptr in both objects are just default initialized at this stage.

Right after the copy construction of the captured variables, compiler creates a closure object and assigns it to the named func_wrapper object ("fib"). The template assignment operator (with universal reference) of func_wrapper is invoked. First thing we do is to allocate an empty std::function object using make_shared. This initializes the func shared pointer in the named func_wrapper ("fib" again). We cannot use the lambda closure object to create the std::function just yet because the std::function constructor is going to make an internal copy of the closure. Our captured pointer, however, points to the one inside the closure object, which is still on the stack. We have to modify the shard_ptr inside the closure object before passing it on to the std::function. That's what we do next.

We assign the (*this).func object to the captured func object only if something has been captured. As it turns out, we have captured a pointer to the func_wrapper inside the closure object. So we assign the shared_ptr to it.

At this stage, the reference count of the empty std::function object is 2. One inside the "fib" object and other one inside the closure object.

Now we are ready to assign the closure object to initialize the empty std::function. We use std::forward to pass on the closure object to the std::function assignment operator. Thus, std::function does not really make copies of the closure but instead moves it. The last step is to return *this.

At this stage, the lambda is ready to invoke itself recursively. The first call is made, of course, using the named func_wrapper.
int main(void)
{
  auto const fib = create();
  std::cout << fib(6); // prints 8
}
The func_wrapper has overloaded function call operator just like std::function does. It simply uses std::forward to pass on the arguments to the closure held by the shared_ptr.

Note that the fib object in main is const. This ensures that no further modifications can be made to this recursive lambda wrapper. After all, this the only reference we have to an anonymous recursive lambda closure. The fib object can be passed by-value to other functions without issues because the const version of the copy-constructor gets invoked in those cases. That constructor makes no modifications to the rhs lvalue.

So, are we done? Not quite, unfortunately! There's a memory leak!

Fixing the Memory Leak

Note that the shread_ptr inside the closure object points to itself. That's a cycle! Even though the named fib object goes out of scope the reference count of the std::function never reaches zero because there is at least one shared_ptr keeping it alive. Therefore, you end up leaking memory.

So we need std::weak_ptr. And here is an improved version of the func_wrapper that does not leak.
template <class T>
class func_wrapper // version 2
{
  std::shared_ptr<std::function<T>> func;
  std::weak_ptr<std::function<T>> weak_func;
  func_wrapper *captured;

public:

  func_wrapper() 
    : captured(nullptr)
  {}

  func_wrapper(const func_wrapper &) = default;
  func_wrapper(func_wrapper &&) = default;

  func_wrapper(func_wrapper & f)
    : func(f.func),
      weak_func(f.weak_func),
      captured(nullptr)
  {
    f.captured = this;
  }

  template <class U> 
  typename 
    std::enable_if<
      is_callable<typename std::remove_reference<U>::type>::value,
      const func_wrapper &>::type
  operator = (U&& closure)
  {
    weak_func = func = std::make_shared<std::function<T>>();
    if(captured)
       captured->weak_func = func;
    (*func) = std::forward<U>(closure);
    return *this;
  }

  template <class... Args>
  auto operator () (Args&&... args) const 
    -> decltype((*func)(std::forward<Args>(args)...)) 
  {
    if(func)
      return (*func)(std::forward<Args>(args)...);
    else
      return (*weak_func.lock())(std::forward<Args>(args)...);
  }
};


The overall concept is the same except that we never initialize the shared_ptr inside the closure object. Instead we use a std::weak_ptr. std::weak_ptr does not increase the reference count. When it needs the object it is referring to, it has to create a shared_ptr object and dereference that one. If in the meanwhile the reference count drops to zero, the std::function and the closure is reclaimed as expected.

There are two ways to invoke the recursive closure from within the overloaded operator (). If there is a named reference ("fib") the shared_ptr is active in that case. So we use that one in the first condition in operator(). Within the closure, however, there is a "half-backed" func_wrapper that does not have strong reference but only a weak one. So the second condition turns the weak reference into a strong one (by calling .lock) and forwards the arguments just as usual.

Additionally, I'm using std::enable_if in the template assignment operator to make sure that what you are assigning is indeed a callable object. is_callable is a non-standard trait I found here.

So there you have it. A reusable solution in C++11 for recursive lambdas that you can return and use just like any other object. BTW, if you know a compelling use-case for using recursive lambdas, I'm all ears!

6 comments:

Lev said...

While this is a nice trick, recursive functions are very inefficient, so it's best not to use them at all.

Anonymous said...

I'd like to be able to create (sort of) recursive lambdas, but the context I want them doesn't allow this solution. I'm passing a lambda directly to a function, so I can't declare any variable at the same time. Passing lambdas directly is sort of where lambdas shine. If you have to write the expression elsewhere to declare a variable you might just as well just go with an entirely separate named functor or function. Useful recursive lambdas (or just a lambda that needs to refer to itself, say, to place itself on an asynchronous queue) really requires a keyword or something so that the lambda can literally refer to itself.

jacobgypsum said...

If the lambda is not stateful, it is enough to declare it static (so that you can return it as reference), isn't it?

jacobgypsum said...

I wrote a solution that does not use a custom class, returns a recursive function itself, and works for stateful recursive functions too. The idea is that if we need to use the variable after our function returns, we need to create it on the heap. Since we don't want to leak memory, we use smart pointers. I put both a shared_ptr and a weak_ptr in the lambda. Then, using a magic parameter, I reset the shared_ptr in the copy that refers to itself, while it is retained in the copy I return.

std::function fib_generator(int a, int b)
{
std::shared_ptr> fib {new std::function};
std::weak_ptr> wfib(fib);
*fib=[wfib, fib, a, b](int n) mutable
{
if (n==-1) { fib.reset(); return 0; }
std::shared_ptr> sfib=wfib.lock();
return (n <= 2)? ((n==1) ? a : b) : (*sfib)(n-1) + (*sfib)(n-2);
};
std::function rfib=*fib;
(*fib)(-1);
return rfib;
}

int main()
{
auto fib1=fib_generator(1, 1);
auto fib2=fib_generator(2, 2);
printf("%i\n", fib1(10));
printf("%i\n", fib2(10));
return 0;
}

jacobgypsum said...

I wrote an even more ugly variant that does not use magic values (though still requires using a constant parameter with which the function is known to return fast). I put a shared_ptr in the lambda that shows whether it is the first run. Using this, I reset both shared_ptr's in the copy that refers to itself.

std::function fib_generator(int a, int b)
{
std::shared_ptr> fib {new std::function};
std::weak_ptr> wfib(fib);
std::shared_ptr first_run {new bool {true}};
*fib=[wfib, fib, first_run, a, b](int n) mutable
{
if (first_run && *first_run)
{
fib.reset();
*first_run=false;
first_run.reset();
}
std::shared_ptr> sfib=wfib.lock();
return (n <= 2)? ((n==1) ? a : b) : (*sfib)(n-1) + (*sfib)(n-2);
};
std::function rfib=*fib;
(*fib)(1);
return rfib;
}

Anyone has a nicer one?

jacobgypsum said...

Wtf? Blogger removed my template parameters, except the >.