#!/usr/bin/env python3
# Copyright 2020 The Skywater PDK Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import csv
import io
import json
import os
import pprint
import re
import sys

from collections import defaultdict, OrderedDict
from pathlib import Path

from common import lib_extract_from_path, version_extract_from_path, lib_extract_from_name, extract_version_and_lib_from_path, copy_file_to_output
from common import convert_libname, convert_cellname, convert_cell_fullname, convert_pinname



debug = False
debug_print = lambda x: print(x) if debug else 0

Copyright_header = (
    '* Copyright 2020 The SkyWater PDK Authors\n'
    '*\n'
    '* Licensed under the Apache License, Version 2.0 (the "License");\n'
    '* you may not use this file except in compliance with the License.\n'
    '* You may obtain a copy of the License at\n'
    '*\n'
    '*     https://www.apache.org/licenses/LICENSE-2.0\n'
    '*\n'
    '* Unless required by applicable law or agreed to in writing, software\n'
    '* distributed under the License is distributed on an "AS IS" BASIS,\n'
    '* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n'
    '* See the License for the specific language governing permissions and\n'
    '* limitations under the License.\n'
    '\n')


RE_SUBCKT = re.compile('.SUBCKT (?P<name>[^ ]*) (?P<ports>[^\\n]*)(?P<contents>.*?)\\n.ENDS(\\s+(?P=name))?\\n', re.I|re.DOTALL)
RE_PININFO = re.compile('^\\*\\.PININFO$')


def process_line(l, s='='):
    start = []
    args = []
    for b in l.split():
        if s not in b:
            start.append(b)
        else:
            k, v = b.split(s, 1)
            assert k not in args, (k, args)
            if not start or start[-1] is not None:
                start.append(None)
            args.append((k, v))

    if not args:
        assert start[-1] != None, (l, start, args)
        start.append(None)

    assert start[-1] is None, start
    start.pop(-1)

    return start, args


def collect_rows(processed_lines):
    if not processed_lines:
        return [], {}

    header = []
    while processed_lines:
        args, kw = processed_lines[0]
        if kw:
            break
        header.append(args)
        processed_lines.pop(0)

    if not processed_lines:
        return header, OrderedDict()

    row_headers_args = OrderedDict()
    row_headers_kw = OrderedDict()
    for args, kw in processed_lines:
        while len(row_headers_args) < len(args):
            row_headers_args[len(row_headers_args)] = None

        for k, v in kw:
            row_headers_kw[k] = None

    row_headers = OrderedDict()
    for k in row_headers_args:
        row_headers[k] = row_headers_args[k]
    for k in row_headers_kw:
        row_headers[k] = row_headers_kw[k]
    del row_headers_args
    del row_headers_kw

    output = []
    for args, kw in processed_lines:
        if not args and not kw:
            continue

        orow = row_headers.copy()

        for i in range(len(args)):
            orow[i] = args[i]

        for k, v in kw:
            if not v:
                v = None
            elif '.' in v or 'e' in v:
                try:
                    v = float(v)
                except ValueError as e:
                    pass
            else:
                try:
                    v = int(v)
                except ValueError as e:
                    pass
            orow[k] = v

        output.append(orow)

    return header, output

#MXXXXXXX ND NG NS   NB  MNAME  <L=VAL> <W=VAL> <AD=VAL> <AS=VAL>
# mI29    Ab Bb mid2 vnb nlowvt m=1 w=0.64 l=0.15 mult=1 sa=265e-3 sb=265e-3 sd=280e-3 topography=normal area=0.063 perim=1.14


RE_MOSFET = re.compile(r'^\s*[Mm](?P<mosfet>\S+)\s+(?P<nd>\S+)\s+(?P<ng>\S+)\s+(?P<ns>\S+)\s+(?P<nb>\S+)\s+(?P<mname>\S+)\s+', re.MULTILINE)
RE_MOSFET_SUB=r'M\g<mosfet> nd=\g<nd> ng=\g<ng> ns=\g<ns> nb=\g<nb> mname=\g<mname> '

def expand_mosfet(s):
    """

    >>> expand_mosfet("mI29    Ab Bb mid2 vnb nlowvt m=1 w=0.64 l=0.15 mult=1 sa=265e-3 sb=265e-3 sd=280e-3 topography=normal area=0.063 perim=1.14")
    'MI29 nd=Ab ng=Bb ns=mid2 nb=vnb mname=nlowvt m=1 w=0.64 l=0.15 mult=1 sa=265e-3 sb=265e-3 sd=280e-3 topography=normal area=0.063 perim=1.14'

    >>> print(expand_mosfet('''\\
    ... m2 Drain Gate Source Substrate model_use_only w=5.05 l=0.5 m=1
    ... m8 Drain Gate Source Substrate model_use_only w=3.01 l=0.5 m=1
    ... ''').strip())
    M2 nd=Drain ng=Gate ns=Source nb=Substrate mname=model_use_only w=5.05 l=0.5 m=1
    M8 nd=Drain ng=Gate ns=Source nb=Substrate mname=model_use_only w=3.01 l=0.5 m=1

    """
    return RE_MOSFET.sub(RE_MOSFET_SUB, s)



# ND = the name of the drain terminal
# NG = the name of the gate terminal
# NS = the name of the source terminal
# NB = the name of the bulk (backgate) terminal
# MNAME = name of the model used

def ignore_subckt(new_lib, oldname):
    if new_lib == 'sky130_fd_sc_hdll':
        if 'iops8a_' in oldname:
            return True
        if 's8ppscio_' in oldname:
            return True
        if 's8_esd' in oldname:
            return True
        if oldname in ('inv_p', 'nor2_p'):
            return True
        if oldname in ('icecap',):
            return True
    if 'libcell' in oldname:
        return True
    if 's8pir_10r_vcells_lvs' in oldname:
        return True
    return False


RE_CONTINUES = re.compile('\\n\\+ ')
RE_PININFO_MULTILINE = re.compile(r'\n\*\.PININFO ([^\n]*?)\n\*\.PININFO ')

def squash_pininfo(contents):
    """
    >>> squash_pininfo('''
    ... BLAH
    ... *.PININFO 1
    ... *.PININFO 2
    ... *.PININFO 3
    ... BLAH
    ... ''')
    '\\nBLAH\\n*.PININFO 123\\nBLAH\\n'
    """

    output = [None, contents]
    while output[0] != output[1]:
        assert len(output) == 2, len(output)
        output.pop(0)
        output.append(RE_PININFO_MULTILINE.sub(r'\n*.PININFO \1', output[-1]))
    return output[-1]


def convert_cdl_to_tsv(new_lib, info, old_subcktname, old_ports, old_contents):
    old_portnames, old_arguments = process_line(old_ports)

    subckt_contents = squash_pininfo(old_contents)
    subckt_contents = expand_mosfet(subckt_contents)

    if debug:
        print()
        print('---', old_subcktname)
        print(subckt_contents)
        print('---')

    pininfo_count = subckt_contents.count('*.PININFO')
    assert pininfo_count <= 1, str(pininfo_count)+'\n\n'+subckt_contents

    old_pins = []
    processed_lines = []
    if old_arguments:
        processed_lines.append(([''], old_arguments))
    for l in subckt_contents.splitlines():
        l = l.strip()
        if not l:
            continue
        assert not l[0] == '.', (l, subckt_contents)
        if l.startswith("*.PININFO "):
            assert '=' not in l, l
            a, b = process_line(l[9:], ':')
            assert not a, (a, b, l, l[10:].split())
            assert b, (a, b, l)
            old_pins = b
        elif l.startswith("* NOTE"):
            print(processed_lines)
            processed_lines.append((['Note', l.split(':', 1)[1].strip()], {}))
        else:
            processed_lines.append(process_line(l))

    assert len(old_pins) > 0 or not pininfo_count

    try:
        header, rows = collect_rows(processed_lines)
    except AssertionError:
        pprint.pprint(processed_lines)
        print("subckt ----")
        print(subckt_contents)
        print("-----------")
        sys.stdout.flush()
        raise

    old_pinnames = [a for a, b in old_pins]
    if old_pins and set(old_portnames) != set(old_pinnames):
        error = '{} != {}'.format(old_portnames, old_pinnames) + '\n' + old_ports + '\n' + subckt_contents
        assert len(old_portnames) == len(old_pinnames), error
        for p in old_portnames:
            if p.startswith('pin'):
                continue
            raise AssertionError('Invalid pin name: {}'.format(repr(p))+'\n'+error)
        assert "xcmv" in old_subcktname, error
        print("Rewrote port names from:", old_portnames, "to", old_pinnames)
        old_portnames = old_pinnames
    elif not old_pins:
        old_pins = [(a, '?') for a in old_portnames]

    new_subcktname = convert_cell_fullname(old_subcktname, new_lib)
    ext_lib, new_cell = new_subcktname.split('__', 1)
    assert new_lib == ext_lib, (new_lib, ext_lib)

    new_portnames = {o: convert_pinname(o) for o in old_portnames}

    new_pins = [(convert_pinname(a), b) for a, b in old_pins]

    info[new_subcktname] = new_pins

    f = io.StringIO()
    writer = csv.writer(f, delimiter='\t')

    tsv_header = []
    if rows:
        for h in rows[0].keys():
            if isinstance(h, int):
                tsv_header.append('')
            else:
                tsv_header.append(h)
    else:
        tsv_header = ['', '']
    assert tsv_header[0] == '', tsv_header
    tsv_header[0] = 'Name'
    tsv_header.insert(1, 'Formula')

    writer.writerow([new_subcktname, ' '.join('{}:{}'.format(a, b) for a, b in new_pins)])
    writer.writerow(tsv_header)

    MNAME_IDX = None
    if 'mname' in tsv_header:
        MNAME_IDX = tsv_header.index('mname')

    for name, *extra in header:
        assert name not in old_portnames
        new_extra = []
        for e in extra:
            if e in new_portnames:
                e = new_portnames[e]
            new_extra.append(e)
        if name.startswith('XICE'):
            assert new_extra[-1] == 'icecap', (name, new_extra)
            continue

        if name[0] == 'X':
            new_extra[-1] = convert_cell_fullname(new_extra[-1], new_lib)

        writer.writerow([name, ' '.join(new_extra)])

    if rows:
        for r in rows:
            name, *extra = r.values()
            new_extra = []
            for e in extra:
                if e in new_portnames:
                    e = new_portnames[e]
                new_extra.append(e)

            formula = ''
            if name:
                if name[0].upper() == 'M':
                    formula = 'MOSFET'
                    assert MNAME_IDX is not None, (tsv_header, r)
                    new_extra[MNAME_IDX-2] = convert_cellname(new_extra[MNAME_IDX-2])
                elif name[0].upper() == 'R':
                    formula = 'RESISTOR'
                elif name[0].upper() == 'D':
                    formula = 'DIODE'

            writer.writerow([name]+[formula]+new_extra)

    return new_subcktname, f.getvalue()


def change_names_cdl(new_lib, contents):
    contents = RE_CONTINUES.sub('', contents)
    output = []
    info = {}

    last_subckt_endpos = 0
    for subckt in RE_SUBCKT.finditer(contents):
        between = contents[last_subckt_endpos:subckt.start(0)]
        if between.strip():
            for l in between.splitlines():
                assert not l or l.strip().startswith('*'), l
        last_subckt_endpos = subckt.end(0)

        old_subcktname = subckt.group('name')
        if ignore_subckt(new_lib, old_subcktname):
            continue

        new_subcktname, tsv_content = convert_cdl_to_tsv(
            new_lib, info,
            old_subcktname,
            subckt.group('ports'),
            subckt.group('contents'))

        if debug:
            print()
            print(">>>", old_subcktname)
            print(tsv_content, end="")
            print("<<<")

        output.append((new_subcktname, tsv_content))

    return output, info


def filemain(input_file, temp_dir, final_dir, new_lib, ver):

    if input_file.endswith('source.cdl'):
        return 0

    if not os.path.exists(temp_dir):
        os.makedirs(temp_dir)
    if 'vcells_lvs' in input_file:
        if not 'timedwards' in input_file:
            print('Skipping', input_file)
            return 0

    # load files
    with open(input_file, 'r') as in_f:
        contents = in_f.read()

    output, info = change_names_cdl(new_lib, contents)

    print("netlist.tsv ----")
    for m in output:
        netlistname, netlistcontent = m
        assert netlistname != '???', m

        tmp_file = os.path.join(temp_dir, netlistname+'.netlist.tsv')
        if os.path.exists(tmp_file):
            with open(tmp_file, newline="\r\n") as f:
                currentcontents = f.read()
            if currentcontents != netlistcontent:
                with open(tmp_file+'.new', 'w') as f:
                    f.write(netlistcontent)
            assert currentcontents == netlistcontent, '\n'.join([tmp_file, '--'*5, repr(currentcontents), '++'*5, repr(netlistcontent)])
        else:
            with open(tmp_file, 'w') as f:
                f.write(netlistcontent)

        copy_file_to_output(tmp_file, final_dir, new_lib, ver, netlistname)
    print("---------")
    print()
    print("Pins ----")
    for netlistname, netlistdata in sorted(info.items()):
        tmp_file = os.path.join(temp_dir, netlistname+'.netlist.pins')
        with open(tmp_file, 'w') as f:
            for pinname, pinprop in sorted(netlistdata):
                f.write(f"{pinname} {pinprop}")
                f.write("\n")
        copy_file_to_output(tmp_file, final_dir, new_lib, ver, netlistname)
    print("---------")


    return 0


def main(args, infile):
    if os.path.isdir(infile):
        all_input_files = sorted(infile.rglob('*.cdl'))
        for f in all_input_files:
            main(args, os.path.join(infile, f))
    else:
        path = os.path.abspath(infile)
        ver = version_extract_from_path(path)
        if ver is None:
            ver = 'XXXX'
        else:
            ver = "v{}.{}.{}".format(*ver)

        old_lib, new_lib, ver = extract_version_and_lib_from_path(path)
        print("-->", path, old_lib, new_lib, ver)

        filename = os.path.basename(path)
        tempdir = os.path.join(args.temp, 'cdl_split', new_lib, ver)
        print()
        print()
        print("Processing", path, "in", tempdir)
        print('-'*75)
        filemain(path, tempdir, str(args.output), new_lib, ver)
        print('-'*75)
    return 0


if __name__ == "__main__":
    import doctest
    fails, _ = doctest.testmod()
    if fails != 0:
        exit(1)
    else:
        print("Tests Passed")
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "input",
        help="The path to the source directory",
        type=Path)
    parser.add_argument(
        "output",
        help="The path to the output directory",
        type=Path)
    parser.add_argument(
        "temp",
        help="The path to the temp directory",
        type=Path)
    args = parser.parse_args()
    sys.exit(main(args, args.input))
