#!/usr/bin/env python3
# Copyright 2020 The Skywater PDK Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import glob
import os
import pprint
import re
import subprocess
import sys
import traceback

from collections import defaultdict
from shutil import copyfile, move
from pathlib import Path
from multiprocessing import Pool

from lef_canonical import canonicalize_lef

import common
from common import \
        lib_extract_from_name, \
        extract_version_and_lib_from_path, \
        copy_file_to_output, \
        \
        convert_libname, \
        convert_cell_fullname, \
        convert_pinname

superdebug = True
debug = True


def prepend_copyright(filename):
    if filename.endswith('.lef'):
        header = common.copyright_header['#']
    elif filename.endswith('.spice'):
        header = common.copyright_header['*']
    else:
        raise IOError("Don't know how to add copyright to: "+filename)

    with open(filename, 'r+') as f:
        content = f.read()
        f.seek(0, 0)
        f.write(header + content)


def _magic_tcl_header(ofile, gdsfile):
    print('#!/bin/env wish',        file=ofile)
    print('drc off',                file=ofile)
    print('scalegrid 1 2',          file=ofile)
    print('cif istyle vendorimport',file=ofile)
    print('gds readonly true',      file=ofile)
    print('gds rescale false',      file=ofile)
    print('tech unlock *',          file=ofile)
    print('cif warning default',    file=ofile)
    print('set VDD VPWR',           file=ofile)
    print('set GND VGND',           file=ofile)
    print('set SUB SUBS',           file=ofile)
    print('gds read ' + gdsfile,    file=ofile)


def magic_get_cells(input_gds, input_techfile):
    destdir, gdsfile = os.path.split(input_gds)
    tcl_path = destdir + '/get_cells.tcl'

    # Generate a tcl script for Magic
    with open(tcl_path, 'w') as ofile:
        _magic_tcl_header(ofile, gdsfile)
        print('quit -noprompt', file=ofile)
    return run_magic(destdir, tcl_path, input_techfile)


def magic_split_gds(input_gds, input_techfile, cell_list):
    destdir, gdsfile = os.path.split(input_gds)
    tcl_path = destdir + '/split_gds.tcl'

    # Generate a tcl script for Magic
    with open(tcl_path, 'w') as ofile:
        _magic_tcl_header(ofile, gdsfile)

        # Write out the cells
        for cell in cell_list:
            escaped_cell = '{'+cell+'}'
            print(f'load {escaped_cell}',      file=ofile)
            print(f'gds write {escaped_cell}', file=ofile)
            print(f'save {escaped_cell}',      file=ofile)

        print('quit -noprompt', file=ofile)
    return run_magic(destdir, tcl_path, input_techfile)


FATAL_ERROR = re.compile('((Error parsing)|(No such file or directory)|(couldn\'t be read))')
READING_REGEX = re.compile('Reading "([^"]*)".')


def run_magic(destdir, tcl_path, input_techfile, d="null"):
    cmd = [
        'valgrind',
        'magic',
        '-nowrapper',
        '-d'+d,
        '-noconsole',
        '-T', input_techfile,
        os.path.abspath(tcl_path)
    ]
    mproc = subprocess.run(
        cmd,
        stdin = subprocess.DEVNULL,
        stdout = subprocess.PIPE,
        stderr = subprocess.STDOUT,
        cwd = destdir,
        universal_newlines = True)
    assert mproc.stdout

    max_cellname_width = 0
    output_by_cells = [('', [])]
    fatal_errors = []
    for line in mproc.stdout.splitlines():
        if line.startswith('CIF file read warning: Input off lambda grid by 1/2; snapped to grid'):
            continue
        m = FATAL_ERROR.match(line)
        if m:
            fatal_errors.append(line)
        m = READING_REGEX.match(line)
        if m:
            cell_name = m.group(1)
            max_cellname_width = max(max_cellname_width, len(cell_name))
            output_by_cells.append((cell_name, []))
        output_by_cells[-1][-1].append(line)

    for cell, lines in output_by_cells:
        prefix = "magic " + cell.ljust(max_cellname_width) + ':'
        for l in lines:
            is_error = 'rror' in l
            if superdebug or (debug and is_error):
                print(prefix, l)

    assert not mproc.stderr, mproc.stderr

    if mproc.returncode != 0 or fatal_errors:
        if fatal_errors:
            msg = ['ERROR: Magic had fatal errors in output:'] + fatal_errors
        else:
            msg = ['ERROR: Magic exited with status ' + str(mproc.returncode)]
        msg.append("")
        msg.append(" ".join(cmd))
        msg.append('='*75)
        msg.append(mproc.stdout)
        msg.append('='*75)
        msg.append(destdir)
        msg.append(tcl_path)
        msg.append('-'*75)
        msg.append(msg[0])
        raise SystemError('\n'.join(msg))

    return output_by_cells


def magic_generate_output(techfile, gds_file, lef_file, cdl_file, cell):
    print("magic_generate_output", techfile, gds_file, lef_file, cdl_file, cell)
    assert os.path.exists(techfile), techfile

    destdir, _ = os.path.split(gds_file)
    tcl_path = destdir + f'/gen_output.{cell}.tcl'
    assert not os.path.exists(tcl_path), tcl_path

    # Generate a tcl script for Magic
    with open(tcl_path, 'w') as ofile:
        assert os.path.exists(gds_file), gds_file
        _magic_tcl_header(ofile, gds_file)

        escaped_cell = '{'+cell+'}'

        # Force a canonical port order
        print(f'select top cell',          file=ofile)
        print(f'port renumber',            file=ofile)

        # LEF writing
        if lef_file:
            assert os.path.exists(lef_file), lef_file
            print(f'lef read '+lef_file,   file=ofile)
            # Complement missing annotations from the SkyWater LEF
            print(f'load {escaped_cell}',      file=ofile)
            print( 'if {![catch {set vpbpin [port VPB index]}]} {', file=ofile)
            print( '    if {[port VPB use] == "default"} {', file=ofile)
            print( '        port VPB use power', file=ofile)
            print( '    }', file=ofile)
            print( '    if {[port VPB class] == "default"} {', file=ofile)
            print( '        port VPB class bidirectional', file=ofile)
            print( '    }', file=ofile)
            print( '}', file=ofile)
            print( 'if {![catch {set vnbpin [port VNB index]}]} {', file=ofile)
            print( '    if {[port VNB use] == "default"} {', file=ofile)
            print( '        port VNB use ground', file=ofile)
            print( '    }', file=ofile)
            print( '    if {[port VNB class] == "default"} {', file=ofile)
            print( '        port VNB class bidirectional', file=ofile)
            print( '    }', file=ofile)
            print( '}', file=ofile)

        # Override the port order to match the order given in the CDL file.
        if cdl_file:
            assert os.path.exists(cdl_file), cdl_file
            print(f'readspice '+cdl_file,  file=ofile)
        print(f'load {escaped_cell}',      file=ofile)

        print(f'lef write {escaped_cell} -toplayer', file=ofile)

        # Netlist extraction
        print(f'load {escaped_cell}',      file=ofile)
        print( 'extract do local',         file=ofile)
        print( 'extract all',              file=ofile)
        print( 'ext2spice lvs',            file=ofile)
        print( 'ext2spice',                file=ofile)

        print('quit -noprompt', file=ofile)
    return run_magic(destdir, tcl_path, techfile)


class Rewriter:

    other_replacements = [
        (re.compile('S8PIR[_-]10R', re.I),
            lambda m: 'SKY130'),
        (re.compile('S8', re.I),
            lambda m: 'SKY130'),
    ]

    def __init__(self, original_path):
        self._strings = set()
        self._cache_libname = None
        self._cache_cellname = {}
        self._cache_pinname = {}
        self._cache_string = {}

        self.structures_reset()

        self.old_lib, self.new_lib, self.version = extract_version_and_lib_from_path(original_path)

        self.set_string = self.set_generic_string
        self.rewrite_string = self.rewrite_generic_string

    def add_string(self, s):
        self._strings.add(s)

    def strings(self):
        o = list(self._strings)
        o.sort()
        return o

    def structures(self):
        return dict(self._structures)

    def structures_usecount(self):
        structs = list(self._structures_instances.items())
        max_count = max(c for n, c in structs)
        structs.sort(key=lambda x: (max_count-x[1], x[1]))
        return structs

    def structures_reset(self):
        self._structures = {}
        self._structures_instances = {}
        self._structures_pins = {}
        self._structure_last = None

    def replacements(self):
        r = []
        r.append((self._cache_libname, self.new_lib))
        r.extend(self._cache_cellname.items())
        r.extend(self._cache_pinname.items())
        r.extend(self._cache_string.items())

        o = {}
        for a, b in r:
            if a == b:
                continue
            elif a is None:
                continue
            o[a] = b

        return o

    def rewrite_library_name(self, s):
        if self._cache_libname is not None:
            assert s == self._cache_libname, (s, self._cache_libname)
        else:
            self._cache_libname = s
        if s != self.new_lib:
            print("Rewriting library from", repr(s), "to", repr(self.new_lib))
        return self.new_lib

    def set_library_name(self, s):
        self.add_string(s)
        return

    def _convert_cell(self, old_name, is_real_structure=True):
        if '$$' in old_name:
            old_cellname, number = old_name.split('$$', 1)
            self.rewrite_structure_name(old_cellname, False)

        old_cellname = old_name
        if self.old_lib == 's8rf':
            if not old_name.startswith('s8rf_'):
                old_cellname = 's8rf_'+old_name

        new_name = convert_cell_fullname(old_cellname, self.new_lib)
        assert new_name.startswith(self.new_lib), (new_name, self.new_lib)
        assert old_name not in self._cache_cellname, (old_name, self._cache_cellname)
        self._cache_cellname[old_name] = new_name

    def pins(self, s):
        return list(self._structures_pins[s].keys())

    def rewrite_structure_name(self, s, is_real_structure=True):
        assert (not is_real_structure) or (s in self._structures), s
        if s not in self._cache_cellname:
            self._convert_cell(s, is_real_structure)
        ns = self._cache_cellname[s]
        if is_real_structure:
            assert s in self._structures, repr((s, ns))+'\n'+pprint.pformat(self._structures)
            assert self._structures[s] is None, repr((sn, self._structures[s]))+'\n'+pprint.pformat(self._structures)
        self._structures[s] = ns

        if s != ns:
            print("Rewriting structure from", repr(s), "to", repr(ns))

        return ns

    def set_structure_name(self, sn):
        self.add_string(sn)
        #assert 'pnp' in sn or sn not in self._structures, sn+'\n'+pprint.pformat(self._structures)

        if debug:
            if self._structure_last != None:
                print("Clearing current structure (was", self._structure_last+")")
                print()

        self._structures[sn] = None
        self._structures_instances[sn] = 0
        self._structures_pins[sn] = {}
        self._structure_last = sn

        if debug:
            print()
            print("Setting current structure to", self._structure_last)

    def set_instance_struct(self, s):
        self.add_string(s)
        assert s in self._structures, (s, self._structures)
        self._structures_instances[s] += 1
        if debug:
            if self._structure_last != None:
                print("Clearing current structure (was", self._structure_last+")")
                print()
        self._structure_last = None

    def rewrite_instance_struct(self, s):
        self.add_string(s)
        assert s in self._cache_cellname, (s, self._cache_cellname)
        sn = self._cache_cellname[s]
        if s != sn:
            print("Rewriting structure instance from", repr(s), "to", repr(sn))
        return sn

    def set_pin_name(self, s):
        self.add_string(s)
        assert self._structure_last is not None, s
        if s not in self._structures_pins[self._structure_last]:
            self._structures_pins[self._structure_last][s] = None

    def rewrite_pin_name(self, s):
        self.add_string(s)
        assert self._structure_last is not None, s
        assert self._structure_last in self._structures, self._structure_last

        new_cellname = self._structures[self._structure_last]
        if new_cellname is None:
            new_cellname = self._structure_last

        pk = f'{self._structure_last}:{s}'
        if pk in self._cache_pinname:
            return self._cache_pinname[pk]

        pn = convert_pinname(s.upper(), new_cellname)
        if s != pn:
            print(" Rewriting pin on", self._structure_last, "from", repr(s), "to", repr(pn))
        self._cache_pinname[pk] = pn
        self._structures_pins[self._structure_last][s] = pn
        return pn

    overrides = {
        ('sky130_fd_sc_hvl','scs8ls', 'inv_2'):             'scs8hvl_inv_2',
        ('sky130_fd_sc_hvl','scs8ls', 'inv_4'):             'scs8hvl_inv_4',
        ('sky130_fd_sc_ls', 'scs8ms', 'tapvgndnovpb_1'):    'scs8ls_tapvgndnovpb_1',
        ('sky130_fd_sc_ls', 'scs8lp', 'diode_2'):           'scs8ls_diode_2',
        ('sky130_fd_sc_ls', 'scs8lp', 'tap_2'):             'scs8ls_tap_2',
        ('sky130_fd_sc_ls', 'scs8ms', 'tapvgnd2_1'):        'scs8ls_tapvgnd2_1',
        ('sky130_fd_sc_ls', 'scs8ms', 'tapvgnd_1'):         'scs8ls_tapvgnd_1',
        ('sky130_fd_sc_ls', 'scs8ms', 'tapvpwrvgnd_1'):     'scs8ls_tapvpwrvgnd_1',
        ('sky130_fd_sc_ms', 'scs8ls', 'clkdlyinv3sd1_1'):   'scs8ms_clkdlyinv3sd1_1',
        ('sky130_fd_sc_ms', 'scs8ls', 'clkdlyinv3sd2_1'):   'scs8ms_clkdlyinv3sd2_1',
        ('sky130_fd_sc_ms', 'scs8ls', 'clkdlyinv3sd3_1'):   'scs8ms_clkdlyinv3sd3_1',
        ('sky130_fd_sc_ms', 'scs8ls', 'clkdlyinv5sd1_1'):   'scs8ms_clkdlyinv5sd1_1',
        ('sky130_fd_sc_ms', 'scs8lp', 'dlygate4s15_1'):     'scs8ms_dlygate4s15_1',
        ('sky130_fd_sc_ms', 'scs8ls', 'tap_1'):             'scs8ms_tap_1',
        ('sky130_fd_sc_ms', 'scs8lp', 'tap_2'):             'scs8ms_tap_2',
    }

    def set_generic_string(self, s):
        self.add_string(s)

    def rewrite_generic_string(self, old_str):
        self.add_string(old_str)
        if old_str in self._cache_string:
            return self._cache_string[old_str]

        new_str = old_str

        # Is this a simple cell name string?
        if old_str in self._cache_cellname:
            cell_fullname = self._cache_cellname[old_str]
            lib_name, cell_name = cell_fullname.split('__', 1)
            if old_str != cell_name:
                print(" Rewriting string (cell name) from", repr(old_str), "to", repr(cell_name))
            return cell_name

        # Is this a simple pin name string?
        if self._structure_last:
            if old_str in self._structures_pins[self._structure_last]:
                new_str = self._structures_pins[self._structure_last][old_str]
                if old_str != new_str:
                    print(" Rewriting string (pin name on ", self._structure_last, ") from", repr(old_str), "to", repr(new_str))
                return new_str

        # Check this isn't for a different library....
        ext_libname, ext_cellname = lib_extract_from_name(old_str)
        okey = (self.new_lib, ext_libname, ext_cellname)
        if okey in self.overrides:
            override = self.overrides[okey]
            if debug:
                print('Overriding {} with {}'.format(old_str, override))
            return self.rewrite_generic_string(override)
        elif (self.new_lib, ext_libname) == ('sky130_fd_sc_hdll', 'scs8hd'):
            override = 'scs8hdll_'+ext_cellname
            if debug:
                print('Overriding {} with {}'.format(old_str, override))
            return self.rewrite_generic_string(override)
        else:
            assert ext_libname is None or ext_libname == self.old_lib, (
                old_str, ext_libname, ext_cellname, self.old_lib)

            if ext_libname is not None and ext_cellname is not None:
                new_cellname = convert_cell_fullname(ext_cellname, self.new_lib)
                assert new_cellname.startswith(self.new_lib), (new_cellname, self.new_lib)
                new_str = new_str.replace(
                    f'{ext_libname}_{ext_cellname}', new_cellname)

        # Does this string contain a cell name?
        for old_cellname, new_cellname in self._cache_cellname.items():
            if old_cellname in new_str:
                new_str = new_str.replace(old_cellname, new_cellname)

        # Does this contain other things needed to replaced?
        for regex, rep in self.other_replacements:
            new_str = regex.sub(rep, new_str)

        self._cache_string[old_str] = new_str
        if old_str != new_str:
            print(" Rewriting string from", repr(old_str), "to", repr(new_str))
        return new_str

    def set_string_type(self, i):
        # Processed Name            Processed Description       Legacy Name  Description  GDS Layer+Purpose
        # ('nwell', 'label')        ('nwell', 'label')          nwellLabel   nwell label  64  5     -- VPB?
        # ('tap', 'label')          ('tap', 'label')            tapLabel     tap label    65  5
        # ('poly', 'label')         ('poly', 'label')           polyLabel    poly label   66  5
        # ('polytt',)               ('poly', 'drawing')         polytt       poly label   66  5
        # (('li', 1), 'text')       (('li', 1), 'drawing')      li1tt        li1 label    67  5
        # (('metal', 1), 'label')   (('metal', 1), 'label')     met1Label    met1 label   68  5
        # (('metal', 1), 'text')    (('metal', 1), 'drawing')   met1tt       met1 label   68  5
        # (('metal', 2), 'label')   (('metal', 2), 'label')     met2Label    met2 label   69  5
        # (('metal', 2), 'text')    (('metal', 2), 'drawing')   met2tt       met2 label   69  5
        # (('metal', 3), 'label')   (('metal', 3), 'label')     met3Label    met3 label   70  5
        # (('metal', 3), 'text')    (('metal', 3), 'drawing')   met3tt       met3 label   70  5
        # (('metal', 4), 'label')   (('metal', 4), 'label')     met4Label    met4 label   71  5
        # (('metal', 4), 'text')    (('metal', 4), 'drawing')   met4tt       met4 label   71  5
        # (('metal', 5), 'label')   (('metal', 5), 'label')     met5Label    met5 label   72  5
        # (('metal', 5), 'text')    (('metal', 5), 'drawing')   met5tt       met5 label   72  5
        # ('pad', 'text')           ('pad', 'label')            padText      pad label    76  5

        # ('pwelltt',)              ('pwell', 'label')          pwelltt      pwell label  64  59    -- VNB?
        # ('difftt',)               ('diffusion', 'drawing')    difftt       diff label   65  6

        # NOTE:  Purposes 5 and 16 are confused in the SkyWater GDS files.
        # One should be used for pins, the other for text, not intermingled
        # as they are.  Most non-pin text seems to be on layer 83:44.

        if i in (5, 16, 59):
            if superdebug:
                print(" Next string is a pin name on", self._structure_last)
            # FIXME: Hack -- self._structure should never be none for these types!?
            if self._structure_last is None:
                self.set_string = self.set_generic_string
                self.rewrite_string = self.rewrite_generic_string
            else:
                self.set_string = self.set_pin_name
                self.rewrite_string = self.rewrite_pin_name
        else:
            if superdebug:
                print(" Next string is a generic string")
            self.set_string = self.set_generic_string
            self.rewrite_string = self.rewrite_generic_string

    def set_label(self, s):
        return self.set_generic_string(s)



def replace_gds_strings(gds_filename, call_rewrite_on_rtypes, rewriter):
    source = gds_filename
    dest = gds_filename

    sourcedir, gdsinfile = os.path.split(source)
    destdir, gdsoutfile = os.path.split(dest)

    with open(source, 'rb') as ifile:
        gdsdata = bytearray(ifile.read())

    # If we are rewriting the structure names, we need to rewrite the instances
    # which point to the structure.
    if 'structure_name' in call_rewrite_on_rtypes:
        call_rewrite_on_rtypes.append('instance_struct')

    rtype_mapping = {
        2:  'library_name',   # libname
        6:  'structure_name', # strname - Structure Definition Name
        18: 'instance_struct',# sname - Instances structure
        22: 'string_type',    # Indicates what the next string is in reference too..
        25: 'string',         # string
        44: 'label',          # text 83 44 -- textlabel
    }
    rtype_rmapping = {v:k for k,v in rtype_mapping.items()}

    for r in call_rewrite_on_rtypes:
        assert r in rtype_rmapping, r

    datalen = len(gdsdata)
    if superdebug:
        print('Original data length = ' + str(datalen))
    dataptr = 0
    while dataptr < datalen:
        # Read stream records up to any string, then search for search text.
        bheader = gdsdata[dataptr:dataptr + 2]
        reclen = int.from_bytes(bheader, 'big')
        newlen = reclen

        # The GDS files seem to occasionally end up with trailing zero bytes.
        if newlen == 0:
            if debug:
                print("{:10d} (of {:10d} - {:10d} left)".format(dataptr, datalen, datalen-dataptr), 'Found zero-length record at position in', source)
            if superdebug:
                print('Remaining data', repr(gdsdata[dataptr:]))
            for i in range(dataptr, datalen):
                if gdsdata[i] != 0:
                    raise SystemError('Found non-zero pad byte at {} ({}): {}'.format(i, hex(i), repr(gdsdata[i])))
            break

        rtype = gdsdata[dataptr + 2]
        rtype_name = rtype_mapping.get(rtype, '??? - {}'.format(rtype))
        datatype = gdsdata[dataptr + 3]

        # FIXME: Hack to use different method for pin names...
        if datatype == 2 and rtype in rtype_mapping:
            assert datatype == 2, (rtype, datatype)
            value = int.from_bytes(gdsdata[dataptr+4:dataptr+6], 'big')
            if superdebug:
                print(
                    "{:10d} (of {:10d} - {:10d} left)".format(dataptr, datalen, datalen-dataptr),
                    'Record type = {:15s} '.format(rtype_name),
                    value,
                )
            getattr(rewriter, 'set_'+rtype_name)(value)

        # Datatype 6 is STRING
        if datatype == 6:
            bstring = gdsdata[dataptr + 4: dataptr + reclen]

            if bstring[-1] == 0:
                # Was original string padded with null byte?  If so,
                # remove the null byte.
                decoded = bstring[:-1].decode('ascii')
            else:
                decoded = bstring.decode('ascii')

            rewriter.add_string(decoded)

            if '???' in rtype_name:
                print(rtype_name, decoded)

            if rtype not in (44,):
                getattr(rewriter, 'set_'+rtype_name)(decoded)

            if rtype_name in call_rewrite_on_rtypes:
                skipped = False
                repstring = getattr(rewriter, 'rewrite_'+rtype_name)(decoded)
            else:
                skipped = True
                repstring = decoded
            assert repstring is not None

            changed = (decoded != repstring)
            if superdebug:
                print(
                    "{:10d} (of {:10d} - {:10d} left)".format(dataptr, datalen, datalen-dataptr),
                    'Record type = {:15s} '.format(rtype_name),
                    'Skipped = {:5s}'.format(str(skipped)),
                    end="  ",
                )
                if changed:
                    print(repr(decoded), '->', repr(repstring))
                else:
                    print(repr(decoded), '==', repr(repstring))

            brepstring = repstring.encode('ascii')
            newlen = len(brepstring) + 4

            # Record sizes must be even
            if newlen % 2 != 0:
                brepstring += b'\x00'
                newlen += 1

            if changed:
                #before = gdsdata[0:dataptr]
                #after = gdsdata[dataptr + reclen:]

                bnewlen = newlen.to_bytes(2, byteorder='big')
                brtype = rtype.to_bytes(1, byteorder='big')
                bdatatype = datatype.to_bytes(1, byteorder='big')

                # Assemble the new record
                newrecord = bnewlen + brtype + bdatatype + brepstring

                # Reassemble the GDS data around the new record
                #gdsdata = before + newrecord[0:newlen] + after
                gdsdata[dataptr:dataptr+reclen] = newrecord[0:newlen]

                # Adjust the data end location
                datalen += (newlen - reclen)

        # Advance the pointer past the data
        dataptr += newlen

    with open(dest, 'wb') as ofile:
        ofile.write(gdsdata)


def gds_strings_get(input_path, tmp_gds):
    rewriter = Rewriter(input_path)
    debug_stop()
    replace_gds_strings(tmp_gds, [], rewriter)
    assert not rewriter.replacements(), (tmp_gds, rewriter.replacements())
    debug_restart()
    return rewriter.strings()


def gds_strings_print(h, input_path, tmp_gds):
    strings = gds_strings_get(input_path, tmp_gds)
    max_len = max(len(s) for s in strings)
    print()
    print(h, end=" ")
    print(f"strings in GDS file {tmp_gds} (using GDS reader):",)
    print('---')
    for s in strings:
        print("  ", s.rjust(max_len))
    print('---')


def strings_get(pathname):
    string_counts = {}
    lines = subprocess.check_output(
        "strings {} | sort | uniq -c".format(pathname), shell=True).decode('utf-8').splitlines()
    for l in lines:
        c, s = l.strip().split(' ', 1)
        string_counts[s] = int(c)
    return string_counts


def strings_print(h, strings):
    max_len = max(len(s) for s in strings.keys())
    max_count = max(strings.values())
    print()
    print(h)
    print("---")
    for s, c in sorted(strings.items(), key=lambda x: (max_count-x[-1], x[0])):
        print("  ", s.rjust(max_len)+':', c)
    print("---")


ldebug = []
lsuperdebug = []

def debug_stop():
    global debug
    global superdebug
    global ldebug
    global lsuperdebug

    ldebug.append(debug)
    debug = False

    lsuperdebug.append(superdebug)
    superdebug = False


def debug_restart():
    global debug
    global superdebug
    global ldebug
    global lsuperdebug

    assert ldebug
    assert lsuperdebug

    debug = ldebug.pop(-1)
    superdebug = lsuperdebug.pop(-1)


def pool_write_new(args):
    src_temp_dir, final_dir, techfile, input_path, new_cellname = args

    temp_dir = os.path.join(src_temp_dir, new_cellname)
    if not os.path.exists(temp_dir):
        os.makedirs(temp_dir)

    cell_gds_file = os.path.join(src_temp_dir, new_cellname+'.gds')
    gds_file = os.path.abspath(os.path.join(temp_dir, new_cellname+'.gds'))
    assert os.path.exists(cell_gds_file), cell_gds_file
    assert not os.path.exists(gds_file), gds_file
    copyfile(cell_gds_file, gds_file)

    subprocess.check_call(CHANGE_GDS_DATE+" 1 0 "+gds_file, shell=True)

    if debug:
        strings_print(f"Final strings in GDS file {gds_file} (using strings):", strings_get(gds_file))
        gds_strings_print('Final', input_path, gds_file)

    debug_stop()
    rewriter = Rewriter(input_path)
    replace_gds_strings(gds_file, [], rewriter)
    r = rewriter.replacements()
    assert not r, (new_cellname, r)
    debug_restart()

    # Write the output .pins files
    no_pins = False
    for name in common.CELLS_WITH_NO_PINS:
        if name in new_cellname:
            no_pins = True
            break
    #assert no_pins or rewriter.pins(new_cellname), (new_cellname, rewriter.pins(new_cellname))
    if not rewriter.pins(new_cellname):
        print("WARNING: No pins on", new_cellname)

    gds_opin_file = os.path.join(temp_dir, new_cellname+'.gds.pins')
    assert not os.path.exists(gds_opin_file), gds_opin_file
    with open(gds_opin_file, 'w') as f:
        for pin in rewriter.pins(new_cellname):
            f.write(pin)
            f.write('\n')

    lib = rewriter.new_lib
    ver = rewriter.version

    in_cdl_file = common.get_final_path(final_dir, lib, ver, new_cellname, '.cdl')
    if not in_cdl_file:
        return

    in_lef_file = common.get_final_path(final_dir, lib, ver, new_cellname, '.lef')
    if not in_lef_file:
        return

    if not os.path.exists(in_cdl_file):
        print("Missing CDL:", in_cdl_file)
        in_cdl_file = None

    if not os.path.exists(in_lef_file):
        print("Missing LEF:", in_lef_file)
        in_lef_file = None

    if in_cdl_file and in_lef_file:
        print(new_cellname, "found both LEF and CDL inputs")

    out_lef_file = os.path.join(temp_dir, new_cellname+'.lef')
    assert not os.path.exists(out_lef_file), ("Existing lef:", out_lef_file)

    out_spice_file = os.path.join(temp_dir, new_cellname+'.spice')
    assert not os.path.exists(out_spice_file), ("Existing spice:", out_spice_file)

    magic_generate_output(techfile, gds_file, in_lef_file, in_cdl_file, new_cellname)

    assert os.path.exists(out_lef_file), ("Missing new LEF:", out_lef_file)
    rewrite_out_lef(out_lef_file)

    out_mlef_file = os.path.join(temp_dir, new_cellname+'.magic.lef')
    assert not os.path.exists(out_mlef_file), ("Existing .magic.lef:", out_mlef_file)
    os.rename(out_lef_file, out_mlef_file)

    assert os.path.exists(out_spice_file), ("Missing new spice:", out_spice_file)
    rewrite_out_spice(out_spice_file)

    # GDS file -> output
    copy_file_to_output(gds_file, final_dir, lib, ver, new_cellname)
    copy_file_to_output(gds_opin_file, final_dir, lib, ver, new_cellname)

    # .magic.lef file -> output
    copy_file_to_output(out_mlef_file, final_dir, lib, ver, new_cellname)

    # spice file -> output
    copy_file_to_output(out_spice_file, final_dir, lib, ver, new_cellname)
    return new_cellname


def rewrite_out_lef(lef_file):
    with open(lef_file, 'r') as f:
        contents = f.read()

    contents = canonicalize_lef(contents)

    with open(lef_file, 'w') as f:
        f.write(contents)

    prepend_copyright(lef_file)


RE_NGSPICE = re.compile('\\* NGSPICE file created from .* - technology: .*')
RE_WHITESPACE = re.compile('\\n+')
RE_SUBCKT = re.compile('^\\.subckt (?P<subckt>[^ ]*)([^\\n]*)$(.*?)^\\.ends', flags=re.M|re.DOTALL)

def rewrite_contents_spice(spice_file, contents):
    """

    >>> print(rewrite_contents_spice('RANDOM/sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhvtop.spice', '''
    ... .subckt sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv C1 C0 MET5
    ... X0 C1 C0 C1 C1 sky130_fd_pr__nfet_05v0_nvt w=1e+07u l=4e+06u
    ... X1 C1 C0 C1 C1 sky130_fd_pr__nfet_05v0_nvt w=1e+07u l=4e+06u
    ... .ends
    ... .subckt sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhvtop M5 C0 SUB
    ... Xsky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv_0[0|0] SUB C0 M5 sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv
    ... Xsky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv_0[1|0] SUB C0 M5 sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv
    ... Xsky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv_0[0|1] SUB C0 M5 sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv
    ... Xsky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv_0[1|1] SUB C0 M5 sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv
    ... .ends
    ... ''').strip())
    .subckt sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhvtop M5 C0 SUB
    Xsky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv_0[0|0] SUB C0 M5 sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv
    Xsky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv_0[1|0] SUB C0 M5 sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv
    Xsky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv_0[0|1] SUB C0 M5 sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv
    Xsky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv_0[1|1] SUB C0 M5 sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv
    .ends

    >>> print(rewrite_contents_spice('RANDOM/sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhvtop.spice', '''
    ... .subckt sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv C1 C0 MET5
    ... X0 C1 C0 C1 C1 sky130_fd_pr__nfet_05v0_nvt w=1e+07u l=4e+06u
    ... X1 C1 C0 C1 C1 sky130_fd_pr__nfet_05v0_nvt w=1e+07u l=4e+06u
    ... .ends
    ... ''').strip())
    .subckt sky130_fd_pr__cap_vpp_11p3x11p8_l1m1m2m3m4_shieldm5_nhv C1 C0 MET5
    X0 C1 C0 C1 C1 sky130_fd_pr__nfet_05v0_nvt w=1e+07u l=4e+06u
    X1 C1 C0 C1 C1 sky130_fd_pr__nfet_05v0_nvt w=1e+07u l=4e+06u
    .ends

    """
    contents = contents.replace('technology: (null)', 'technology: sky130A')

    contents = RE_NGSPICE.sub('', contents)
    contents = RE_WHITESPACE.sub('\n', contents)

    if contents.count('.subckt') > 1:
        basename = spice_file.rsplit('/', 1)[-1].split('.', 1)[0]
        def replace_subckt(m):
            if m.group('subckt') != basename:
                return ''
            return m.group(0)

        new_contents = RE_SUBCKT.sub(replace_subckt, contents)
        if new_contents.count('.subckt') == 1:
            contents = new_contents
    return contents


def rewrite_out_spice(spice_file):
    with open(spice_file, 'r') as f:
        contents = f.read()

    contents = rewrite_contents_spice(spice_file, contents)
    with open(spice_file, 'w') as f:
        f.write(contents)

    prepend_copyright(spice_file)


def filemain(input_path, temp_dir, final_dir, args):

    techfile = args.techfile

    tmp_gds = temp_dir + '/input.gds'
    copyfile(input_path, tmp_gds)

    rewriter = Rewriter(input_path)

    if debug:
        gds_strings_print('Initial', input_path, tmp_gds)

    # First rewrite the cell names
    if debug:
        print()
        print("Rewriting library and structure names")
        print("-------------------------------------")
    replace_gds_strings(tmp_gds, ['library_name', 'structure_name'], rewriter)
    if debug:
        print("-------------------------------------")

    cell_rewrites = list(rewriter.structures().items())
    if not cell_rewrites:
        print("WARNING: No cells found!")
        return

    #if 's8rf' in input_path and 's8rf2' not in input_path:
    #    cell_removed = [x for x in cell_rewrites if not x.startswith('s8rf_')]
    #    print('Ignoring:', cell_removed)
    #    cell_rewrites = [x for x in cell_rewrites if x.startswith('s8rf_')]

    # Write out the cell list
    with open(temp_dir + '/cells.list', 'w') as f:
        max_len = max(len(s) for s in cell_rewrites)
        for from_str, to_str in sorted(cell_rewrites):
            f.write("  {:s} -> {:s}".format(from_str.rjust(max_len), to_str))
            f.write('\n')

    # Second rewrite any remaining strings.
    # -----
    # The strings could have references to cell names and there is no guarantee
    # that the cell instance definition will appear before a string which
    # happens to contain the name.
    # Hence, we want to know all the cell rewrites before rewriting strings.
    if debug:
        print()
        print("Rewriting strings (inc pin names)")
        print("-------------------------------------")
    rewriter.structures_reset()
    replace_gds_strings(tmp_gds, ['string'], rewriter)
    if debug:
        print("-------------------------------------")

    if debug:
        gds_strings_print('Before split', input_path, tmp_gds)

    structs = list(rewriter.structures_usecount())
    for name, count in structs:
        with open(os.path.join(temp_dir, name+'.gds.src.pins'), 'w') as f:
            for pin in rewriter.pins(name):
                f.write(pin)
                f.write('\n')

    # Write out the rewrite list
    with open(temp_dir + '/rewrite.list', 'w') as f:
        r = rewriter.replacements()
        max_len = max(len(x) for x in r)
        for from_str, to_str in sorted(r.items()):
            f.write(from_str.rjust(max_len))
            f.write(" -> ")
            f.write(to_str)
            f.write('\n')

    filtered_cells = set()
    for name in rewriter.structures():
        if 'libcell' in name:
            print("Skipping", name)
            continue
        if 'vcells' in name:
            print("Skipping", name)
            continue
        filtered_cells.add(name)
    filtered_cells = list(filtered_cells)
    filtered_cells.sort()

    # Split apart the GDS file
    input_techfile = os.path.abspath(techfile)
    output_by_cells = magic_split_gds(tmp_gds, input_techfile, filtered_cells)

    #for cellname in filtered_cells:
    #    pool_write_new([temp_dir, final_dir, input_techfile, input_path, cellname])
    #
    #return

    # Extract the netlist
    def pool_write_new_args(filtered_cells):
        for new_cellname in filtered_cells:
            yield temp_dir, final_dir, input_techfile, input_path, new_cellname

    pool = Pool()
    for cellname in pool.imap_unordered(pool_write_new, pool_write_new_args(filtered_cells)):
        if not cellname:
            continue
        if superdebug:
            print("Finished creating new files for:", cellname)
    pool.close()
    pool.join()



__dir__ = os.path.dirname(os.path.realpath(__file__))
CHANGE_GDS_DATE = os.path.join(__dir__, 'change_gds_date.py')


if __name__ == "__main__":
    import doctest
    fails, _ = doctest.testmod()
    if fails != 0:
        sys.exit("Some test failed")
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "input",
        help="The path to the source directory/file",
        type=Path)
    parser.add_argument(
        "output",
        help="The path to the output directory",
        type=Path)
    parser.add_argument(
        "techfile",
        help="Full path to the techfile",
        type=str)
    parser.add_argument(
        "temp",
        help="The path to the temp directory",
        type=Path)
    args = parser.parse_args()
    common.main('gds', filemain, args)
