3/3/2020 Data structures wa y x 1 D ASE System System E C r* O r state D Critic Critic E ACE vector x n R wc Global constants #define MAXINP 256 // Maximum number of input units #define ASE_ETA0 10.0 // ASE default learning rate #define ASE_DECAY0 0.85 // ASE default eligibility decay #define ACE_ETA0 0.5 // ACE default learning rate #define ACE_DECAY0 0.4 // ACE default eligibility decay #define GAMMA0 0.95 // default prediction discount 2 Global variables Basic functions static float wa[MAXINP]; // ASE weights n init_net init_net clear_traces clear_traces static float wc[MAXINP]; // ACE weights static float x[MAXINP]; // input vector static float eligi[MAXINP]; // ASE eligibility vector box y box p static float trace[MAXINP]; // ACE trace vector ase_output ase_output ace_output ace_output static float ase_eta; // ASE learning rate static float ase_decay; // ASE eligibility decay update_ase_weights update_ase_weights update_ace_weights update_ace_weights static float ace_eta; // ACE learning rate static float ace_deacay; // ACE eligibility decay static float discount; // ACE prediction discount ase_trace_decay ase_trace_decay ace_trace_decay ace_trace_decay static int inputs; // number of input units static int r; // primary reinforcement static int s; // ACE output (sec. reinf.) Auxiliary functions static int y; // ASE output min x y r frand frand sign sign max 3 4 Auxiliary functions Initialize network float frand( float *xmin, float *xmax) void init_net( int in) { { float range; int i; range = (xmax ‐ xmin); inputs = in; return (xmin + range*( float )rand()/RAND_MAX); } for (i=0; i<inputs; i++) wa[i] = wc[i] = 0.0; int sign( float x) { ase_eta = ASE_ETA0; if (x > 0) return 1; ace_eta = ACE_ETA0; return ‐ 1; ase_decay = ASE_DECAY0; } ace_decay = ACE_DECAY0; discount = GAMMA0; } 5 6 1
3/3/2020 Compute outputs Update weights / traces float ase_output( int x) void update_weights( float r) { { float net int i; int y; for (i=0; i<inputs; i++) { net = wa[x] + frand( ‐ 0.5, 0.5); wa[i] += r * ase_eta * eligi[i]; y = sign(net); wc[i] += r * ace_eta * trace[i]; } return y; } } void update_traces( int box, int y) float ace_output( int x) { { eligi[box] += (1.0 ‐ ase_decay) * y; return wc[x]; trace[box] += (1.0 ‐ ase_decay); } } 7 8 Trace decay Cart-pole model . void clear_traces() M c = cart mass { M p = pole mass int i; . L = pole length x for (i=0; i<inputs; i++) F = applied force F eligi[i] = trace[i] = 0.0; } x void decay_traces() { 2 ( M M ) g sin ( F M L sin ) cos int i; c p p 4 for (i=0; i<inputs; i++) { 2 ( M M ) L M L cos c p p eligi[i] = ase_decay * eligi[i]; 3 trace[i] = ace_decay * trace[i]; 2 } F M L ( sin cos ) p x } M M c p 9 10 Compute state System state typedef struct { int compute_state( float force, STATE *s) float x; // cart position { float v; // cart speed float ct, st; // sin, cos float tetha; // pole angle float x_acc, t_acc; // linear & angular acceleration float omega; // pole angular velocity float dt; // integration step } STATE; ct = cos(s.theta); st = sin(s.theta); 3 x .. t_acc = <see equation >; -XL XL .. 3 x_acc = <see equation x >; v -VL VL s.pos += s.speed*dt; 6 s.speed += x_acc*dt; -T6 -T1 T1 T6 s.theta += s.omega*dt; 3 s.omega += t_acc*dt; } -W50 W50 162 0 11 12 2
3/3/2020 Decode state Decode state v #define XL 0.8 // ASE default learning rate #define VL 0.5 // ASE default eligibility decay #define T1 0.01745 // PI/180 #define T6 0.10472 // 6*PI/180 6 7 8 #define W50 0.87266 // 50*PI/180 int decode_state(STATUS s) VL { int box; 3 4 5 x if (s.x < ‐ XL) box = 0; -VL else if (s.x < XL) box = 1; else box = 2; 0 1 2 if (s.v < ‐ VL) ; else if (s.v < VL) box += 3; else box += 6; -XL XL 13 14 Decode state Decode state A more structured way to determine the box number is to if (s.theta < ‐ T6) ; decode each state variable and then combine them: else if (s.theta < ‐ T1) box += 9; else if (s.theta < 0) box += 18; int decode_x( float x) else if (s.theta < T1) box += 27; { else if (s.theta < T6) box += 36; if (x < ‐ XL) return 0; else box += 45; if (x < XL) return 1; else return 2; if (s.omega < ‐ W50) ; } else if (s.omega < W50) box += 54; else box += 108; int decode_v( float v) { return box; if (v < ‐ VL) return 0; } if (v < VL) return 1; else return 2; } 15 16 Decode state Decode state int decode_w( float w) #define NBX 3 // number of boxes for position x #define NBV 3 // number of boxes for speed v { #define NBT 6 // number of boxes for theta if (w < ‐ W50) return 0; #define NBW 3 // number of boxes for omega if (w < W50) return 1; else return 2; The overall box number can be computed as: } int decode_state(STATUS s) int decode_t( float t) { { int box; if (t < ‐ T6) return 0; if (t < ‐ T1) return 1; box = box_x(s.pos) + if (t < 0) return 2; box_v(s.speed) * NBX + if (t < T1) return 3; box_t(s.theta) * NBX*NBV + if (t < T6) return 4; box_w(s.omega) * NBX*BBV*NBT; else return 5; } return box; } 17 18 3
3/3/2020 Learning cycle Learning cycle int main() while (duration < MAX_ITE) { { duration++; long duration; // # steps pole balanced y = ase_output(box); long failures; // # failures update_traces(box, y); long total_steps; // total # of steps force = FORCE*y; int y; // ASE output fail = compute_state(force, &s); int box; // decoded state region box = decode_state(s); int fail;// failure flag update_weights(box, fail); float force; // applied force to the cart STATE s; // system state if (fail) { clear_traces(); init_net(NBOXES); s = set_state(0,0,0,0); clear_traces(); box = decode_state(s); s = set_state(0,0,0,0); duration = 0; box = decode_state(s); failures++; } duration = 0; } total_steps = 0; } 19 20 4
Recommend
More recommend