On the Noisy Gradient Descent that Generalizes as SGD Jingfeng Wu , Wenqing Hu, Haoyi Xiong, Jun Huan, Vladimir Braverman, Zhanxing Zhu Johns Hopkins University, Missouri University of Science and Technology, Baidu Research, Styling AI, Peking University
<latexit sha1_base64="bS1Z0XerBUCN69KWD1FLEosGR7k=">ACynicbVHRStxAFJ3EWnVr21UfRlcWhTpklRBfRCkfHB4WuWjZLmEzu7g5OJmHmRghb/6eH+AP9Ds6iRF3Vy+EnDn3nHuHM1EmhUHPe3LcpQ/LH1dW1zqf1j9/+drd2Lw2a45DHgqU30bMQNSKBigQAm3mQaWRBJuorvfdf/mHrQRqfqDRQajhE2UGAvO0FJh9zHAKSALS9z3K/r9lLZnpD9oYAENUmuPNONQBihkDOWk2n0R7VU0CDq08eUqfhUuzJhoFoctXdGLWf+LZNZe/+4A31tY6+emzQ6rwm7P63tN0bfAb0GPtHUZdv8FcrzBRyYwZ+l6Go5JpFxC1QlyA5m9DJvA0ELFEjCjsom9ot8sE9Nxqu2nkDbsrKNkiTFElwnBqFns1+V5vmOP4eFQKleUIij8vGueSYkrN6Sx0MBRFhYwroW9K+VTZkND+9JzW5KiWdKxwfiLMbwF1z/7/mH/5Oqwd/arjWiVbJMdskt8ckTOyDm5JAPCnQPnrxM53L1wtVu45bPUdVrPFpkr9+E/w2XezQ=</latexit> <latexit sha1_base64="p6Gzhfmu2a6M/o0dA0E02QuYdk=">ACNXicbVDBThsxFPRSaGkKbdoe7GIKqWXaBchNRVCQnAoBw4gNSRSNl15nbeJFdu7st+iRsv+Ch/QH+gPcIU7B26IK79QZ5NDAx3J0mjePM3zxJkUFn3/1lt5sbr28tX69qbjc237+rvP5zZNDcOjyVqenFzIUGjoUEIvM8BULKEbTw5n8+45GCtS/QOnGQwUG2mRCM7QSVG9fdwMcQzIvtA9GiaG8SIoC13S0OYqKsReUP7UNAQpafNXJOguXdijesNv+RXocxIsSGP/O/0dRhejk6j+EA5TnivQyCWzth/4GQ4KZlBwCWUtzC1kjE/YCPqOaqbADorqhyX97JQhTVLjnkZaqf9uFExZO1WxcyqGY/t0NhP/N+vnmLQHhdBZjqD5PCjJcWUzuqiQ2GAo5w6wrgR7lbKx8zVhK7UpRQ1rUJqrpjgaQ3Pydl2K9hpfTt1DR2QOdbJ7JFmiQgX8k+OSInpEM4uSRX5JrceH+8O+/e5hbV7zFzkeyBO/xL6azrY=</latexit> <latexit sha1_base64="NO81viQx6ZAH+ykfVPko4iKqA=">ACDnicbVDLSsNAFJ3UV62vVJduBotQNyWRgrorunFZwT6gKWEynbZDJw9mbioh9B/8Abf6B+7Erb/gD/gdTtMsbOuBC4dz7uVcjhcJrsCyvo3CxubW9k5xt7S3f3B4ZJaP2yqMJWUtGopQdj2imOABawEHwbqRZMT3BOt4k7u535kyqXgYPEISsb5PRgEfckpAS65ZnrqOGg1w1YExA+LChWtWrJqVAa8TOycVlKPpmj/OIKSxzwKgijVs60I+imRwKlgs5ITKxYROiEj1tM0ID5T/TR7fYbPtTLAw1DqCQBn6t+LlPhKJb6nN30CY7XqzcX/vF4Mw+t+yoMoBhbQRdAwFhCPO8BD7hkFESiCaGS618xHRNJKOi2lL8JAsp6WLs1RrWSfuyZtdrNw/1SuM2r6iITtEZqiIbXaEGukdN1EIUPaEX9IrejGfj3fgwPherBSO/OUFLML5+ARZrnAU=</latexit> <latexit sha1_base64="WuEeC+IspeWG7WEBXAwi0QNv8w=">ACO3icbVDPSxtBFJ7VWmPUGvXoZWgQ0kvYVUFkWAv9pZCE4VsWGYnb5PBmdl5q0Ylvw3/hn2H+jVXnoueJBePHjv5Mehmj4Y+N73zeN74szKSz6/i9vYfHd0vl0kp5dW39w0Zlc6t09xwaPFUpuYqZhak0NBCgRKuMgNMxRIu4+vPY/3yBowVqf6Gwy6ivW1SARn6KiochYmhvEiGBXxiIY2V1EhQqHpeYSu7xvWi4oQB4DMtSBl7TYSJ3TKRPgpqlT9uj8pOg+CGag2vuyf3l8rzajymPYS3muQCOXzNpO4GfYLZhBwSWMymFuIWP8mvWh46BmCmy3mNw5oruO6dEkNe5pBP234mCKWuHyl2yqxgO7FtTP5P6+SYHULobMcQfOpUZJLikdh0Z7wgBHOXSAcSPcrpQPmAsOXbSvXNRwYlJ2wQRvY5gH7b16cFA/uoSOifTKpEd8pHUSEAOSYNckCZpEU7uyA/yQH56373f3pP3Z/p1wZvNbJNX5b38Bcgrsdc=</latexit> Stochastic gradient descent (SGD) n L ( ✓ ) = 1 Loss function X ` ( x i ; ✓ ) n i =1 1 X r θ ` ( x i ; ✓ t ) b i ∈ B t SGD z }| { θ t +1 = θ t � η g ( θ t ) ˜ = θ t � η r θ L ( θ t ) } � η (˜ g ( θ t ) � r θ L ( θ t )) | {z | {z } GD v sgd ( θ t ) (unbiased) gradient noise
Noise matters! SGD >> GD • How? <= Still open… • Which? <= This work! CIFAR-10, ResNet-18, w/o weight decay, w/o data augmentation
<latexit sha1_base64="kj+yFB6VfOGxOf9MuzV4MT7235s=">ACOXicbVDLSgNBEJz1bXxFPXoZDIeDLsSUEFB9OLBg4LRQDYs7OdzZDZBzO9gbDkZ/wNf8Cr3jx6EfHqDzh5KGosGCiquqme8lMpNr2szUxOTU9Mzs3X1hYXFpeKa6u3egkUxyqPJGJqvlMgxQxVFGghFqgEW+hFu/fdb3bzugtEjia+ym0IhYGIum4AyN5BWPOp6rw4Bu9gCZDv0mLoZA0/JZ2qRsqFnj5UOhdfDlesWSX7QHoOHFGpERGuPSKr26Q8CyCGLlkWtcdO8VGzhQKLqFXcDMNKeNtFkLd0JhFoBv54Jc9umWUgDYTZV6MdKD+3MhZpHU38s1kxLCl/3p98T+vnmHzoJGLOM0QYj4MamaSYkL7ldFAKOAou4YwroS5lfIWU4yjKfZXStQdhBRMc7fGsbJzV7ZqZQPryqlk9NRXNkg2ySbeKQfXJCzsklqRJO7sgDeSRP1r31Yr1Z78PRCWu0s05+wfr4BNijrFQ=</latexit> Which noise matters? v sgd ( θ ) = ˜ g ( θ ) � r θ L ( θ ) 1. Magnitude <= YES! (e.g., Jastrzkebski et al. 2017) 2. Covariance structure <= YES! (e.g., Zhu et al. 2018) 3. Distribution class <= ? No!!! (this work) Bernoulli? Gaussian? Levy?...
<latexit sha1_base64="lb+H4UTjzf5Kv/JL3KQaDibfw=">ACNXicbVDLSgNBEJz1GeMr6tHLYBUNOxKwIiXoBePCkMZJcwO+mYIbMPZnolYcmv+Bv+gFe9e/AmufoLTmIEYywYKq6qZ7yYyk02vabNTe/sLi0nFnJrq6tb2zmtrZrOkoUhyqPZKTqPtMgRQhVFCihHitgS/hzu9ejfy7B1BaRGEF+zF4AbsPRVtwhkZq5kou9GLg2Ogdu9gBZM2Kl7og5UHv4kc4pCd0Wjo6HDRzebtgj0FniTMheTLBTM3dFsRTwIkUumdcOxY/RSplBwCYOsm2iIGe+ye2gYGrIAtJeOfzig+0Zp0XakzAuRjtXfGykLtO4HvpkMGHb0X28k/uc1EmyXvFSEcYIQ8u+gdiIpRnRUF20JZdqRfUMYV8LcSnmHKcbRlDqVEvTHIVlTjPO3hlSOy04xcL5bTFfvpxUlCG7ZI8cEIeckTK5JjekSjh5JM/khbxaT9a79WENv0fnrMnODpmC9fkFJx+riw=</latexit> <latexit sha1_base64="6fN/KzCDnCBe5BNfzP9VFrwepSQ=">ACVnicbZDfShtBFMYna6Mxrbray94MBsFSDLsiqBdCaG960QsFkyjZsMzOniSDs7PLzFkhLPtsvoY+QHvZvoE4SVb8kx4Y+M53vsMZflEmhUHPe6g5Kx/q2uN9ebHTxubW+72Ts+kuebQ5alM9VXEDEihoIsCJVxlGlgSehHNz9m8/4taCNSdYnTDIYJGysxEpyhtUL3OsAJIAsL/OaX9IxWLdIDGlhBg1zFoCPNOBTBWLM4LBaRkv7afw5/LZ/jty9e6La8tjcvuiz8SrRIVeh+yeIU54noJBLZszA9zIcFkyj4BLKZpAbyBi/YWMYWKlYAmZYzBGUdM86MR2l2j6FdO6+3ihYsw0iWwyYTgx72cz83+zQY6jk2EhVJYjKL4NMolxZTOeNJYaOAop1YwroX9K+UTZnmhpf7mSjKdH2laMP57DMuid9j2j9qnF0etzvcKUYN8Ibtkn/jkmHTIT3JOuoSTO/Kb/CX/ave1R6furC2iTq3a+UzelOM+AbQitZQ=</latexit> Intuition For quadratic loss, the generalization error x, θ T [ ` ( x ; ✓ T ) − ` ( x ; ✓ ∗ )] only depends on the first two moments of 𝜄 ! , which only depend on the first two moments of 𝑤(𝜄) . θ t +1 = θ t � η r θ L ( θ t ) } � η v ( θ t ) | {z Linear Noise matters! But noise class does not!!!
<latexit sha1_base64="TAwr4vz/1WeH84dKCYFz3HPYwgM=">ACKHicbVDLSsNAFJ3UV62vqEs3g6UgqCWRg8Qim5cVrAPaEKZTCft0MkzEyEPIR/oY/4Fb/wJ1068LvcNJ2YVsPDBzOmcu593gRo1JZ1tgorKyurW8UN0tb2zu7e+b+QUuGscCkiUMWio6HJGUk6aipFOJAgKPEba3ug+9vPREga8ieVRMQN0IBTn2KktNQzT502RqznyEH/Bjpl6PgC4dTOUi+Dt9A7c3LR0pSfez2zbFWtCeAysWekDGZo9Mwfpx/iOCBcYak7NpWpNwUCUxI1nJiSWJEB6hAelqylFApJtOjspgRSt96IdCP67gRP07kaJAyiTQe1YCpIZy0cvF/7xurPwrN6U8ihXheBrkxwyqEOYNwT4VBCuWaIKwoHpXiIdI16J0j3MpQTIJKeli7MUalknromrXqtePtXL9blZRERyBY3ACbHAJ6uABNEATYPAC3sA7+DBejU/jyxhPvxaM2cwhmIPx/Qv7o6QE</latexit> <latexit sha1_base64="6nU86pnc4vVkxCH7hZrZS+zvU=">ACa3icbVHLSgMxFM2M7/qulMXQREqSpkRQUWEohsXLhSsCp0yZNLbNjSTGZI74jD0L/w0N/6AO3/AleljYdULISfnvgJEqlMOh5747NT0zOze/UFpcWl5ZLa+tP5gk0xzqPJGJfoqYASkU1FGghKdUA4sjCY9R72qQf3wGbUSi7jFPoRmzjhJtwRlaKiyroKNZKywC7AKyPg1uOJOV0WufXtAg0oz3AItJXQBSVl5C/3ysPKRBK0Fjr/9k6nAs64flXa/qDYP+Bf4Y7NZ2goPX91p+G5Y/bGOexaCQS2ZMw/dSbBZMo+AS+qUgM5DaDVkHGhYqFoNpFkNf+nTPMi3aTrQ9CumQ/VlRsNiYPI6sMmbYNb9zA/K/XCPD9mzECrNEBQfDWpnkmJCBybTltDAUeYWMK6F3ZXyLrNOov2KiSlxPhxSsb4v234Cx6Oqv5x9ezOnRJRjFPtsgOqRCfnJAauSa3pE4eSNfzpQz7Xy6G+6muz2Sus64ZoNMhLv3DRKv14=</latexit> <latexit sha1_base64="0bKvXi7R+xe/ISAKMPyZNSTCdio=">ACnicbZHPatAEMZXStuk7j8nOfay1BTSQ4wUQtMeAqGF0kNKU6gdg2XEaDWl6xWYncUMEIPk8fqC/Q5spbdNko6sPAxv5md4ZukVNJSEPzy/K1Hj59s7ztPXv+4uWr/u7e2BaVETgShSrMJAGLSmockSFk9Ig5InCy+Tq84pfXqOxstA/aVniLIdMy7kUQC4V92+iSqdoEgMC6+s4slnKDyJaIMG7hp9yfpdHJFWKPtXcNjlmYE0rte0Of9b1jvlXcSjcwHqD+eRSAvq/jR2vN2mifuDYBi0wR+KcCMGbBMXcf93lBaiylGTUGDtNAxKmtVgSAqFTS+qLJYgriDqZMacrSzurWy4W9dJuXzwrinibfZux015NYu8RV5kALe5+tkv9j04rmH2a1GVFqMV60LxSnAq+ugtPpUFBaukECPdrlwswJlB7nqdKfmyHdJzxoT3bXgoxkfD8Hj48cfx4OzTxqId9pq9YQcsZCfsjH1lF2zEhLftHXrvROf+1/8b/73danvbXr2WSf8yS3AuM4x</latexit> <latexit sha1_base64="tpgXxRfVOTUgPLu6xgEHKXJFoQ=">ACMXicbVDLSiQxFE35Gm1fPbqcTbARdNUSYPOTsbNLGahMK1CV9PcSt3qDqaSIrk10BT1I/Mb/oBb5w/cDS7c+BOmHwtfBwIn59zLSU5SKOkoDB+ChcWl5ZUvq2uN9Y3Nre3m151LZ0orsCuMvY6AYdKauySJIXhUXIE4VXyc3ZxL/6g9ZJo3/TuMB+DkMtMymAvDRoduKhXRQxTRCgprHvwSog9ntkMciNcTjzIKorSdWw0Jsmg2Qrb4RT8I4nmpMXmOB80n+LUiDJHTUKBc70oLKhfgSUpFNaNuHRYgLiBIfY81ZCj61fT39V83yspz4z1RxOfq83KsidG+eJn8yBRu69NxE/83olZSf9SuqiJNRiFpSVipPhk6p4Ki0KUmNPQFjp38rFCHwX5At9k5KPpyENX0z0voaP5PKoHXa3y86rdMf84pW2Te2xw5YxI7ZKfvJzlmXCfaX3bF79i+4DR6C/8HjbHQhmO/sjcInl8AhGeq6w=</latexit> <latexit sha1_base64="i9wUYTmMfkengD4WERgFodbQRU=">ACInicbVDNSsNAGNzUv1r/qh69LJaCF0siBfUgFL14rGB/oAlhs9m0SzebsLsRQsgT+Bq+gFd9A2/iSfDsc7hNi9jWgYXZme9jdseLGZXKND+N0srq2vpGebOytb2zu1fdP+jKBGYdHDEItH3kCSMctJRVDHSjwVBocdIzxvfTPzeAxGSRvxepTFxQjTkNKAYKS251brdxYi5thz68Aravd/LKbQDgXBm5RnP3WrNbJgF4DKxZqQGZmi71W/bj3ASEq4wQ1IOLDNWToaEopiRvGInksQIj9GQDTlKCTSyYrv5LCuFR8GkdCHK1iofzcyFEqZhp6eDJEayUVvIv7nDRIVXDgZ5XGiCMfToCBhUEVw0g30qSBYsVQThAXVb4V4hHQLSjc4lxKmRUhF2Mt1rBMumcNq9m4vGvWtezisrgCByDE2CBc9ACt6ANOgCDR/AMXsCr8WS8Ge/Gx3S0ZMx2DsEcjK8f4pKj0A=</latexit> <latexit sha1_base64="4RaNxqViWuH516fg4e9gd2nVKJQ=">ACLHicbVDLSgMxFM3UV62vqks3wSLUTZ2RgrorunHhoJ9QKeUO5m0Dc08SO4IZehn+Bv+gFv9Azcibvsdpo+FbT0QODnXk5yvFgKjb9ZWXW1jc2t7LbuZ3dvf2D/OFRXUeJYrzGIhmpgeaSxHyGgqUvBkrDoEnecMb3E38xjNXWkThEw5j3g6gF4quYIBG6uQv3J4Cv5O62OcI+o+MJDF2e2cusyPkLoNo5kR3fNHnXzBLtlT0FXizEmBzFHt5MeuH7Ek4CEyCVq3HDvGdgoKBZN8lHMTzWNgA+jxlqEhBFy30+nHRvTMKD7tRsqcEOlU/buRQqD1MPDMZADY18veRPzPayXYvW6nIowT5CGbBXUTSTGik5aoLxRnKIeGAFPCvJWyPihgaLpcSAmG05CcKcZrmGV1C9LTrl081guVG7nFWXJCTklReKQK1Ih96RKaoSRF/JG3smH9Wp9Wt/Wz2w0Y813jskCrPEvGhSopA=</latexit> A closer look at the noise of SGD v sgd ( θ ) | {z } = ˜ |{z} � r θ L ( θ ) g ( θ ) } = r θ L ( θ ) · V sgd | {z |{z} r θ L ( θ ) · 1 Gradient Sampling r θ L ( θ ) · W sgd noise noise n • Gradient matrix r θ L ( ✓ ) = ( r θ ` ( x 1 ; ✓ ) , . . . , r θ ` ( x n , ✓ )) W sgd : #1 • Sampling vector b = b, #0 = n − b V sgd = W sgd − 1 • Sampling noise n
Recommend
More recommend