import requests
import json
import time
import os
import re
from http.server import BaseHTTPRequestHandler, HTTPServer
from datakit_framework import DataKitFramework

class zabbixCollect(DataKitFramework):
    name = 'zabbix_collect'
    interval = 60 # triggered interval seconds.
    _zabbix_host, _zabbix_user, _zabbix_passwd = None,None,None
    start_ts = None
    end_ts = None
    all_items = {}
    item_type_dict = {"0":"Zabbix_agent","2":"Zabbix_trapper","3":"Simple_check","5":"Zabbix_internal","7":"Zabbix_agent_(active)","8":"Zabbix_aggregate","9":"Web_item","10":"External_check","11":"Database_monitor","12":"IPMI_agent","13":"SSH_agent","14":"Telnet_agent","15":"Calculated","16":"JMX_agent","17":"SNMP_trap","18":"Dependent_item","19":"HTTP_agent","20":"SNMP_agent","21":"Script"}

    def __init__(self, **kwargs):
        super().__init__(ip = '127.0.0.1', port = 9529)
        self._zabbix_host = os.environ.get('ZABBIX_HOST', 'http://<zabbix-host>/zabbix')
        self._zabbix_user = os.environ.get('ZABBIX_USER', 'zabbix-user')
        self._zabbix_passwd = os.environ.get('ZABBIX_PASSWD', 'zabbix-pwd')
        self.sample_rate = 30
        self._zabbix_version = os.environ.get('ZABBIX_VERSION', '7.0')
        self.collect_type = os.environ.get('COLLECT_TYPE', 'api')
        self.stream_listener_port = os.environ.get('STREAM_LISTENER_PORT', 8000)
        self.only_zabbix_metrics = os.environ.get('ONLY_ZABBIX_METRICS', 'false').lower() == 'true'
        self.zabbix_server_group = os.environ.get('ZABBIX_SERVER_GROUP', 'Zabbix servers')

    def run(self):
        self.load_refer_table()
        collect_type = self.collect_type
        if collect_type == 'api':
            self.zabbix_collect_with_api()
        elif collect_type == 'stream':
            zabbix_handler = self
            class ZabbixHandler(BaseHTTPRequestHandler):
                def do_POST(self):
                    content_length = int(self.headers['Content-Length'])
                    post_data = self.rfile.read(content_length)
                    b64_str = post_data.decode("utf-8").replace("\n",",")[:-1]
                    zabbix_handler.handle_and_write(json.loads(f'[{b64_str}]'))
                    self.send_response(200)
                    self.end_headers()
            try:
                print('start listen on 8000')
                HTTPServer(('', 8000), ZabbixHandler).serve_forever()
            except Exception as e:
                return     
      
    def zabbix_collect_with_api(self):
        start_ts = int(time.time()) - 90
        end_ts = int(time.time()) - 30
        while start_ts < end_ts:
          timepoints = []
          for history in [0,3]:
              timepoints += self.do_request('history.get',time_from=start_ts, time_till=start_ts + self.sample_rate,limit=1000000, history=history)
          start_ts += self.sample_rate
          self.handle_and_write(timepoints)        
        
    def load_refer_table(self):
        # load reference table
        table = self.fetch_refer_table()[0]
        self.row_data = table['row_data']
        idx = table['column_name'].index('itemid')
        self.key_list = list(map(lambda o: o[idx], self.row_data))
        self.column_name = table['column_name']

    def get_refer_tag(self, itemid):
        try:
            curr_idx = self.key_list.index(str(itemid))
        except Exception as e:
            return None,None,None
        row = self.row_data[curr_idx]
        refer_tag = {}
        for i in range(0, len(self.column_name)):
            v = row[i]
            if v and v != '' and self.column_name[i] != 'itemid':
                refer_tag[self.column_name[i]] = v
        return refer_tag.pop('measurement'),refer_tag.pop('metric'),refer_tag

    def handle_and_write(self, time_points):
        if len(time_points) == 0:
            print('no time points to handle')
            return
        data = []
        for tp in time_points:
            timestamp = int(tp['clock'])
            timestamp = timestamp * 1000000000 + int(tp['ns'])
            measurement,metric,tags = self.get_refer_tag(tp['itemid'])
            if not measurement:
                # not match refer table
                continue
            data.append({
                "measurement": measurement,
                "timestamp": timestamp,
                "tags": tags,
                "fields": {metric: float(tp['value'])}
            })
        in_data = {
            'M':data,
            # Collect name
            'input': "zabbix_collect"
        }

        return self.report(in_data)

    def fetch_refer_table(self):
        all_hosts = {}
        rslt = []
        data = []
        # Get all host
        if self.only_zabbix_metrics:
            groups = self.do_request('hostgroup.get', output='extend', filter={'name':self.zabbix_server_group})
            if len(groups) == 0:
                raise Exception(f'zabbix server group {self.zabbix_server_group} not found')
            groupid = groups[0]['groupid']
            hosts = self.do_request('host.get', output=['hostid','host', 'interfaces'], selectInterfaces=['type', 'ip'], groupids=groupid)
        else:
            hosts = self.do_request('host.get', output=['hostid','host', 'interfaces'], selectInterfaces=['type', 'ip'])
        for host in hosts:
            # Get interface 1-agent > 2-snmp > 3-ipmi > 4-jmx
            interface = sorted(host['interfaces'], key=lambda x: x["type"])[0]
            host_tag = {'host': host['host'].lower(), 'ip': interface['ip'], 'zabbix_agent_type': interface['type'], 'source': 'zabbix'}
            all_hosts[host['hostid']] = host_tag
        print(f'all hosts num: {len(all_hosts)}')

        # Group by host and Get zabbix item
        valid_value_type = ['0', '3']
        for host_ids in [list(all_hosts.keys())[i:i + 1000] for i in range(0, len(all_hosts), 1000)]:
            item_dict = {}
            # Get item
            items = self.do_request('item.get',output=['itemid','type','hostid','key_', 'value_type'], selectItemDiscovery=['key_'], hostids=host_ids)
            for item in items:
                if not item['value_type'] in valid_value_type:
                    continue
                hostid = item['hostid']
                tags = {}
                tags.update(all_hosts[hostid])
                tags['zabbix_agent_type'] = self.item_type_dict[item['type']]
                tags['itemid'] = item['itemid']
                # Key tag
                measurement,metric, param = self.extract_item_key(item)
                if len(param) > 0:
                    tags.update(param)
                tags['metric'] = metric
                tags['measurement'] = measurement
                rslt.append(tags)
        keys = set()
        for item in rslt:
            keys.update(item.keys())
        keys = sorted(list(keys))
        for item in rslt:
            data.append(list(map(lambda o: item.get(o, ''), keys)))
        return [{"table_name": "zabbix-refer-table",
            "column_name": list(keys),
            "column_type": list(map(lambda i: 'string', range(0, len(keys)))),
            "row_data": data}]

    def extract_item_key(self, item):
        key_ = item['key_'].replace(' ','.')
        measurement = key_[:key_.find('.')]
        if '[' in measurement:
            measurement = key_[:key_.find('[')]
        idx = key_.find('[')
        if idx < 0:
            return measurement, key_, {}
        metric = key_[:idx]
        params = key_[idx + 1: len(key_) - 1].split(',')
        template = []
        if len(item['itemDiscovery']) > 0:
            t_key = item['itemDiscovery']['key_']
            idx = t_key.find('[')
            if idx >= 0:
                for p in t_key[idx + 1: len(t_key) - 1].split(','):
                    p = p.replace('\\', '/').replace('.', '_').replace('*', '_').replace('?', '_').replace('+', '_')
                    if '{#' in p:
                        k = re.findall(r'{#(.*?)}', p)[0].replace('.','_')
                        regex = p.replace('{#', '(?P<').replace('}', '>.*)')
                        if '{$' in regex:
                            k = re.findall(r'{$(.*?)}', p)[0].replace('.','_')
                            regex = regex.replace('{$', '(?P<')
                        template.append((k, regex))
                    else:
                        template.append(('', ''))
        tag = {}
        for i,p in enumerate(params):
            if i < len(template) and template[i][0] != '':
                k, regex = template[i]
                p = p.replace('\\','/')
                try:
                    rslt = re.search(regex, p)
                    tag[k.lower()] = rslt.group(k) if rslt else p
                except Exception as e:
                    print(regex, p)
            elif p != '':
                metric = metric + '.' + p
        return measurement, metric, tag


    def get_token(self):
        if float(self._zabbix_version) < 7.0:
            # zabbix < 7.0
            param = {'jsonrpc':'2.0','method': 'user.login','params': {'user': self._zabbix_user, 'password': self._zabbix_passwd}, 'id':'1'}
        else:
            # zabbix >= 7.0
            param = {'jsonrpc':'2.0','method': 'user.login','params': {'username': self._zabbix_user, 'password': self._zabbix_passwd}, 'id':'1'}
        token = requests.post(f'{self._zabbix_host}/api_jsonrpc.php', headers={'Content-Type':'application/json-rpc'}, data=json.dumps(param), verify=False).json()['result']
        if 'error' in token:
            raise Exception(token['error'])
        return token

    def do_request(self, method, **kwargs):
        param = {'jsonrpc':'2.0','method': method,'params': kwargs, 'id':'1', 'auth': self.get_token()}
        resp = requests.post(f'{self._zabbix_host}/api_jsonrpc.php', headers={'Content-Type':'application/json-rpc'}, data=json.dumps(param), verify=False, timeout=600)
        if resp.status_code != 200:
            print(f'method: {method}, args: {kwargs}')
            raise Exception(resp)
        resp = resp.json()
        # validate session activity
        if 'error' in resp and  'Session terminated' in resp['error']['data']:
            # relogin
            resp = self.do_request(method, **kwargs)
        if 'error' in resp:
            raise Exception(resp['error'])
        if type(resp) == list:
            return resp
        return resp['result']

