-- This script implements Recurrent Highway Networks (Zilly and Srivastava et al., 2016) 
-- The main changes to the variational dropout implementation by Gal (2015) (which in turn is based on Zaremba's work) is
-- 1. Using a Recurrent Highway Layer instead of an LSTM
-- 2. Having an initial negative bias for the transfer gates to faciliate learning long-term dependencies
-- 3. Tuning of hyperparameters to adapt to Recurrent Highway Networks
-- All other parts of the code should be identical or close to identical to Gal's implementation.
-- 
-- Single model test perplexity is improved from Gal's from 73.4 to 64.4. 
-- Please use the tensorflow script to get this result with MC dropout and weight-tying.
-- 
-- References:
-- Zilly, J, Srivastava, R, Koutnik, J, Schmidhuber, J., "Recurrent Highway Networks", 2016
-- Gal, Y, "A Theoretically Grounded Application of Dropout in Recurrent Neural Networks", 2015.
-- Zaremba, W, Sutskever, I, Vinyals, O, "Recurrent neural network regularization", 2014.

local ok,cunn = pcall(require, 'fbcunn')
if not ok then
    ok,cunn = pcall(require,'cunn')
    if ok then
        print("warning: fbcunn not found. Falling back to cunn") 
        LookupTable = nn.LookupTable
    else
        print("Could not find cunn or fbcunn. Either is required")
        os.exit()
    end
else
    deviceParams = cutorch.getDeviceProperties(1)
    cudaComputeCapability = deviceParams.major + deviceParams.minor/10
    LookupTable = nn.LookupTable
end
require('nngraph')
require('base')
local ptb = require('data')

local params = {batch_size=20,
                seq_length=35,
                layers=1,
                decay=1.02,
                rnn_size=1000,
                dropout_x=0.25,
                dropout_i=0.75,
                dropout_h=0.25,
                dropout_o=0.75,
                init_weight=0.04,
                lr=0.2,
                vocab_size=10000,
                max_epoch=20,
                max_max_epoch=1000,
                max_grad_norm=10,
                weight_decay=1e-7,
                recurrence_depth=10,
                initial_bias=-4}


local disable_dropout = false
local function local_Dropout(input, noise)
  return nn.CMulTable()({input, noise})
end

local function transfer_data(x)
  return x:cuda()
end

local state_train, state_valid, state_test
local model = {}
local paramx, paramdx


local function rhn(x, prev_c, prev_h, noise_i, noise_h)
  -- Reshape to (batch_size, n_gates, hid_size)
  -- Then slice the n_gates dimension, i.e dimension 2
  local reshaped_noise_i = nn.Reshape(2,params.rnn_size)(noise_i)
  local reshaped_noise_h = nn.Reshape(2,params.rnn_size)(noise_h)
  local sliced_noise_i   = nn.SplitTable(2)(reshaped_noise_i)
  local sliced_noise_h   = nn.SplitTable(2)(reshaped_noise_h)
  -- Calculate all two gates
  local dropped_h_tab = {}
  local h2h_tab = {}
  local t_gate_tab = {}
  local c_gate_tab = {}
  local in_transform_tab = {}
  local s_tab = {}
  for layer_i = 1, params.recurrence_depth do
    local i2h        = {}
    h2h_tab[layer_i] = {}
    if layer_i == 1 then
      for i = 1, 2 do
        -- Use select table to fetch each gate
        local dropped_x         = local_Dropout(x, nn.SelectTable(i)(sliced_noise_i))
        dropped_h_tab[layer_i]  = local_Dropout(prev_h, nn.SelectTable(i)(sliced_noise_h))
        i2h[i]                  = nn.Linear(params.rnn_size, params.rnn_size)(dropped_x)
        h2h_tab[layer_i][i]     = nn.Linear(params.rnn_size, params.rnn_size)(dropped_h_tab[layer_i])
      end
      t_gate_tab[layer_i]       = nn.Sigmoid()(nn.AddConstant(params.initial_bias, False)(nn.CAddTable()({i2h[1], h2h_tab[layer_i][1]})))
      in_transform_tab[layer_i] = nn.Tanh()(nn.CAddTable()({i2h[2], h2h_tab[layer_i][2]}))
      c_gate_tab[layer_i]       = nn.AddConstant(1,false)(nn.MulConstant(-1, false)(t_gate_tab[layer_i]))
      s_tab[layer_i]           = nn.CAddTable()({
        nn.CMulTable()({c_gate_tab[layer_i], prev_h}),
        nn.CMulTable()({t_gate_tab[layer_i], in_transform_tab[layer_i]})
      })
    else
      for i = 1, 2 do
        -- Use select table to fetch each gate
        dropped_h_tab[layer_i]  = local_Dropout(s_tab[layer_i-1], nn.SelectTable(i)(sliced_noise_h))
        h2h_tab[layer_i][i]     = nn.Linear(params.rnn_size, params.rnn_size)(dropped_h_tab[layer_i])
      end
      t_gate_tab[layer_i]       = nn.Sigmoid()(nn.AddConstant(params.initial_bias, False)(h2h_tab[layer_i][1]))
      in_transform_tab[layer_i] = nn.Tanh()(h2h_tab[layer_i][2])
      c_gate_tab[layer_i]       = nn.AddConstant(1,false)(nn.MulConstant(-1, false)(t_gate_tab[layer_i]))
      s_tab[layer_i]           = nn.CAddTable()({
        nn.CMulTable()({c_gate_tab[layer_i], s_tab[layer_i-1]}),
        nn.CMulTable()({t_gate_tab[layer_i], in_transform_tab[layer_i]})
      })
    end
  end
  local next_h = s_tab[params.recurrence_depth]
  local next_c = prev_c
  return next_c, next_h
end

local function create_network()
  local x                = nn.Identity()()
  local y                = nn.Identity()()
  local prev_s           = nn.Identity()()
  local noise_x          = nn.Identity()()
  local noise_i          = nn.Identity()()
  local noise_h          = nn.Identity()()
  local noise_o          = nn.Identity()()
  local i                = {[0] = LookupTable(params.vocab_size,
                                              params.rnn_size)(x)}
  i[0] = local_Dropout(i[0], noise_x)
  local next_s           = {}
  local split            = {prev_s:split(2 * params.layers)}
  local noise_i_split    = {noise_i:split(params.layers)}
  local noise_h_split    = {noise_h:split(params.layers)}
  for layer_idx = 1, params.layers do
    local prev_c         = split[2 * layer_idx - 1]
    local prev_h         = split[2 * layer_idx]
    local n_i            = noise_i_split[layer_idx]
    local n_h            = noise_h_split[layer_idx]
    local next_c, next_h = rhn(i[layer_idx - 1], prev_c, prev_h, n_i, n_h)
    table.insert(next_s, next_c)
    table.insert(next_s, next_h)
    i[layer_idx] = next_h
  end
  local h2y              = nn.Linear(params.rnn_size, params.vocab_size)
  local dropped          = local_Dropout(i[params.layers], noise_o)
  local pred             = nn.LogSoftMax()(h2y(dropped))
  local err              = nn.ClassNLLCriterion()({pred, y})
  local module           = nn.gModule({x, y, prev_s, noise_x, noise_i, noise_h, noise_o},
                                      {err, nn.Identity()(next_s)})
  module:getParameters():uniform(-params.init_weight, params.init_weight)
  return transfer_data(module)
end

local function setup()
  print("Creating an RHN network.")
  local core_network = create_network()
  paramx, paramdx = core_network:getParameters()
  model.s = {}
  model.ds = {}
  model.start_s = {}
  for j = 0, params.seq_length do
    model.s[j] = {}
    for d = 1, 2 * params.layers do
      model.s[j][d] = transfer_data(torch.zeros(params.batch_size, params.rnn_size))
    end
  end
  for d = 1, 2 * params.layers do
    model.start_s[d] = transfer_data(torch.zeros(params.batch_size, params.rnn_size))
    model.ds[d] = transfer_data(torch.zeros(params.batch_size, params.rnn_size))
  end

  model.noise_i = {}
  model.noise_x = {}
  model.noise_xe = {}
  for j = 1, params.seq_length do
    model.noise_x[j] = transfer_data(torch.zeros(params.batch_size, 1))
    model.noise_xe[j] = torch.expand(model.noise_x[j], params.batch_size, params.rnn_size)
    model.noise_xe[j] = transfer_data(model.noise_xe[j])
  end
  model.noise_h = {}
  for d = 1, params.layers do
    model.noise_i[d] = transfer_data(torch.zeros(params.batch_size, 2 * params.rnn_size))
    model.noise_h[d] = transfer_data(torch.zeros(params.batch_size, 2 * params.rnn_size))
  end
  model.noise_o = transfer_data(torch.zeros(params.batch_size, params.rnn_size))
  model.core_network = core_network
  model.rnns = g_cloneManyTimes(core_network, params.seq_length)
  model.norm_dw = 0
  model.err = transfer_data(torch.zeros(params.seq_length))

  model.pred = {}
  for j = 1, params.seq_length do
    model.pred[j] = transfer_data(torch.zeros(params.batch_size, params.vocab_size))
  end
  local y                = nn.Identity()()
  local pred             = nn.Identity()()
  local err              = nn.ClassNLLCriterion()({pred, y})
  model.test             = transfer_data(nn.gModule({y, pred}, {err}))
end

local function reset_state(state)
  state.pos = 1
  if model ~= nil and model.start_s ~= nil then
    for d = 1, 2 * params.layers do
      model.start_s[d]:zero()
    end
  end
end

local function reset_ds()
  for d = 1, #model.ds do
    model.ds[d]:zero()
  end
end

-- convenience functions to handle noise
local function sample_noise(state)
  for i = 1, params.seq_length do
    model.noise_x[i]:bernoulli(1 - params.dropout_x)
    model.noise_x[i]:div(1 - params.dropout_x)
  end
 
  for b = 1, params.batch_size do
    for i = 1, params.seq_length do
      local x = state.data[state.pos + i - 1]
      for j = i+1, params.seq_length do
        if state.data[state.pos + j - 1] == x then
          model.noise_x[j][b] = model.noise_x[i][b]
          -- we only need to override the first time; afterwards subsequent are copied:
          break
        end
      end
    end
  end
  for d = 1, params.layers do
    model.noise_i[d]:bernoulli(1 - params.dropout_i)
    model.noise_i[d]:div(1 - params.dropout_i)
    model.noise_h[d]:bernoulli(1 - params.dropout_h)
    model.noise_h[d]:div(1 - params.dropout_h)
  end
  model.noise_o:bernoulli(1 - params.dropout_o)
  model.noise_o:div(1 - params.dropout_o)
end

local function reset_noise()
  for j = 1, params.seq_length do
    model.noise_x[j]:zero():add(1)
  end
  for d = 1, params.layers do
    model.noise_i[d]:zero():add(1)
    model.noise_h[d]:zero():add(1)
  end
  model.noise_o:zero():add(1)
end

local function fp(state)
  g_replace_table(model.s[0], model.start_s)
  if state.pos + params.seq_length > state.data:size(1) then
    reset_state(state)
  end

  if disable_dropout then reset_noise() else sample_noise(state) end
  for i = 1, params.seq_length do
    local x = state.data[state.pos]
    local y = state.data[state.pos + 1]
    local s = model.s[i - 1]
    model.err[i], model.s[i] = unpack(model.rnns[i]:forward(
      {x, y, s, model.noise_xe[i], model.noise_i, model.noise_h, model.noise_o}))
    state.pos = state.pos + 1
  end
  g_replace_table(model.start_s, model.s[params.seq_length])
  return model.err
end

local function bp(state)
  paramdx:zero()
  reset_ds()
  for i = params.seq_length, 1, -1 do
    state.pos = state.pos - 1
    local x = state.data[state.pos]
    local y = state.data[state.pos + 1]
    local s = model.s[i - 1]
    local derr = transfer_data(torch.ones(1))
    local tmp = model.rnns[i]:backward( -- Yarin: do we need model.noise_x[i+1]?
      {x, y, s, model.noise_xe[i], model.noise_i, model.noise_h, model.noise_o},
      {derr, model.ds})[3]
    g_replace_table(model.ds, tmp)
    cutorch.synchronize()
  end
  state.pos = state.pos + params.seq_length
  model.norm_dw = paramdx:norm()
  if model.norm_dw > params.max_grad_norm then
    local shrink_factor = params.max_grad_norm / model.norm_dw
    paramdx:mul(shrink_factor)
  end
  paramx:add(paramdx:mul(-params.lr))
  paramx:add(-params.weight_decay, paramx)
end

local function run_valid()
  reset_state(state_valid)
  disable_dropout = true
  local len = (state_valid.data:size(1) - 1) / (params.seq_length)
  local perp = 0
  for i = 1, len do
    local p = fp(state_valid)
    perp = perp + p:mean()
  end
  print("Validation set perplexity : " .. g_f3(torch.exp(perp / len)))
  disable_dropout = false
end

local function run_test()
  reset_state(state_test)
  reset_noise()
  local perp = 0
  local len = state_test.data:size(1)
  g_replace_table(model.s[0], model.start_s)
  for i = 1, (len - 1) do
    local x = state_test.data[i]
    local y = state_test.data[i + 1]
    perp_tmp, model.s[1] = unpack(model.rnns[1]:forward(
      {x, y, model.s[0], model.noise_xe[1], model.noise_i, model.noise_h, model.noise_o}))
    perp = perp + perp_tmp[1]
    g_replace_table(model.s[0], model.s[1])
  end
  print("Test set perplexity : " .. g_f3(torch.exp(perp / (len - 1))))
end

local function main()
  g_init_gpu(1)
  state_train = {data=transfer_data(ptb.traindataset(params.batch_size))}
  state_valid =  {data=transfer_data(ptb.validdataset(params.batch_size))}
  state_test =  {data=transfer_data(ptb.testdataset(params.batch_size))}
  print("Network parameters:")
  print(params)
  local states = {state_train, state_valid, state_test}
  for _, state in pairs(states) do
    reset_state(state)
  end
  setup()
  local step = 0
  local epoch = 0
  local total_cases = 0
  local beginning_time = torch.tic()
  local start_time = torch.tic()
  print("Starting training.")
  local epoch_size = torch.floor(state_train.data:size(1) / params.seq_length)
  local perps
  while epoch < params.max_max_epoch do
    local perp = fp(state_train):mean()
    if perps == nil then
      perps = torch.zeros(epoch_size):add(perp)
    end
    perps[step % epoch_size + 1] = perp
    step = step + 1
    bp(state_train)
    total_cases = total_cases + params.seq_length * params.batch_size
    epoch = step / epoch_size
    if step % torch.round(epoch_size / 10) == 10 then
      local wps = torch.floor(total_cases / torch.toc(start_time))
      local since_beginning = g_d(torch.toc(beginning_time) / 60)
      print('epoch = ' .. g_f3(epoch) ..
            ', train perp. = ' .. g_f3(torch.exp(perps:mean())) ..
            ', wps = ' .. wps ..
            ', dw:norm() = ' .. g_f3(model.norm_dw) ..
            ', lr = ' ..  g_f3(params.lr) ..
            ', since beginning = ' .. since_beginning .. ' mins.')
    end
    if step % epoch_size == 0 then
      run_valid()
      run_test()
      if epoch > params.max_epoch then
          params.lr = params.lr / params.decay
      end
    end
    if step % 33 == 0 then
      cutorch.synchronize()
      collectgarbage()
    end
  end
  run_test()
  print("Training is over.")
end

main()