April 22, 2020

更亲开发者的深度学习包: PyTorch Lightning -- 简介

更亲开发者的深度学习包: PyTorch Lightning -- 简介

本质上 Lightning  就是 PyTorch代码,更像是一种编程风格指引。老实说这东西着实不错!基本上是实现了开发者熟悉的状态机方式来写深度学习的代码。

有开发背景的同学肯定对有限状态机(FSM)的概念不陌生,所谓状态机模型,就是所有状态都提前预知并定义好成一个类似 callback, 然后程序员所需要做的就是在定义好的状态或者说动作里面填写上具体的功能代码即可,所以一般也叫有限状态机,上一个例子就容易理解我在说什么了:


// 顾名思义,这个 viewDidLoad 函数就是在该 viewcontroller 的 view 完成了加载之后会被执行。
- (void)viewDidLoad
{
    [super viewDidLoad];
	// Do any additional setup after loading the view, typically from a nib.

    UIImage *temp = [[UIImage imageNamed:@"title.png"] imageWithRenderingMode: UIImageRenderingModeAlwaysOriginal];
    UIBarButtonItem *barButtonItem = [[UIBarButtonItem alloc] initWithImage:temp style:UIBarButtonItemStyleBordered target:self action:@selector(action)];
    self.navigationItem.leftBarButtonItem = barButtonItem;
    self.navigationItem.leftBarButtonItem.enabled = NO;
    
    protocol = [[RBLProtocol alloc] init];
    protocol.delegate = self;
    protocol.ble = ble;
    
    NSLog(@"ControlView: viewDidLoad");
}

-(void)viewDidAppear:(BOOL)animated
{
    NSLog(@"ControlView: viewDidAppear");
    
    syncTimer = [NSTimer scheduledTimerWithTimeInterval:(float)3.0 target:self selector:@selector(syncTimeout:) userInfo:nil repeats:NO];

    [protocol queryProtocolVersion];
}

-(void)viewDidDisappear:(BOOL)animated
{
    NSLog(@"ControlView: viewDidDisappear");

    total_pin_count = 0;
    [tv reloadData];
    
    init_done = 0;
}

- (void)didReceiveMemoryWarning
{
    [super didReceiveMemoryWarning];
    // Dispose of any resources that can be recreated.
}

FSM 的优缺点都非常明显,经验老到的前辈把可以预期的状态都已 callback 的方式写到代码中,后来者只需要在相应的状态 callback 里面增删代码既可实现新的功能。优点是逻辑清晰易上手,缺点是一旦遇到新增的状态,全部状态机代码都得重审一遍,以免遗漏状态切换,所以这种方式只适合在流程比较成熟的场景下,比如上面 iOS 程序的开发流程,或者其他诸如 react.js 的成熟前端流程中。

由William Falcon (williamFalcon) 同学领衔的 PyTorch Lightning 的出现说明了两点:

  1. 针对特定任务的深度学习任务流程已经慢慢成熟。
  2. 深度学习领域已经「快要」做好准备迎接更多的程序猿入场了。

个人最关心的还是它对 BERT-like 模型和 TPU 上的支持,通过下面的方式在 colab 或者 GCP 上开启 PyTorch 对 TPU 也就是 Lightning 对 TPU 的支持:

#@title 开启 TPU 支持:

import collections
from datetime import datetime, timedelta
import os
import requests
import threading

_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
VERSION = "torch_xla==nightly"  #@param ["xrt==1.15.0", "torch_xla==nightly"]
CONFIG = {
    'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
    'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
        (datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
}[VERSION]
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)

# Update TPU XRT version
def update_server_xrt():
  print('Updating server-side XRT to {} ...'.format(CONFIG.server))
  url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
      TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
      XRT_VERSION=CONFIG.server,
  )
  print('Done updating server-side XRT: {}'.format(requests.post(url)))

update = threading.Thread(target=update_server_xrt)
update.start()

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5
update.join()

然后配置最新版本的 pytorch lightning 和 transformers 支持:

!pip install git+https://github.com/PyTorchLightning/pytorch-lightning.git --upgrade
!pip install transformers --upgrade

注意!改完之后先不要马上往下运行,由于原有的 tqdm 包被新的包替换,建议这里重启下 colab 代码执行程序,之后就可以安全的按照官方提供的 bert demo 一步步实验了,只要一行小小的改动就可以把原有的单卡环境改到 8-core TPU 支持了:

trainer = pl.Trainer(gpus=1,max_epochs=1)    

改成:

trainer = pl.Trainer(num_tpu_cores=8,max_epochs=1)    

当然,如果当前是自己部署的多卡多机 GPU 训练环境,可以分别像这样指定 单机多卡、多机多卡的分布式训练环境( 分布式模式可选单机多卡的 dp/多机多卡的 ddp/单机内走dp,机器之间走ddp 的 ddp2 ...):



# 单机8卡,
trainer = pl.Trainer(gpus=8, distributed_backend='dp')

# 4机4x8 卡
trainer = pl.Trainer(gpus=8, distributed_backend='ddp', num_nodes=4)