dependent type

我也来写一个类型安全的 printf

本文使用的语言是 C++17,并且只会实现 printf 中很少的一部分功能,主要是展示如何在 C++ 中利用模板和编译期计算来实现依赖类型(dependent type)。本文写得较为简略,需要读者对 C++17 有一些初步的了解。

先来看一个简单的例子,只接受 "%d" 和 "%s" 这两个模式串的 println。你可以点击代码下面的 run 来运行该代码,也可以试着修改部分代码,来检验我们实现的类型约束。

#include <iostream>
#include <functional>
#include <type_traits>

using namespace std;

template<const char*format>
static auto println() {
    if constexpr (format[0]=='%') {
        if constexpr (format[1]=='d') {
            return [](int x){cout<<x<<endl;};
        } else if constexpr (format[1]=='s') {
            return [](const char* x){cout<<x<<endl;};
        } else {
            return "error";
        }
    } else {
        return "error";
    }
}

constexpr const char fs1[] = "%d";
constexpr const char fs2[] = "%s";
int main(){
    auto f1 = println<fs1>();
    f1(1);
    // f1("Hello, DT.\n");

    auto f2 = println<fs2>();
    // f2(1);
    f2("Hello, DT!");
    return 0;
}

这个例子可以说明,在 C++17 中,我们可以通过 auto 和 constexpr 来实现类型计算。if constexpr 语句的不同分支下返回的类型不同,使得我们不需要通过繁杂的函数重载和 SFINAE 来进行类型计算。

接下来进入正题,先介绍一些辅助函数/类型。

自定义的 Unit 类型

struct unit_t {char x;};

获取单参函数的参数类型

template<typename T,typename R>
constexpr auto get_arg(R (*f)(T)){
    return T{};
}

值得一提的是,C++ 中并不能直接将类型作为返回值,但我们可以返回一个该类型的无意义的数据,需要用这个类型的时候用 decltype 取出来。

因为一个 printf 中可能会打印多个变量,因此我们需要想办法去遍历这些变量,在 C++ 模板中常用于遍历的方法是递归。这里,我们引入一种 continuation 类型,简写为 cont。它是一个接受一个参数,返回一个函数的函数。这个返回的函数可以理解为接下来要做的事情。在本文的设计中,typed_printf 函数接受一个模式字符串,返回一个 cont,你喂一个参数,就打印一部分内容,并得到一个新的 cont,直到打印结束。

例如,"Hello, my dear %s %s\n" 把这个模式串传给 typed_printf,你会得到一个 cont1,你给这个 cont1 传一个参数 "friend",它会打印 "Hello, my dear friend",然后返回你一个 cont2。你给 cont2 传一个参数 "c++",它会打印 " c++\n",然后返回你一个什么事也不做的 cont3。

最终的效果是这样的

def_typed_printf(f4,"Hello, my dear %s %s\n");
f4("friend")("c++");

下面是一些与 cont 相关的辅助函数。

判断一个 cont 是否不需要参数。

template<typename T>
constexpr bool cont_takes_no_arg(T cont){
    using cont_t = decay_t<T>;
    using arg_type = decay_t<decltype(get_arg(cont))>;
    return is_same<unit_t,arg_type>::value;
}

计算 cont 的返回值类型

template<typename R,typename X>
constexpr auto cont_ret_type(R (*cont)(X)){
    return R{};
}

template<typename R>
constexpr auto cont_ret_type(R (*cont)()){
    return R{};
}

计算 cont 的参数类型

template<typename R,typename X>
constexpr auto cont_arg_type(R (*cont)(X)){
    return X{};
}

template<typename R>
constexpr auto cont_arg_type(R (*cont)()){
    return unit_t{};
}

cont 的组合函数。print_var 模板产生一个新的 cont,它打印一个变量,然后返回一个指定的 cont。print_const 模板产生打印字符的 cont。

template<typename T,typename R,typename X,R (*cont)(X)>
auto print_var(T x){
    cout<<x;
    return cont;
}

template<typename T,typename R,typename X,R (*cont)(void)>
auto print_var(T x){
    cout<<x;
    return cont();
}

template<char c,typename R,typename X,R (*cont)(X)>
auto print_const(X x){
    cout<<c;
    return cont(x);
}

template<char c,typename R,typename X,R (*cont)(void)>
auto print_const(){
    cout<<c;
    return cont();
}

什么事情也不做的 print_nothing

unit_t print_nothing(){return unit_t{};}

两个辅助宏

#define cont_ret_t decay_t<decltype(cont_ret_type(cont))>
#define cont_arg_t decay_t<decltype(cont_arg_type(cont))>

有了以上准备,我们就能很容易地写出一个类型安全的 printf 了。

template<const char*format,int i>
constexpr auto _typed_printf(){
    if constexpr (format[i]=='%' && format[i+1] == 'd') {
        constexpr auto cont = _typed_printf<format,i+2>();
        return print_var<int,cont_ret_t,cont_arg_t,cont>;
    } else if constexpr (format[i]=='%' && format[i+1] == 's') {
        constexpr auto cont = _typed_printf<format,i+2>();
        return print_var<const char*,cont_ret_t,cont_arg_t,cont>;
    } else if constexpr (format[i]!='\0') {
        constexpr auto cont = _typed_printf<format,i+1>();
        return print_const<format[i],cont_ret_t,cont_arg_t,cont>;
    } else {
        return print_nothing;
    }
}

用 i 来进行字符串的遍历操作,每一步都用 print_var 或 print_const 组合出新的 cont,类似于一个 fold 操作。

所有代码如下:

/*
 * title: type safe printf
 * author: nicekingwei
 * related knowledge:
 *  - value and type
 *      value->value: function
 *      type->value: parametric polymorphism
 *      type->type: generic
 *      value->type: dependent type
 *  - auto
 *  - if constexpr
 */
#include <iostream>
#include <functional>
#include <type_traits>

using namespace std;

template<const char*format>
static auto println() {
    if constexpr (format[0]=='%') {
        if constexpr (format[1]=='d') {
            return [](int x){cout<<x<<endl;};
        } else if constexpr (format[1]=='s') {
            return [](const char* x){cout<<x<<endl;};
        } else {
            return "error";
        }
    } else {
        return "error";
    }
}

struct unit_t {char x;};

template<typename T,typename R>
constexpr auto get_arg(R (*f)(T)){
    return T{};
}

template<typename T>
constexpr bool cont_takes_no_arg(T cont){
    using cont_t = decay_t<T>;
    using arg_type = decay_t<decltype(get_arg(cont))>;
    return is_same<unit_t,arg_type>::value;
}


template<typename T,typename R,typename X,R (*cont)(X)>
auto print_var(T x){
    cout<<x;
    return cont;
}

template<typename T,typename R,typename X,R (*cont)(void)>
auto print_var(T x){
    cout<<x;
    return cont();
}

template<char c,typename R,typename X,R (*cont)(X)>
auto print_const(X x){
    cout<<c;
    return cont(x);
}

template<char c,typename R,typename X,R (*cont)(void)>
auto print_const(){
    cout<<c;
    return cont();
}


template<typename R,typename X>
constexpr auto cont_ret_type(R (*cont)(X)){
    return R{};
}

template<typename R>
constexpr auto cont_ret_type(R (*cont)()){
    return R{};
}

template<typename R,typename X>
constexpr auto cont_arg_type(R (*cont)(X)){
    return X{};
}

template<typename R>
constexpr auto cont_arg_type(R (*cont)()){
    return unit_t{};
}

unit_t print_nothing(){return unit_t{};}

#define cont_ret_t decay_t<decltype(cont_ret_type(cont))>
#define cont_arg_t decay_t<decltype(cont_arg_type(cont))>

template<const char*format,int i>
constexpr auto _typed_printf(){
    if constexpr (format[i]=='%' && format[i+1] == 'd') {
        constexpr auto cont = _typed_printf<format,i+2>();
        return print_var<int,cont_ret_t,cont_arg_t,cont>;
    } else if constexpr (format[i]=='%' && format[i+1] == 's') {
        constexpr auto cont = _typed_printf<format,i+2>();
        return print_var<const char*,cont_ret_t,cont_arg_t,cont>;
    } else if constexpr (format[i]!='\0') {
        constexpr auto cont = _typed_printf<format,i+1>();
        return print_const<format[i],cont_ret_t,cont_arg_t,cont>;
    } else {
        return print_nothing;
    }
}


constexpr const char fs1[] = "%d";
constexpr const char fs2[] = "%s";
constexpr const char fs3[] = "Hello, DT! %s %d\n";
constexpr const char fs4[] = "%d";

#define def_typed_printf(f,str) constexpr static const char str_fmt##f[] = str; auto f = _typed_printf<str_fmt##f,0>();

int main(){
    auto f1 = println<fs1>();
    f1(1);
    // f1("Hello, DT.\n");

    auto f2 = println<fs2>();
    // f2(1);
    f2("Hello, DT!");

    auto f3 = _typed_printf<fs3,0>();
    f3("My string")(233);

    def_typed_printf(f4,"Hello, my dear %s %s\n");
    f4("friend")("c++");
    // f4("friend")(233);

    return 0;
}