diff --git a/js/ml/rl.js b/js/ml/rl.js new file mode 100644 index 0000000..4ef15eb --- /dev/null +++ b/js/ml/rl.js @@ -0,0 +1,1189 @@ +/** + ** ============================== + ** O O O OOOO + ** O O O O O O + ** O O O O O O + ** OOOO OOOO O OOO OOOO + ** O O O O O O O + ** O O O O O O O + ** OOOO OOOO O O OOOO + ** ============================== + ** Dr. Stefan Bosse http://www.bsslab.de + ** + ** COPYRIGHT: THIS SOFTWARE, EXECUTABLE AND SOURCE CODE IS OWNED + ** BY THE AUTHOR(S). + ** THIS SOURCE CODE MAY NOT BE COPIED, EXTRACTED, + ** MODIFIED, OR OTHERWISE USED IN A CONTEXT + ** OUTSIDE OF THE SOFTWARE SYSTEM. + ** + ** $AUTHORS: Ankit Kuwadekar, Stefan Bosse + ** $INITIAL: (C) 2015, Andrej Karpathy + ** $MODIFIED: (C) 2006-2019 bLAB by sbosse + ** $VERSION: 1.1.2 + ** + ** $INFO: + ** + ** Reinforcement Learning module that implements several common RL algorithms. + ** Portable models (TDAgent/DPAgent/DQNAgent) + ** + ** $ENDOFINFO + */ +"use strict"; + +var options = { + version:'1.1.2' +} +var Io = Require('com/io') +var R = module.exports; // the Recurrent library + + +// Utility fun +function assert(condition, message) { + // from http://stackoverflow.com/questions/15313418/javascript-assert + if (!condition) { + message = message || "Assertion failed"; + if (typeof Error !== "undefined") { + throw new Error(message); + } + throw message; // Fallback + } +} + +// Random numbers utils +var return_v = false; +var v_val = 0.0; +var gaussRandom = function() { + if(return_v) { + return_v = false; + return v_val; + } + var u = 2*Math.random()-1; + var v = 2*Math.random()-1; + var r = u*u + v*v; + if(r == 0 || r > 1) return gaussRandom(); + var c = Math.sqrt(-2*Math.log(r)/r); + v_val = v*c; // cache this + return_v = true; + return u*c; +} +var randf = function(a, b) { return Math.random()*(b-a)+a; } +var randi = function(a, b) { return Math.floor(Math.random()*(b-a)+a); } +var randn = function(mu, std){ return mu+gaussRandom()*std; } + +// helper function returns array of zeros of length n +// and uses typed arrays if available +var zeros = function(n) { + if(typeof(n)==='undefined' || isNaN(n)) { return []; } + if(typeof ArrayBuffer === 'undefined') { + // lacking browser support + var arr = new Array(n); + for(var i=0;i= 0 && ix < M.w.length); + return M.w[ix]; + }, + set: function(M, row, col, v) { + // slow but careful accessor function + var ix = (M.d * row) + col; + assert(ix >= 0 && ix < M.w.length); + M.w[ix] = v; + }, + setFrom: function(M, arr) { + for(var i=0,n=arr.length;i=0;i--) { + G.backprop[i](); // tick! + } + }, + rowPluck: function(G, m, ix) { + // pluck a row of m with index ix and return it as col vector + assert(ix >= 0 && ix < m.n); + var d = m.d; + var out = Mat(d, 1); + for(var i=0,n=d;i 0 ? out.dw[i] : 0.0; + } + } + G.backprop.push(backward); + } + return out; + }, + mul: function(G, m1, m2) { + // multiply matrices m1 * m2 + assert(m1.d === m2.n, 'matmul dimensions misaligned'); + + var n = m1.n; + var d = m2.d; + var out = Mat(n,d); + for(var i=0;i maxval) maxval = m.w[i]; } + + var s = 0.0; + for(var i=0,n=m.w.length;i clipval) { + mdwi = clipval; + num_clipped++; + } + if(mdwi < -clipval) { + mdwi = -clipval; + num_clipped++; + } + num_tot++; + + // update (and regularize) + m.w[i] += - step_size * mdwi / Math.sqrt(s.w[i] + S.smooth_eps) - regc * m.w[i]; + m.dw[i] = 0; // reset gradients for next iteration + } + } + } + solver_stats['ratio_clipped'] = num_clipped*1.0/num_tot; + return solver_stats; + } +} + +var initLSTM = function(input_size, hidden_sizes, output_size) { + // hidden size should be a list + + var model = {}; + for(var d=0;d