Taichi计算效率的请教

各位老师好,打扰老师们我想请教两个问题~

  1. taichi图形课以及文档中都提到了taichi自动的做了JIT处理,请问从计算效率方面,taichi和其他框架比如JAX做JIT之后是否有优势呢?
  2. taichi在for循环的最外层做了自动并行处理提升计算效率,请问这种建立在for上的自动并行化处理和我们在python里直接写成高阶矩阵操作在效率上哪个更好呢?比如我们在JAX/PyTorch这些框架里为了获得更好的计算效率,会尽可能减少for的使用将code尽可能写成高阶矩阵或tensor,然后用vmap等手段去加速。那taichi这种基于for的自动并行相比JAX/PyTorch来说是否有优势呢?

这两个问题都不能一概而论,非常看具体的应用场景。在有高度优化的库的时候,通常现有的框架会有性能上的优势,比如跑神经网络训练一定是用PyTorch更好,Taichi也因此做了和PyTorch Tensor的互通。关于高阶算子,我认为BLAS-3级别的算子(矩阵乘法)是不太容易在Taichi里充分优化的,我们在语法上的支持并不是很完备。BLAS-1和BLAS-2级别的算子使用Taichi通常能达到CUDA的水平,更不用说和PyTorch/JAX相比了。另外,我们鼓励大家把不同的操作写到一个loop里面,这在神经网络框架中称为算子融合,是一种优化技术。使用高阶库来写的时候做算子融合是个比较复杂的事情,非常依赖框架本身的优化,在融合算子有库支持的时候,比如“conv + relu”/“conv + batchnorm”,仍然是PyTorch快,但是如果融合算子没有支持,结果就很难说,需要实际测试一下。但是在Taichi里面写for循环的时候,合并几个相同的循环是很自然很简单的事情,因此在小算子很多的情况下Taichi通常会有性能优势。作为用户,你可以把Taichi当成一个使用Python语法的CUDA,同时还支持移动端设备;如果你的场景没什么特别好用的库,比如新写的前后处理算子,或者希望进一步优化性能,只管用Taichi来探索;如果你有高度优化的库,那么可以结合起来用:高度优化的部分使用库来写,欠优化的部分使用Taichi来补充

4 个赞