欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TensorFlow 源码阅读[1] OpKernel的注册

程序员文章站 2022-07-07 11:22:56
...

OpKernel介绍

TensorFlow 源码阅读[1] OpKernel的注册

在TF的架构中,OpKernel是Ops和硬件的中间层,用来抽象统一各个硬件平台上的Kernel类和接口。

注册过程

我们首先大致列出OpKernel注册的过程,后面再详细分析,我们按照调用顺序,从上层往下说:

  1. 在各个xxx_op.cc文件中调用REGISTER_KERNEL_BUILDER()
  2. 调用OpKernelRegistrar的构造函数
  3. 并在该构造函数中调用OpKernelRegistrar::InitInternal
  4. 调用GlobalKernelRegistry获取保存注册信息的map
  5. 将Key和kernel保存到map中

分析

现在我们来逐个分析,在上面我们是从调用过程往下走,在这里,我们尝试从底层往上走。

1.KernelRegistration

首先我们需要关注的是KernelRegistration类,它用来保存OpKernel注册所需的信息,包括KernelDef、kernel的名字以及kernel的创建方法factory:

struct KernelRegistration {
  KernelRegistration(const KernelDef& d, StringPiece c,
                     std::unique_ptr<kernel_factory::OpKernelFactory> f)
      : def(d), kernel_class_name(c), factory(std::move(f)) {}

  const KernelDef def;
  const string kernel_class_name;
  std::unique_ptr<kernel_factory::OpKernelFactory> factory;
};

2.KernelRegistry

这个结构体用来保存OpKernel的注册信息KernelRegistration,并将这些信息保存到一个unordered_multimap里:

struct KernelRegistry {
  mutex mu;
  std::unordered_multimap<string, KernelRegistration> registry
      TF_GUARDED_BY(mu);
};

这个map维持一个Key到OpKernel注册信息之间的关系,而这个Key,是这样生成的:

const string key =
        Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
            kernel_def->label());

既然是unordered_multimap,说明一个Key可以对应多个KernelRegistration。
KernelRegistry的实例是通过下面这个函数构造的:

void* GlobalKernelRegistry() {
  static KernelRegistry* global_kernel_registry = []() {
    KernelRegistry* registry = new KernelRegistry;
    OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
    return registry;
  }();
  return global_kernel_registry;
}

3.OpKernelRegistrar

上面我们提到了OpKernel需要保存的信息,以及这些信息是保存在一个unordered_multimap中的,下面我们要来看这个保存的过程。
我们首先来看这个类的构造函数:

// 构造函数1
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                    std::unique_ptr<OpKernelFactory> factory) {
    // Perform the check in the header to allow compile-time optimization
    // to a no-op, allowing the linker to remove the kernel symbols.
    if (kernel_def != nullptr) {
      InitInternal(kernel_def, kernel_class_name, std::move(factory));
    }
  }

//构造函数2
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                    OpKernel* (*create_fn)(OpKernelConstruction*)) {
    // Perform the check in the header to allow compile-time optimization
    // to a no-op, allowing the linker to remove the kernel symbols.
    if (kernel_def != nullptr) {
      InitInternal(kernel_def, kernel_class_name,
                   absl::make_unique<PtrOpKernelFactory>(create_fn));
    }
  }

这里涉及到另外一个类OpKernelFactory,我们也可以看下它的定义:

class OpKernelFactory {
 public:
  virtual OpKernel* Create(OpKernelConstruction* context) = 0;
  virtual ~OpKernelFactory() = default;
};

从这个类的create函数我们就可以看出,OpKernelRegistrar的亮哥构造函数其实大同小异,第一个参数是kernel_del,第二个参数是kernel_class_name,第三个参数都是创建这个kernel的函数。
我们来看一下OpKernelRegistrar构造函数的核心部分:

void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
                                     StringPiece kernel_class_name,
                                     std::unique_ptr<OpKernelFactory> factory) {
  // See comments in register_kernel::Name in header for info on _no_register.
  if (kernel_def->op() != "_no_register") {
    const string key =
        Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
            kernel_def->label());

	auto global_registry =
	        reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
	    mutex_lock l(global_registry->mu);
	    global_registry->registry.emplace(
	        key,
	        KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
	}
}

这个GlobalKernelRegistry我们之前已经说过了,它返回的是一个KernelRegistry实例,global_registry->registry 就是我们之前说的保存注册信息的map,也就是说,OpKernel的注册发生在OpKernelRegistrar的构造函数中!
我们顺藤摸瓜,看看这个构造函数是怎么被调用的。

4. REGISTER_KERNEL_BUILDER

OpKernelRegistrar的构造就是在REGISTER_KERNEL_BUILDER宏定义中:

#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
  REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
  REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)        \
  constexpr bool should_register_##ctr##__flag =                      \
      SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__);                        \
  static ::tensorflow::kernel_factory::OpKernelRegistrar              \
      registrar__body__##ctr##__object(                               \
          should_register_##ctr##__flag                               \
              ? ::tensorflow::register_kernel::kernel_builder.Build() \
              : nullptr,                                              \
          #__VA_ARGS__,                                               \
          [](::tensorflow::OpKernelConstruction* context)             \
              -> ::tensorflow::OpKernel* {                            \
            return new __VA_ARGS__(context);                          \
          });

宏定义理解起来往往比较麻烦,不要着急,我们一个个看。

首先做一些宏定义知识的补充,可能不是所有人都清楚(比如我-_-!):

__COUNTER__ 可以理解为一个int型计数器,初始值为0,每出现一次,值+1
#x 将x转换成一个字符串
##ctr 变量拼接,就是将ctr的值拼接到整个变量中
__VA_ARGS__可变参数

有了上面这些知识,我们再来看这些宏就没这么复杂了:

  1. 首先REGISTER_KERNEL_BUILDER接受两个参数,一个是kernel_builder,另一个是可变参数;
  2. 将这两个参数传给REGISTER_KERNEL_BUILDER_UNIQ_HELPER,而这个宏在前面的宏的基础上,增加了一个计数器,并将这三个参数传给下一个定义的宏
  3. REGISTER_KERNEL_BUILDER_UNIQ接受了这三个参数,然后定义一个临时变量should_register_##ctr##__flag,根据我们上面宏定义的知识,ctr和flag的值都会拼接到register_后面,而这个bool值的结果是SHOULD_REGISTER_OP_KERNEL(#_VA_ARGS),看字面意思就可以理解为是否需要注册这个OpKernel;然后定义了一个static的OpKernelRegistrar变量registrar__body__##ctr##__object,且调用了OpKernelRegistrar的第二类构造函数:

至此我们找到了构造OpKernelRegistrar的地方,也就是说每次使用宏REGISTER_KERNEL_BUILDER注册OpKernel,都会调用OpKernelRegistrar并将对应的Kernel信息存到map中。

  1. 我们看一下OpKernelRegistrar构造函数的参数:

1)should_register_##ctr##__flag ? ::tensorflow::register_kernel::kernel_builder.Build() : nullptr 也就是说如果需要创建这个OpKernel,就传入::tensorflow::register_kernel::kernel_builder.Build()这个参数的值我们后面会介绍,根据构造函数的三个参数,我们暂时只需要知道这一长串会返回一个KernelDef对象
2) #__VA_ARGS__ 第二个参数是可变参数变成的字符串,也就是kernel_class_name
3)[](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context);这是一个lamda表达式函数,入参数OpKernelConstruction* context,返回类型是OpKernel*,这个函数指针本身构成了第三个参数,即OpKernel* (*create_fn)(OpKernelConstruction*)

到此我们应该理解了这个复杂的宏REGISTER_KERNEL_BUILDER,只需要正确使用这个宏,就可以注册一个OpKernel!!!
遗留了一个问题,就是为什么这个kernel_builder.Build(),就相当于是KernelDef对象呢?

5.如何使用这个宏?

我们看一下官方的例子:

REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),DummyKernel);

这里我们看到第一个参数是Name("Test1").Device(tensorflow::DEVICE_CPU)这个东西为什么就是KernelDef呢?我们看一下这个Name究竟是什么,说实话这个类不太好找:


class Name : public KernelDefBuilder {
 public:
  explicit Name(const char* op)
      : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
};

原来这个Name类是继承自KernelDefBuilder类,且在它的构造函数中,调用了基类的构造函数,传入的是op的名字,我们再来看一下这个基类:

class KernelDefBuilder {
 public:
  explicit KernelDefBuilder(const char* op_name);
  ~KernelDefBuilder();
  KernelDefBuilder& Device(const char* device_type);
  template <typename T>
  KernelDefBuilder& AttrConstraint(const char* attr_name, gtl::ArraySlice<T> allowed);
  template <typename T>
  KernelDefBuilder& AttrConstraint(const char* attr_name, T allowed);
  KernelDefBuilder& TypeConstraint(const char* attr_name,
gtl::ArraySlice<DataType> allowed);
  KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
  template <class T>
  KernelDefBuilder& TypeConstraint(const char* attr_name);
  KernelDefBuilder& HostMemory(const char* arg_name);
  KernelDefBuilder& Label(const char* label);
  KernelDefBuilder& Priority(int32 priority);
  const KernelDef* Build();
 private:
  KernelDef* kernel_def_;
  TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
};

基类KernelDefBuilder也接受一个op_name作为构造参数,且我们现在可以看到,刚才Name(“Test1”)后面的.Device()实际上就是KernelDefBuilder的成员函数,返回的是KernelDefBuilder&类型。

在得到这个KernelDefBuilder&类型的返回值后,在通过调用kernel_builder.Build()方法,就得到了const KernelDef* 类型的返回值,这就回答了我们刚才的问题!

总结

我们花了很久的时间,就是为了搞清楚TF究竟是如何设计和实现Opkernel的注册的。我们先是简单介绍了从调用到底层实现,然后详细的从底层开始分析了每一步的实现。不得不说TF这一套东西很复杂,但是只要多看两遍,也可以理解。

对于OpKernel类来说,往下有它自身的数据类和数据管理类,以及构造辅助类,往上被封装到一个宏定义中,在后面说到Op的时候,会发现整体思路和OpKernel十分相似,所以理解其中一个,另一个理解起来是水到渠成。

参考

  1. TF源码
  2. 『深度长文』Tensorflow代码解析(三)