local table = require("table")
local string = require("string")
local luabit = require("bit")
local tostr = string.char

local double_decode_count = 0
local double_encode_count = 0

-- cache bitops
local band, rshift = luabit.band, luabit.brshift
if not rshift then -- luajit differ from luabit
	rshift = luabit.rshift
end

local function byte_mod(x, v)
	if x < 0 then
		x = x + 256
	end
	return (x % v)
end

-- buffer
local strbuf = "" -- for unpacking
local strary = {} -- for packing

local function strary_append_int16(n, h)
	if n < 0 then
		n = n + 65536
	end
	table.insert(strary, tostr(h, math.floor(n / 256), n % 256))
end

local function strary_append_int32(n, h)
	if n < 0 then
		n = n + 4294967296
	end
	table.insert(
		strary,
		tostr(h, math.floor(n / 16777216), math.floor(n / 65536) % 256, math.floor(n / 256) % 256, n % 256)
	)
end

local doubleto8bytes
local strary_append_double = function(n)
	-- assume double
	double_encode_count = double_encode_count + 1
	local b = doubleto8bytes(n)
	table.insert(strary, tostr(0xcb))
	table.insert(strary, string.reverse(b)) -- reverse: make big endian double precision
end

--- IEEE 754

-- out little endian
doubleto8bytes = function(x)
	local function grab_byte(v)
		return math.floor(v / 256), tostr(math.fmod(math.floor(v), 256))
	end
	local sign = 0
	if x < 0 then
		sign = 1
		x = -x
	end
	local mantissa, exponent = math.frexp(x)
	if x == 0 then -- zero
		mantissa, exponent = 0, 0
	elseif x == 1 / 0 then
		mantissa, exponent = 0, 2047
	else
		mantissa = (mantissa * 2 - 1) * math.ldexp(0.5, 53)
		exponent = exponent + 1022
	end

	local v, byte = "" -- convert to bytes
	x = mantissa
	for _ = 1, 6 do
		_, byte = grab_byte(x)
		v = v .. byte -- 47:0
	end
	x, byte = grab_byte(exponent * 16 + x)
	v = v .. byte -- 55:48
	x, byte = grab_byte(sign * 128 + x)
	v = v .. byte -- 63:56
	return v, x
end

local function bitstofrac(ary)
	local x = 0
	local cur = 0.5
	for _, v in ipairs(ary) do
		x = x + cur * v
		cur = cur / 2
	end
	return x
end

local function bytestobits(ary)
	local out = {}
	for _, v in ipairs(ary) do
		for j = 0, 7, 1 do
			table.insert(out, band(rshift(v, 7 - j), 1))
		end
	end
	return out
end

-- get little endian
local function bytestodouble(v)
	-- sign:1bit
	-- exp: 11bit (2048, bias=1023)
	local sign = math.floor(v:byte(8) / 128)
	local exp = band(v:byte(8), 127) * 16 + rshift(v:byte(7), 4) - 1023 -- bias
	-- frac: 52 bit
	local fracbytes = {
		band(v:byte(7), 15),
		v:byte(6),
		v:byte(5),
		v:byte(4),
		v:byte(3),
		v:byte(2),
		v:byte(1), -- big endian
	}
	local bits = bytestobits(fracbytes)

	for _ = 1, 4 do
		table.remove(bits, 1)
	end

	if sign == 1 then
		sign = -1
	else
		sign = 1
	end

	local frac = bitstofrac(bits)
	if exp == -1023 and frac == 0 then
		return 0
	end
	if exp == 1024 and frac == 0 then
		return 1 / 0 * sign
	end

	local real = math.ldexp(1 + frac, exp)

	return real * sign
end

--- packers

local packers = {}

packers.dynamic = function(data)
	local t = type(data)
	return packers[t](data)
end

packers["nil"] = function()
	table.insert(strary, tostr(0xc0))
end

packers.boolean = function(data)
	if data then -- pack true
		table.insert(strary, tostr(0xc3))
	else -- pack false
		table.insert(strary, tostr(0xc2))
	end
end

packers.number = function(n)
	if math.floor(n) == n then -- integer
		if n >= 0 then -- positive integer
			if n < 128 then -- positive fixnum
				table.insert(strary, tostr(n))
			elseif n < 256 then -- uint8
				table.insert(strary, tostr(0xcc, n))
			elseif n < 65536 then -- uint16
				strary_append_int16(n, 0xcd)
			elseif n < 4294967296 then -- uint32
				strary_append_int32(n, 0xce)
			else -- lua cannot handle uint64, so double
				strary_append_double(n)
			end
		else -- negative integer
			if n >= -32 then -- negative fixnum
				table.insert(strary, tostr(0xe0 + ((n + 256) % 32)))
			elseif n >= -128 then -- int8
				table.insert(strary, tostr(0xd0, byte_mod(n, 0x100)))
			elseif n >= -32768 then -- int16
				strary_append_int16(n, 0xd1)
			elseif n >= -2147483648 then -- int32
				strary_append_int32(n, 0xd2)
			else -- lua cannot handle int64, so double
				strary_append_double(n)
			end
		end
	else -- floating point
		strary_append_double(n)
	end
end

packers.string = function(data)
	local n = #data
	if n < 32 then
		table.insert(strary, tostr(0xa0 + n))
	elseif n < 65536 then
		strary_append_int16(n, 0xda)
	elseif n < 4294967296 then
		strary_append_int32(n, 0xdb)
	else
		error("overflow")
	end
	table.insert(strary, data)
end

packers["function"] = function()
	error("unimplemented:function")
end

packers.userdata = function()
	error("unimplemented:userdata")
end

packers.thread = function()
	error("unimplemented:thread")
end

packers.table = function(data)
	local is_map, ndata, nmax = false, 0, 0
	for k, _ in pairs(data) do
		if type(k) == "number" then
			if k > nmax then
				nmax = k
			end
		else
			is_map = true
		end
		ndata = ndata + 1
	end
	if is_map then -- pack as map
		if ndata < 16 then
			table.insert(strary, tostr(0x80 + ndata))
		elseif ndata < 65536 then
			strary_append_int16(ndata, 0xde)
		elseif ndata < 4294967296 then
			strary_append_int32(ndata, 0xdf)
		else
			error("overflow")
		end
		for k, v in pairs(data) do
			packers[type(k)](k)
			packers[type(v)](v)
		end
	else -- pack as array
		if nmax < 16 then
			table.insert(strary, tostr(0x90 + nmax))
		elseif nmax < 65536 then
			strary_append_int16(nmax, 0xdc)
		elseif nmax < 4294967296 then
			strary_append_int32(nmax, 0xdd)
		else
			error("overflow")
		end
		for i = 1, nmax do
			packers[type(data[i])](data[i])
		end
	end
end

-- types decoding

local types_map = {
	[0xc0] = "nil",
	[0xc2] = "false",
	[0xc3] = "true",
	[0xca] = "float",
	[0xcb] = "double",
	[0xcc] = "uint8",
	[0xcd] = "uint16",
	[0xce] = "uint32",
	[0xcf] = "uint64",
	[0xd0] = "int8",
	[0xd1] = "int16",
	[0xd2] = "int32",
	[0xd3] = "int64",
	[0xda] = "raw16",
	[0xdb] = "raw32",
	[0xdc] = "array16",
	[0xdd] = "array32",
	[0xde] = "map16",
	[0xdf] = "map32",
}

local type_for = function(n)
	if types_map[n] then
		return types_map[n]
	elseif n < 0xc0 then
		if n < 0x80 then
			return "fixnum_posi"
		elseif n < 0x90 then
			return "fixmap"
		elseif n < 0xa0 then
			return "fixarray"
		else
			return "fixraw"
		end
	elseif n > 0xdf then
		return "fixnum_neg"
	else
		return "undefined"
	end
end

local types_len_map = {
	uint16 = 2,
	uint32 = 4,
	uint64 = 8,
	int16 = 2,
	int32 = 4,
	int64 = 8,
	float = 4,
	double = 8,
}

--- unpackers

local unpackers = {}

local unpack_number = function(offset, ntype, nlen)
	local b1, b2, b3, b4, b5, b6, b7, b8
	if nlen >= 2 then
		b1, b2 = string.byte(strbuf, offset + 1, offset + 2)
	end
	if nlen >= 4 then
		b3, b4 = string.byte(strbuf, offset + 3, offset + 4)
	end
	if nlen >= 8 then
		b5, b6, b7, b8 = string.byte(strbuf, offset + 5, offset + 8)
	end

	if ntype == "uint16_t" then
		return b1 * 256 + b2
	elseif ntype == "uint32_t" then
		return b1 * 65536 * 256 + b2 * 65536 + b3 * 256 + b4
	elseif ntype == "int16_t" then
		local n = b1 * 256 + b2
		local nn = (65536 - n) * -1
		if nn == -65536 then
			nn = 0
		end
		return nn
	elseif ntype == "int32_t" then
		local n = b1 * 65536 * 256 + b2 * 65536 + b3 * 256 + b4
		local nn = (4294967296 - n) * -1
		if nn == -4294967296 then
			nn = 0
		end
		return nn
	elseif ntype == "double_t" then
		local s = tostr(b8, b7, b6, b5, b4, b3, b2, b1)
		double_decode_count = double_decode_count + 1
		local n = bytestodouble(s)
		return n
	else
		error("unpack_number: not impl:" .. ntype)
	end
end

local function unpacker_number(offset)
	local obj_type = type_for(string.byte(strbuf, offset + 1, offset + 1))
	local nlen = types_len_map[obj_type]
	local ntype
	if obj_type == "float" then
		error("float is not implemented")
	else
		ntype = obj_type .. "_t"
	end
	return offset + nlen + 1, unpack_number(offset + 1, ntype, nlen)
end

local function unpack_map(offset, n)
	local r = {}
	local k, v
	for _ = 1, n do
		offset, k = unpackers.dynamic(offset)
		assert(offset)
		offset, v = unpackers.dynamic(offset)
		assert(offset)
		r[k] = v
	end
	return offset, r
end

local function unpack_array(offset, n)
	local r = {}
	for i = 1, n do
		offset, r[i] = unpackers.dynamic(offset)
		assert(offset)
	end
	return offset, r
end

function unpackers.dynamic(offset)
	if offset >= #strbuf then
		error("need more data")
	end
	local obj_type = type_for(string.byte(strbuf, offset + 1, offset + 1))
	return unpackers[obj_type](offset)
end

function unpackers.undefined()
	error("unimplemented:undefined")
end

unpackers["nil"] = function(offset)
	return offset + 1, nil
end

unpackers["false"] = function(offset)
	return offset + 1, false
end

unpackers["true"] = function(offset)
	return offset + 1, true
end

unpackers.fixnum_posi = function(offset)
	return offset + 1, string.byte(strbuf, offset + 1, offset + 1)
end

unpackers.uint8 = function(offset)
	return offset + 2, string.byte(strbuf, offset + 2, offset + 2)
end

unpackers.uint16 = unpacker_number
unpackers.uint32 = unpacker_number
unpackers.uint64 = unpacker_number

unpackers.fixnum_neg = function(offset)
	-- alternative to cast below:
	local n = string.byte(strbuf, offset + 1, offset + 1)
	local nn = (256 - n) * -1
	return offset + 1, nn
end

unpackers.int8 = function(offset)
	local i = string.byte(strbuf, offset + 2, offset + 2)
	if i > 127 then
		i = (256 - i) * -1
	end
	return offset + 2, i
end

unpackers.int16 = unpacker_number
unpackers.int32 = unpacker_number
unpackers.int64 = unpacker_number

unpackers.float = unpacker_number
unpackers.double = unpacker_number

unpackers.fixraw = function(offset)
	local n = byte_mod(string.byte(strbuf, offset + 1, offset + 1), 0x1f + 1)
	--  print("unpackers.fixraw: offset:", offset, "#buf:", #buf, "n:",n  )
	local b
	if (#strbuf - 1 - offset) < n then
		error("require more data")
	end

	if n > 0 then
		b = string.sub(strbuf, offset + 1 + 1, offset + 1 + 1 + n - 1)
	else
		b = ""
	end
	return offset + n + 1, b
end

unpackers.raw16 = function(offset)
	local n = unpack_number(offset + 1, "uint16_t", 2)
	if (#strbuf - 1 - 2 - offset) < n then
		error("require more data")
	end
	local b = string.sub(strbuf, offset + 1 + 1 + 2, offset + 1 + 1 + 2 + n - 1)
	return offset + n + 3, b
end

unpackers.raw32 = function(offset)
	local n = unpack_number(offset + 1, "uint32_t", 4)
	if (#strbuf - 1 - 4 - offset) < n then
		error("require more data (possibly bug)")
	end
	local b = string.sub(strbuf, offset + 1 + 1 + 4, offset + 1 + 1 + 4 + n - 1)
	return offset + n + 5, b
end

unpackers.fixarray = function(offset)
	return unpack_array(offset + 1, byte_mod(string.byte(strbuf, offset + 1, offset + 1), 0x0f + 1))
end

unpackers.array16 = function(offset)
	return unpack_array(offset + 3, unpack_number(offset + 1, "uint16_t", 2))
end

unpackers.array32 = function(offset)
	return unpack_array(offset + 5, unpack_number(offset + 1, "uint32_t", 4))
end

unpackers.fixmap = function(offset)
	return unpack_map(offset + 1, byte_mod(string.byte(strbuf, offset + 1, offset + 1), 0x0f + 1))
end

unpackers.map16 = function(offset)
	return unpack_map(offset + 3, unpack_number(offset + 1, "uint16_t", 2))
end

unpackers.map32 = function(offset)
	return unpack_map(offset + 5, unpack_number(offset + 1, "uint32_t", 4))
end

-- Main functions

local ljp_pack = function(data)
	strary = {}
	packers.dynamic(data)
	local s = table.concat(strary, "")
	return s
end

local ljp_unpack = function(s, offset)
	if offset == nil then
		offset = 0
	end
	if type(s) ~= "string" then
		return false, "invalid argument"
	end
	local data
	strbuf = s
	offset, data = unpackers.dynamic(offset)
	return offset, data
end

local function ljp_stat()
	return {
		double_decode_count = double_decode_count,
		double_encode_count = double_encode_count,
	}
end

local msgpack = {
	pack = ljp_pack,
	unpack = ljp_unpack,
	stat = ljp_stat,
}

return msgpack