Alex Constantin-Gomez

Implementing pattern matching in C++17

Pattern matching is one of my favourite language features available in many modern programming languages, especially functional languages like Haskell or Scala. Unfortunately, it is not a feature that is part of the C++ language standard. However, with recent updates to the C++ language and standard library, it is possible to implement very primitive pattern matching constructs. I will explain how to do this in under 50 lines of code using some of the features introduced in C++17.

Visiting variants

The main introductions to C++17 that allows us to easily implement pattern matching are std::variant and std::visit. A variant is simply a type-safe union that allows us to store one of multiple specified types in a memory location at any given time. To operate on the value stored in the variant, we typically use the std::visit function, which invokes a user-supplied functor on the data, dispatching to the correct type.

The usage is simple, but verbose and cumbersome:

std::variant<int, bool, double> v{42};
struct visitor
{
    void operator()(int d)     { std::cout << "int" << std::endl; }
    void operator()(bool d)    { std::cout << "bool" << std::endl; }
    void operator()(double d)  { std::cout << "double" << std::endl; }
};

std::visit(visitor{}, v);  // Prints "int"

Alternatively, we use a constexpr if statement inside a generic lambda:

std::visit([](auto d)
{
    if constexpr(std::is_same_v<decltype(d), int>)
        std::cout << "int" << std::endl;
    else if constexpr(std::is_same_<decltype(d), bool>)
        std::cout << "bool" << std::endl;
    // ...
}, v);

Although both techniques do the job, we would like something shorter and more readable, that resembles pattern matching.

Implementing a match statement

The syntax I am aiming to implement looks like this:

std::variant<double, int, std::string> val = "test";
match(val).on(
{
    [](const std::string& str)    { std::cout << "str: " << str << std::endl; },
    [](int i)                     { std::cout << "int: " << i << std::endl; },
    [](double d)                  { std::cout << "double: " << d << std::endl; }
);

I find this syntax much more readable and it resembles a functional programming language.

As you can see, we take in a set of lambdas, where each one handles a particular type allowed by the variant. So, how do we implement this list of lambdas? The first important observation is to realise that a lambda is just a struct with the operator() overloaded. Therefore, if we could "compose" multiple structs together such that we have a single struct containing an overloaded operator() for each type of the variant, the problem would be solved. How do we "compose" these structs together then? Well, we can just use inheritance! This technique is called an overload set.

As mentioned, we can compose functors through inheritance:

struct visitor_int { void operator()(int x) { std::cout << "int"; } };
struct visitor_float { void operator()(float x) { std::cout << "float"; } };

// Compose both functors via inheritance:
struct visitor : visitor_int, visitor_float
{
    using visitor_int::operator();
    using visitor_float::operator();
};

Note that we explicitly have the two using statements to bring the overloads into the same scope of visitor in order to enable overload resolution. This is because a call to visitor::operator() would be ambiguous because the compiler performs name resolution before overload resolution. Given the example above, we can generalise the pattern by templating over the base classes:

template <typename A, typename B>
struct overload_set : A, B
{
    using A::operator();
    using B::operator();
};

Furthermore, we can generalise this pattern even more by using variadic templates in conjunction with the variadic using statements which were introduced in C++17:

template <typename... Functors>
struct overload_set : Functors...
{
    using Functors::operator()...;
};

We can now instantiate an overload_set with as many lambdas as we want:

overload_set os{
    [](int d)   { ...; },
    [](float d) { ...; }
};

Putting everything together

The difficult part is done now, we just need to put everything together to implement the syntax described earlier. We can do this by creating a generic class which accepts a variant and performs visitation using an overloaded set:

template <typename Variant>
class match
{
public:
    constexpr explicit match(const Variant& variant) : d_variant(variant) {}

    template<typename... Fs>
    constexpr auto on(Fs... fs)
    {
        return std::visit(overload_set<Fs...> { std::forward<Fs>(fs)... }, d_variant);
    }

private:
    Variant d_variant;
};

In very few lines of code, we have implemented a simple but powerful pattern matching construct that allows us to dispatch based on the type of the value of a variant.
The full code with usage examples can be found on my GitHub.