#!/bin/sh
# -*- python -*-
# This file is bilingual. The following shell code finds our preferred python.
# Following line is a shell no-op, and starts a multi-line Python comment.
# See https://stackoverflow.com/a/47886254
""":"
# prefer SYSU_PYTHON environment variable, python3, python
SYSU_PREFERRED_PYTHONS="python3 python"
for cmd in "${SYSU_PYTHON:-}" ${SYSU_PREFERRED_PYTHONS}; do
    if command -v "$cmd" >/dev/null; then
        export SYSU_PYTHON="$(command -v "$cmd")"
        exec "${SYSU_PYTHON}" "$0" "$@"
    fi
done
echo "==> Error: $0 could not find a python interpreter!" >&2
exit 1
":"""

import argparse
import gc
import glob
import gzip
import json
import logging
import os
import re
import statistics
import subprocess
import sys
import tempfile


def main(*argv):

    def unittest_parser(testcase, clang, preprocessor, parser, filenames, time_out, grammar=None):
        def check_ast(ast0, ast1, testcase, command0, command1):
            def wk_exit(n):
                logging.getLogger(__name__).error("---")
                if "inner" in ast0:
                    ast0.pop("inner")
                logging.getLogger(__name__).error("current node: ")
                logging.getLogger(__name__).error(ast0)
                logging.getLogger(__name__).error("---")
                logging.getLogger(__name__).error(command0)
                logging.getLogger(__name__).error(command1)
                return n

            if type(ast0) != type(ast1):
                logging.getLogger(__name__).error("\nfail: type err.")
                return wk_exit(-1)

            key_inner = ["inner"]

            if testcase in ["parser-0", "parser-1", "grammar-4", "grammar-5"]:
                key_kind = ["kind", "name", "value"]
                skip_inner = ["InitListExpr"]  # 跳过列表初始化
            elif testcase in ["parser-2", "grammar-6"]:
                key_kind = ["kind", "name", "value", "type"]
                skip_inner = []
            else:
                key_kind = []
                skip_inner = []
            key_ignore = ["id"]
            for key, value0 in ast0.items():
                if key in key_ignore:
                    continue

                def get_value1():
                    if key not in ast1:
                        logging.getLogger(__name__).error(
                            "\nfail: key '"+key+"' not find")
                        return wk_exit(-1)
                    return ast1.get(key)

                if key in key_inner:
                    if ast0.get("kind") in skip_inner:
                        continue
                    value0 = list(filter(lambda x: x.get(
                        "isImplicit", False) == False and x.get(
                        "implicit", False) == False, value0))
                    value1 = get_value1()
                    value1 = list(filter(lambda x: x.get(
                        "isImplicit", False) == False and x.get(
                        "implicit", False) == False, value1))
                    if len(value0) != len(value1):
                        logging.getLogger(__name__).error(
                            "\nfail: inner not match")
                        logging.getLogger(__name__).error(
                            str(len(value0)))
                        logging.getLogger(__name__).error(
                            str(len(value1)))
                        return wk_exit(-1)
                    for i in range(len(value0)):
                        r = check_ast(value0[i], value1[i],
                                      testcase, command0, command1)
                        if r != 0:
                            return r
                elif len(key_kind) == 0 or key in key_kind:
                    value1 = get_value1()
                    if value0 != value1:
                        logging.getLogger(__name__).error(
                            "\nfail: key not match")
                        logging.getLogger(__name__).error(key)
                        return wk_exit(-1)
            return 0
        for ff in range(len(filenames)):
            filename = filenames[ff]
            logging.getLogger(__name__).info(
                "[{}/{}] {}".format(ff, len(filenames), filename))
            try:
                command = preprocessor + " " + filename
                src = subprocess.run(
                    command, timeout=time_out, stdout=subprocess.PIPE, shell=True).stdout
                command0 = command+" | "+clang + " -cc1 -ast-dump=json"
                if grammar:
                    command1 = command+" | "+grammar
                else:
                    command1 = command+" | "+clang + " -cc1 -dump-tokens 2>&1 | "+parser
                ast0 = None
                if filename.endswith(".sysu.c"):
                    gz = filename[0:filename.rfind(".sysu.c")]+".json.gz"
                    if os.path.exists(gz):
                        with gzip.open(gz, "rb") as f:
                            ast0 = f.read()
                if ast0 == None:
                    ast0 = subprocess.run(
                        clang + " -cc1 -ast-dump=json",
                        input=src, timeout=time_out,
                        stdout=subprocess.PIPE, shell=True).stdout
                ast0 = json.loads(ast0)
                gc.collect()
                if grammar:
                    ast1 = json.loads(subprocess.run(
                        grammar,
                        input=src, timeout=time_out,  stdout=subprocess.PIPE, shell=True).stdout)
                else:
                    ast1 = json.loads(subprocess.run(
                        clang + " -cc1 -dump-tokens 2>&1 |"+parser,
                        input=src, timeout=time_out,  stdout=subprocess.PIPE, shell=True).stdout)
                gc.collect()
                r = check_ast(ast0, ast1, testcase, command0, command1)
                if r != 0:
                    return r
            except subprocess.TimeoutExpired:
                logging.getLogger(__name__).warning("TimeoutExpired. Skip.")
        return 0

    def unittest_lexer(testcase, clang, preprocessor, lexer, filenames, time_out):
        for ff in range(len(filenames)):
            filename = filenames[ff]
            logging.getLogger(__name__).info("[{}/{}] {}".format(ff, len(filenames),
                                                                 filename))
            command = preprocessor + " " + filename
            src = subprocess.run(
                command, stdout=subprocess.PIPE, shell=True).stdout
            command0 = command+" | "+clang + " -cc1 -dump-tokens 2>&1 "
            command1 = command+" | "+lexer + " 2>&1 "
            tokens0 = subprocess.run(
                clang + " -cc1 -dump-tokens 2>&1 ",
                input=src, timeout=time_out,  stdout=subprocess.PIPE, shell=True).stdout.splitlines()
            tokens1 = subprocess.run(
                lexer + " 2>&1 ",
                input=src, timeout=time_out,  stdout=subprocess.PIPE, shell=True).stdout.splitlines()

            def wk_exit(n):
                logging.getLogger(__name__).error("---")
                logging.getLogger(__name__).error(command0)
                logging.getLogger(__name__).error(command1)
                return n

            if len(tokens0) != len(tokens1):
                logging.getLogger(__name__).error(
                    "\nfail: counts of tokens are different")
                return wk_exit(-1)
            for i in range(len(tokens0)):
                def split_tokens(t):
                    tmp = str(t, encoding="utf8")
                    k1 = tmp.find("'")
                    k2 = tmp.rfind("'")
                    k3 = tmp.rfind("Loc=<")
                    if k1 == -1 or k2 == -1 or k3 == -1:
                        logging.getLogger(__name__).error(
                            "\nfail: format error")
                        logging.getLogger(__name__).error("---")
                        logging.getLogger(__name__).error(t)
                        return wk_exit(-1)
                    tok0 = tmp.split()[0].strip()
                    str0 = tmp[k1+1:k2].strip()
                    mid0 = str(tmp[k2+1:k3].strip().split())
                    loc0 = tmp[k3:].strip()
                    return [tok0, str0, mid0, loc0]
                tok0, str0, mid0, loc0 = split_tokens(tokens0[i])
                tok1, str1, mid1, loc1 = split_tokens(tokens1[i])

                if testcase in ["lexer-3", "grammar-3"]:
                    if mid0 != mid1:
                        logging.getLogger(__name__).error(
                            "\nfail: mid " + str(i+1) + "c" + str(i+1))
                        logging.getLogger(__name__).error("< " + mid0)
                        logging.getLogger(__name__).error("---")
                        logging.getLogger(__name__).error("> " + mid1)
                        return wk_exit(-1)
                if testcase in ["lexer-2", "lexer-3", "grammar-2", "grammar-3"]:
                    if loc0 != loc1:
                        logging.getLogger(__name__).error(
                            "\nfail: loc " + str(i+1) + "c" + str(i+1))
                        logging.getLogger(__name__).error("< " + loc0)
                        logging.getLogger(__name__).error("---")
                        logging.getLogger(__name__).error("> " + loc1)
                        return wk_exit(-1)
                if testcase in ["lexer-0", "lexer-1", "lexer-2", "lexer-3", "grammar-0", "grammar-1", "grammar-2", "grammar-3"]:
                    if tok0 != tok1:
                        logging.getLogger(__name__).error(
                            "\nfail: tok " + str(i+1) + "c" + str(i+1))
                        logging.getLogger(__name__).error("< " + tok0)
                        logging.getLogger(__name__).error("---")
                        logging.getLogger(__name__).error("> " + tok1)
                        return wk_exit(-1)
                    if str0 != str1:
                        logging.getLogger(__name__).error(
                            "\nfail: str " + str(i+1) + "c" + str(i+1))
                        logging.getLogger(__name__).error("< " + str0)
                        logging.getLogger(__name__).error("---")
                        logging.getLogger(__name__).error("> " + str1)
                        return wk_exit(-1)
        return 0

    def benchmark_generator_and_optimizer(
            testcase,
            clang,
            generator,
            optimizer,
            filenames,
            time_out):
        results = {"tests": []}

        def wk_exit(n):
            print(json.dumps(results))
            return n

        for ff in range(len(filenames)):
            filename = filenames[ff]
            results["tests"].append({
                "name": filename,
                "score": 0,
                "max_score": 1,
                "output": ""})
        performance = []
        score = 0
        with tempfile.TemporaryDirectory() as tmpdirname:
            ir = os.path.join(tmpdirname, "a.ll")
            exe = os.path.join(tmpdirname, "a.out")
            for ff in range(len(filenames)):
                filename = filenames[ff]
                logging.getLogger(__name__).info("[{}/{}] {}".format(ff, len(filenames),
                                                                     filename))

                def get_tm(sp):
                    val = -1
                    matchObj = re.findall(
                        b'TOTAL: (\\d*)H-(\\d*)M-(\\d*)S-(\\d*)us', sp.stderr)
                    matchObj = matchObj[-1]
                    val = int(matchObj[0])
                    val = val * 60 + int(matchObj[1])
                    val = val * 60 + int(matchObj[2])
                    val = val * 1000000 + int(matchObj[3])
                    return val

                sysu = [subprocess.run(
                        [clang, "-O3", "-S", "-emit-llvm", "-o-", filename],
                        timeout=time_out,
                        stdout=subprocess.PIPE).stdout]
                res = []
                try:
                    test_clang = 0
                    if test_clang:
                        sysu.append(subprocess.run(
                            [clang, "-O0", "-S", "-emit-llvm", "-o-", filename],
                            timeout=time_out,
                            stdout=subprocess.PIPE).stdout)
                    else:
                        s = None
                        if filename.endswith(".sysu.c"):
                            gz = filename[0:filename.rfind(
                                ".sysu.c")]+".json.gz"
                            if os.path.exists(gz):
                                with gzip.open(gz, "rb") as f:
                                    s = f.read()
                        if s == None:
                            s = subprocess.run(
                                ["clang", "-E", "-O0", filename],
                                timeout=time_out,
                                stdout=subprocess.PIPE)
                            if s.returncode:
                                return wk_exit(s.returncode)
                            s = s.stdout
                            s = subprocess.run(
                                [clang, "-cc1", "-O0", "-ast-dump=json"],
                                timeout=time_out,
                                input=s,
                                stdout=subprocess.PIPE)
                            if s.returncode:
                                return wk_exit(s.returncode)
                            s = s.stdout
                        s = json.loads(s)
                        gc.collect()
                        s = json.dumps(s).encode(encoding="ascii")
                        for it in [generator, optimizer]:
                            s = subprocess.run(
                                [it],
                                timeout=time_out,
                                input=s,
                                stdout=subprocess.PIPE)
                            if s.returncode:
                                return wk_exit(s.returncode)
                            s = s.stdout
                        sysu.append(s)
                    log = "Compile Finish."
                    logging.getLogger(__name__).info(log)
                    results["tests"][ff]["output"] += log + "\n"
                    inputs = None
                    if filename.endswith(".sysu.c"):
                        gz = filename[0:filename.rfind(".sysu.c")]+".in.gz"
                        if os.path.exists(gz):
                            with gzip.open(gz, "rb") as f:
                                inputs = f.read()
                    for it in sysu:
                        with open(ir, "wb") as f:
                            f.write(it)
                        r = subprocess.run(
                            [clang, "-O0", "-lsysy", "-lsysu", "-o", exe, ir],
                            timeout=time_out).returncode
                        if r:
                            logging.getLogger(__name__).error("ir error")
                            return wk_exit(r)
                        jt = subprocess.run(
                            [exe],
                            timeout=time_out,
                            input=inputs,
                            stderr=subprocess.PIPE,
                            stdout=subprocess.PIPE)
                        res.append(jt)
                        if len(res) == 1:
                            log = "[{}/{}] clang -O3: {}us, ret {}".format(
                                ff, len(filenames), get_tm(jt), jt.returncode)
                            logging.getLogger(__name__).info(log)
                            results["tests"][ff]["output"] += log + "\n"
                        else:
                            log = "[{}/{}] sysu-lang: {}us, ret {}".format(
                                ff, len(filenames), get_tm(jt), jt.returncode)
                            logging.getLogger(__name__).info(log)
                            results["tests"][ff]["output"] += log + "\n"

                            if res[0].returncode != jt.returncode or \
                                    res[0].stdout != jt.stdout:
                                logging.getLogger(__name__).error(
                                    "fail: not equ")
                                return wk_exit(-1)
                            if get_tm(res[0]) > 0:
                                if get_tm(jt) <= 0:
                                    logging.getLogger(__name__).error(
                                        "fail: no timer")
                                    return wk_exit(-1)
                                performance.append(get_tm(res[0])/get_tm(jt))
                            results["tests"][ff]["score"] = 1

                            score = score + 1
                except subprocess.TimeoutExpired:
                    if len(res) == 1 and get_tm(res[0]) > 0:
                        performance.append(get_tm(res[0])/(time_out * 1000000))
                    log = "TimeoutExpired."
                    logging.getLogger(__name__).warning(log)
                    results["tests"][ff]["output"] += log + "\n"
            if len(performance):
                results["leaderboard"] = [{
                    "name": "performance",
                    "order": 1,
                    "is_desc": True,
                    "value": statistics.geometric_mean(performance)}, {
                    "name": "score",
                    "order": 2,
                    "is_desc": True,
                    "value": score}]
        return wk_exit(0)

    def convert_sysy(filenames):
        for ff in range(len(filenames)):
            filename = filenames[ff]
            logging.getLogger(__name__).info(
                "[{}/{}] {}".format(ff, len(filenames), filename))
            if filename.endswith(".sy") or filename.endswith(".sysy"):
                file = open(filename, mode="r")
                src = file.read()
                file.close()
                patterns22 = ["getfloat", "putfloat",
                              "getfarray", "putfarray", "putf"]
                patterns = patterns22+[
                    "getint", "getch", "getarray", "putint",
                    "putch", "putarray", "starttime", "stoptime"]
                for pattern in patterns:
                    if src.find(pattern) != -1:
                        src = "#include <sysy/sylib.h>\n"+src
                        if filename.endswith(".sysy") or \
                                pattern in patterns22:
                            src = "#ifndef __SYSY\n" +\
                                "#define __SYSY 202203L\n" +\
                                "#endif\n" +\
                                src
                        break
                if filename.endswith(".sysy"):
                    filename = filename[0:-2]
                file = open(filename+"su.c", mode="w")
                file.write(src)
                file.close()
        return 0

    def compile_S(
            preprocessor,
            lexer,
            parser,
            generator,
            optimizer,
            translator,
            linker,
            filenames):
        for filename in filenames:
            s = subprocess.run(
                [preprocessor, filename],
                stdout=subprocess.PIPE)
            if s.returncode:
                return s.returncode
            s = s.stdout
            for it in [lexer, parser, generator, optimizer, translator]:
                s = subprocess.run(
                    [it],
                    stdout=subprocess.PIPE,
                    input=s)
                if s.returncode:
                    return s.returncode
                s = s.stdout
                if it == parser:
                    s = json.loads(s)
                    gc.collect()
                    s = json.dumps(s).encode(encoding="ascii")
            print(str(s, encoding="utf8"))

        return 0

    logging.getLogger(__name__).setLevel(level=logging.INFO)
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger(__name__).addHandler(console)

    parser = argparse.ArgumentParser(
        prog=os.path.basename(argv[0]))
    specs = ["clang", "sysu-preprocessor", "sysu-lexer", "sysu-parser",
             "sysu-generator", "sysu-optimizer", "sysu-translator", "sysu-linker", "sysu-grammar"]
    for spec in specs:
        parser.add_argument(
            "--"+spec,
            help="Specify "+spec,
            default=spec if spec == "clang" else
            os.path.join(os.path.dirname(os.path.abspath(__file__)), spec))
    parser.add_argument(
        "--recursion-limit",
        help="(default 1048576)",
        type=int,  # 防爆
        default=1048576)
    parser.add_argument(
        "--convert-sysy",
        help="Convert *.sy in [file ...]",
        action="store_true")
    parser.add_argument("--unittest")
    parser.add_argument(
        "--unittest-timeout",
        help="unittest-timeout in seconds, '-1' for no skip (default 120)",
        type=int,  # 慎重测试！
        default=120)
    parser.add_argument(
        "-S",
        help="Only run preprocess and compilation steps",
        action="store_true")
    parser.add_argument(
        "-o",
        help="Write output to <file>",
        metavar="<file>",
        default="-")
    parser.add_argument("file", default=["-"], nargs="*")
    try:
        args = parser.parse_args(argv[1:])
    except:
        logging.getLogger(__name__).error(
            os.path.basename(argv[0])+": argumentError")
        logging.getLogger(__name__).error(
            os.path.basename(argv[0])+": try to use clang")
        return subprocess.run(("clang",)+argv[1:]).returncode

    sys.setrecursionlimit(args.recursion_limit)
    if args.o != "-":
        sys.stdout = open(args.o, "w")
    filenames = []
    for f in args.file:
        tmp = glob.glob(f, recursive=True)
        for t in tmp:
            if t not in filenames:
                filenames.append(t)
    if args.convert_sysy:
        return convert_sysy(filenames)
    if args.unittest:
        filenames.sort()
        time_out = None
        if args.unittest_timeout >= 0:
            time_out = args.unittest_timeout
        for _test in ["lexer-0", "lexer-1", "lexer-2","lexer-3"]:
            if _test in args.unittest:
                return unittest_lexer(args.unittest, args.clang,
                                  args.sysu_preprocessor, args.sysu_lexer,
                                  filenames, time_out)
        for _test in ["parser-0", "parser-1", "parser-2","parser-3"]:
            if _test in args.unittest:
                return unittest_parser(args.unittest, args.clang,
                                   args.sysu_preprocessor, args.sysu_parser,
                                   filenames, time_out)
        for _test in ["grammar-0", "grammar-1", "grammar-2","grammar-3"]:
            if _test in args.unittest:
                return unittest_lexer(args.unittest, args.clang,
                                  args.sysu_preprocessor, args.sysu_grammar + " -dump-tokens ",
                                  filenames, time_out)
        for _test in ["grammar-4", "grammar-5", "grammar-6","grammar-7"]:
            if _test in args.unittest:
                return unittest_parser(args.unittest, args.clang,
                                   args.sysu_preprocessor, None,
                                   filenames, time_out, args.sysu_grammar)
        if "benchmark_generator_and_optimizer" in args.unittest:
            return benchmark_generator_and_optimizer(
                args.unittest,
                args.clang,
                args.sysu_generator,
                args.sysu_optimizer,
                filenames,
                time_out)
    if args.S:
        returncode = compile_S(
            args.sysu_preprocessor, args.sysu_lexer,
            args.sysu_parser, args.sysu_generator,
            args.sysu_optimizer, args.sysu_translator,
            args.sysu_linker, filenames)
        return returncode

    logging.getLogger(__name__).error(
        os.path.basename(argv[0])+": try to use clang")
    return subprocess.run((args.clang,)+argv[1:]).returncode


if __name__ == '__main__':
   sys.exit(main(*sys.argv))
