import redis
from redis.sentinel import Sentinel
import time
import struct
import numpy as np
from IO.IOBase import IOBase
#numpy encoder for redis, see earthworm RedisNumpy.hh for detail

def numpy2bytes(matrix):
    #size
    dim = matrix.ndim
    dimArr = matrix.shape
    header=struct.pack('Q', dim)
    for i in range(dim):
        header += struct.pack('Q', dimArr[i])
    dtypestr=str(matrix.dtype)
    header += struct.pack('Q', len(dtypestr))
    header += bytes (dtypestr, encoding='utf-8')

    encoded = header + matrix.tobytes()
    return encoded

def bytes2numpy(encoded):
    dim= struct.unpack('Q',encoded[:8])[0]
    shapeEndPos = 8 + dim*8
    shape = struct.unpack('Q'*dim, encoded[8:shapeEndPos])
    dtypelen = struct.unpack('Q',encoded[shapeEndPos:(shapeEndPos+8)])[0]
    dtypestr = ''.join(chr(i) for i in encoded[(shapeEndPos+8):(shapeEndPos+8+dtypelen)])

    return np.frombuffer(encoded[(shapeEndPos+8+dtypelen):], dtype=dtypestr).reshape(shape)

class RedisHelper(IOBase):
    def __init__(self, ip_port, passwd, timeout, db=0, master_name='mymaster'):
        super().__init__()
        self.sole = False
        self.redisWrite = None
        self.redisRead = None
        self.db = db
        begin = time.time()

        if type(ip_port) is tuple:
            self.redisRead = redis.Redis(host=ip_port[0], port=ip_port[1], db=self.db, password=passwd, socket_timeout=5, retry_on_timeout=True)
            self.redisWrite=self.redisRead
            self.sole=True
        else:
            while True:
                if not self.redisWrite:
                    sentinel = Sentinel(ip_port, socket_timeout=10)
                    try:
                        self.soleRW = False
                        self.redisRead = sentinel.slave_for(master_name, socket_timeout=30, password=passwd)
                        self.redisWrite = sentinel.master_for(master_name, socket_timeout=30, password=passwd)
                    except (redis.exceptions.ConnectionError):
                        print( 'ERROR: Redis falied')
                    if self.redisWrite is not None:
                        if self.redisRead is None: self.redisRead=self.redisWrite
                        self.soleRW = True
                        break
                    else:
                        print( 'ERROR: Redis falied')
                    if time.time()-begin>timeout:
                        print( 'WARNING: Reids timeout')
                        break

    def soleRW(self):
        return self.soleRW

    def read(self, path):
    	return self.redisRead.get(path)

    def write(self, path, data):
    	return self.redisWrite.set(path, data)

    def writeNumpyArray(self, path, matrix):
        encoded = numpy2bytes(matrix)
        self.write(path,encoded)

    def readNumpyArray(self, path):
        encoded = self.read(path)
        return bytes2numpy(encoded)
