ACCELERATING SOUMITH ML DEVELOPMENT CHINTAL A WITH PY TORCH FACEBOOK AI
PY TORCH OVERVIEW
N V I D I A S U P P O R T & C O L L A B O R A T I O N H A R D W A R E S O F T W A R E S U P P O R T C O L L A B O R A T I O N S C A L A B I L I T Y & C O R E L I B R A R Y D E P L O Y M E N T I N T E G R A T I O N
GOING FROM RESEARCH TO PRODUCTION
T R A N S F E R T O P R E P A R E P R O D U C T I O N D A T A 1 2 3 4 5 D E T E R M I N E D E P L O Y & B U I L D & A P P R O A C H S C A L E T R A I N M O D E L
torch.jit Code == Model Code == Model == Data 1.0.0
torch.jit Code == Model Code == Model == Data @torch.jit.script def myfun(x): def myfun(x): y = x * x y = x * x z = y.tanh() z = y.tanh() return z return z 1.0.0
torch.jit Code == Model Code == Model == Data graph(%x : Dynamic) { @torch.jit.script def myfun(x): %y : Dynamic = aten::mul(%x, def myfun(x): y = x * x %x) y = x * x z = y.tanh() %z : Dynamic = aten::tanh(%y) z = y.tanh() return z return (%z) return z } 1.0.0
torch.jit Export and run anywhere Code == Model Code == Model == Data graph(%x : Dynamic) { @torch.jit.script def myfun(x): %y : Dynamic = aten::mul(%x, def myfun(x): y = x * x %x) y = x * x z = y.tanh() %z : Dynamic = aten::tanh(%y) z = y.tanh() return z return (%z) return z } 1.0.0
torch.jit Execution with Eager execution ahead-of-time analysis @torch.jit.script def myfun(x): def myfun(x): y = x * x y = x * x z = y.tanh() z = y.tanh() return z return z 1.0.0
torch.jit Execution with Eager execution ahead-of-time analysis @torch.jit.script def myfun(x): def myfun(x): y = x * x y = x * x z = y.tanh() z = y.tanh() return z return z 1.0.0
torch.jit Execution with Eager execution ahead-of-time analysis @torch.jit.script def myfun(x): def myfun(x): y = x * x y = x * x z = y.tanh() z = y.tanh() return z return z 1.0.0
torch.jit Execution with Eager execution ahead-of-time analysis @torch.jit.script def myfun(x): def myfun(x): y = x * x y = x * x z = y.tanh() z = y.tanh() return z return z 1.0.0
torch.jit Execution with Eager execution ahead-of-time analysis @torch.jit.script def myfun(x): def myfun(x): y = x * x y = x * x z = y.tanh() z = y.tanh() return z return z 1.0.0
torch.jit Execution with Eager execution ahead-of-time analysis @torch.jit.script def myfun(x): def myfun(x): y = x * x return tanh_mul(x) z = y.tanh() return z 1.0.0
torch.jit Execution with Eager execution ahead-of-time analysis @torch.jit.script def myfun(x): def myfun(x): y = x * x return tanh_mul(x) z = y.tanh() saturate faster hardware return z whole program optimizations 1.0.0
PyTorch Models are Python programs • Simple • Debuggable — print and pdb • Hackable — use any Python library • Needs Python to run • Difficult to optimize and parallelize
PyTorch Eager Mode Models are Python programs • Simple • Debuggable — print and pdb • Hackable — use any Python library • Needs Python to run • Di ffi cult to optimize and parallelize
PyTorch Eager Mode PyTorch Script Mode Models are Python programs Models are programs written in an optimizable subset of Python • Simple • Production deployment • Debuggable — print and pdb • No Python dependency • Hackable — use any Python library • Optimizable • Needs Python to run • Di ffi cult to optimize and parallelize
P Y T O R C H J I T Tools to transition eager code into script mode @torch.jit.script E A G E R S C R I P T M O D E M O D E torch.jit.trace For prototyping, training, For use at scale and experiments in production
Transitioning a model with torch.jit.trace Take an existing eager model, and provide import torch example inputs. import torchvision The tracer runs the function, recording def foo(x, y): the tensor operations performed. return 2*x + y We turn the recording into a Torch Script module. # trace a model by providing example inputs • Can reuse existing eager model code traced_foo = torch.jit.trace(foo, ⚠ Control-flow is ignored • (torch.rand(3), torch.rand(3))) traced_resnet = torch.jit.trace(torchvision.models.resnet18(), torch.rand(1, 3, 224, 224))
Tracing def foo(x, t): y = x.mm(x) X print(y) # still works! return y + t 1 MatMul x = torch.Tensor([[1,2],[3,4]]) foo(x, 1) Add trace = torch.jit.trace(foo, (x, 1)) trace.save(“serialized.pt”)
[0,0] w Tracing X[0] MatMul def foo(x, t): Add y = x.mm(x) print(y) # still works! w return y + t X[1] MatMul def bar(x, w): y = torch.zeros(1, 2) Add for t in x: y = foo(y, w, t) w return y X[2] MatMul trace = torch.jit.trace(foo, (x, 1)) trace.save(“serialized.pt”) Add
Script [0,0] def foo(x, t): for i = range(X.shape[0]): y = x.mm(x) print(y) # still works! w return y + t MatMul @script X[i] def bar(x, w): Add y = torch.zeros(1, 2) for t in x: y = foo(y, w, t) y return y trace = torch.jit.trace(foo, (x, 1)) trace.save(“serialized.pt”)
Transitioning a model with @torch.jit.script class RNN(torch.jit.ScriptModule): Write model directly in a subset of Python, def __init__(self, W_h, U_h, W_y, b_h, b_y): annotated with @torch.jit.script or super(RNN, self).__init__() @torch.jit.script_method self.W_h = nn.Parameter(W_h) You can mix both trace and script self.U_h = nn.Parameter(U_h) • Control-flow is preserved self.W_y = nn.Parameter(W_y) • in a single model. print statements can be used for self.b_h = nn.Parameter(b_h) debugging self.b_y = nn.Parameter(b_y) • @torch.jit.script_method Remove the annotations to debug using def forward(self, x, h): standard Python tools. y = [] for t in range(x.size(0)): h = torch.tanh(x[t] @ self.W_h + h @ self.U_h + self.b_h) y += [torch.tanh(h @ self.W_y + self.b_y)] if t % 10 == 0: print("stats: ", h.mean(), h.var()) return torch.stack(y), h
Under the hood of @torch.jit.script
Predictable error messages @torch.jit.script PA R S E T I M E R U N T I M E
Loading a model without Python # Python: save model Torch Script models can be saved to a model traced_resnet = torch.jit.trace(torchvision.models.resnet18(), archive, and loaded in a python-free torch.rand(1, 3, 224, 224)) executable using a C++ API. traced_resnet.save("serialized_resnet.pt") Our C++ Tensor API is the same as our Python API, so you can do preprocessing and // C++: load and run model post processing before calling the model. auto module = torch::jit::load("serialized_resnet.pt"); auto example = torch::rand({1, 3, 224, 224}); auto output = module->forward({example}).toTensor(); std::cout << output.slice(1, 0, 5) << '\n';
H A R D W A R E E F F I C I E N C Y Faster operator performance In PyTorch 1.0: • Leveraging specialized libraries: MKL-DNN, CuDNN, etc • Faster implementations for dozens of basic tensor operations What’s next: • Exposing all of the best operator implementations from Caffe2
H A R D W A R E E F F I C I E N C Y Connecting to ONNX Ecosystem Vendor runtimes are best for running things fast. In PyTorch 1.0: • Export entire model to ONNX for inference What’s coming: • ONNXIFI runtimes as part of bigger model through JIT
S C A L A B I L I T Y Distributed training Challenges: • Scaling to hundreds of GPUs • Heterogeneous clusters, Ethernet/InfiniBand • Potentially unreliable nodes In PyTorch 1.0: • Fully revamped distributed backend - c10d
S C A L A B I L I T Y & C R O S S - P L A T F O R M Deployment in C++ Often Python is not an option: • Torch Script High overhead on small models • + state_dict Multithreading services bottleneck on GIL • Deployment service might be C++ only In PyTorch 1.0: • Convert inference part of the model to Torch Script • Link with libtorch.so in your C++ application
PY TORCH TENG DISTRIBUTED LI TRAINING FACEBOOK AI
S I G N I F I C A N C E O F S C A L A B L E D I S T R I B U T E D T R A I N I N G M O R E C O M P U T I N G M O R E T R A I N I N G P O W E R D ATA . L A R G E R M O D E L S • S I G N I F I C A N T T R A I N I N G T I M E S P E E D U P S • G R E AT E X T E N T O F M O D E L E X P L O R AT I O N
DISTRIBUTED – WHAT’S NEW? • A brand new performance-driven distributed backend: C10D
H I G H L I G H T S PyTorch 1.0 Distributed B R A N D N E W B A C K E N D D E S I G N H I G H L Y S C A L A B L E P E R F O R M A N C E • • Fully asynchronous backend library: C10D Near roofline performance on key workloads • • Both Python and C++ support Data Parallel: Single-node, multi-GPUs • • Fully backward-compatible frontend python API Data Parallel: Multi-node, multi-GPUs
C 1 0 D L I B R A R Y DESIGN AND FEATURES • Backends • Gloo, NCCL, MPI • Fully asynchronous collectives for all backends • Both Python and C++ APIs • Performance-driven design • Self-managed CUDA streams for parallel execution • Upcoming • Fault tolerance with elasticity
Recommend
More recommend