#!/usr/bin/env python3
# Copyright 2020 The Skywater PDK Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import io
import textwrap

import unittest

import common


def read_bits(file, comments, include_delimeters=False):
    if isinstance(comments, str):
        cmt_lines, cmt_multilines = common.COMMENTS[comments]
    else:
        cmt_lines, cmt_multilines = comments

    if not hasattr(file, 'read'):
        file = open(file)
    data = file.read()

    # Convert single line comments into a multiline comment with '\n' as the
    # closing delimiter.
    cmt_delimiters = []
    for cmt_start in cmt_lines:
        cmt_delimiters.append((cmt_start, '\n'))
    cmt_delimiters.extend(cmt_multilines)

    unprocessed = data
    in_comment = False

    comments_pending = []
    comments_pending_type = ''

    while True:
        # Work out the starting position of the next comment...
        next_start_pos = {}
        for cmt_start, cmt_end in cmt_delimiters:
            start_pos = unprocessed.find(cmt_start)
            if start_pos > -1:
                assert start_pos not in next_start_pos, (start_pos, next_start_pos)
                next_start_pos[start_pos] = (cmt_start, cmt_end)

        if not next_start_pos:
            next_start_pos[len(unprocessed)] = ('', '')

        # Use the earliest found comment
        min_start_pos = min(next_start_pos.keys())
        cmt_start, cmt_end = next_start_pos[min_start_pos]

        # Split into everything before the comment and after it
        non_comment, unprocessed = unprocessed[:min_start_pos], unprocessed[min_start_pos:]

        nl_pos = non_comment.rfind('\n')

        # Is this comment indented?
        if nl_pos > -1:
            comments_current_type = 'X'*(len(non_comment)-nl_pos-1)+cmt_start
        else:
            comments_current_type = 'X'*len(non_comment)+cmt_start

        # Are we joining together a comment block?
        if comments_pending and (cmt_end != '\n' or comments_current_type != comments_pending_type or nl_pos > -1):
            # Flush out any pending comments
            assert comments_pending, (comments_pending, comments_pending_type, comments_current_type)

            if not include_delimeters:
                lines = ''.join(comments_pending).splitlines(True)
                if not lines:
                    lines.append('')

                # FIXME: This is a hack....
                # Unindent the comment
                rpos = comments_pending_type.rfind('X')
                if len(comments_pending_type) > (rpos+2) and comments_pending_type[rpos+2] == '*':
                    rpos += 2
                    prefix = ' '*rpos+'*'

                    for i in range(1, len(lines)):
                        if lines[i].startswith(prefix):
                            lines[i] = lines[i][len(prefix):]

                comment = lines[0].lstrip()
                comment += textwrap.dedent(''.join(lines[1:]))
                comment = comment.strip(' \t')
            else:
                comment = ''.join(comments_pending)
            yield (True, comment)
            comments_pending.clear()

        if non_comment:
            yield (False, non_comment)

        # Have we finished processing?
        if not unprocessed and not comments_pending:
            break

        next_end_pos_a = unprocessed.find(cmt_end)
        if next_end_pos_a < 0:
            print('  !!!', 'Missing %r for closing %r' % (cmt_end, cmt_start))
            next_end_pos_a = len(unprocessed)

        next_end_pos_b = next_end_pos_a+len(cmt_end)

        next_start_pos_a = len(cmt_start)
        if include_delimeters:
            next_start_pos_a = 0
            next_end_pos_a += len(cmt_end)
        elif cmt_end == '\n':
            next_end_pos_a += 1
        comment, unprocessed = unprocessed[next_start_pos_a:next_end_pos_a], unprocessed[next_end_pos_b:]

        comments_pending.append(comment)
        comments_pending_type = comments_current_type


class Test(unittest.TestCase):
    maxDiff = None

    EXAMPLE = '''\
no comment
end comment // here1
end comment // here2
// full line end comment1
// full line end comment2

// full line end comment3
/* full line multiline comment */
before /* line one
          line two
          line three */ after
before /* inside */ after
a/*b*/c/*d*/
/*a*/b/*c*//*d*/
/*a*//*b*//*c*/d/*e*/
/* a1
 * a2
 * a3
 * a4
 */
    /* b1
     * b2
     * b3
     * b4
     */
/* list intro
   * list 1
   * list 2
 */
'''

    def check_bits(self, input_str, expected=None, **kw):
        output = []
        for cmt, s in read_bits(io.StringIO(input_str), 'c', **kw):
            if cmt:
                output.append('# ')
            else:
                output.append('> ')
            output.append(repr(s))
            output.append('\n')

        actual = "".join(output)
        print("---")
        print(actual)
        print("---")
        self.assertMultiLineEqual(actual, expected[1:])


    def test_read_bits_del(self):
        self.check_bits(self.EXAMPLE[:-1], include_delimeters=True, expected=r'''
> 'no comment\nend comment '
> 'end comment '
# '// here1\n// here2\n'
# '// full line end comment1\n// full line end comment2\n'
> '\n'
# '// full line end comment3\n'
# '/* full line multiline comment */'
> '\nbefore '
# '/* line one\n          line two\n          line three */'
> ' after\nbefore '
# '/* inside */'
> ' after\na'
# '/*b*/'
> 'c'
# '/*d*/'
> '\n'
# '/*a*/'
> 'b'
# '/*c*/'
# '/*d*/'
> '\n'
# '/*a*/'
# '/*b*/'
# '/*c*/'
> 'd'
# '/*e*/'
> '\n'
# '/* a1\n * a2\n * a3\n * a4\n */'
> '\n    '
# '/* b1\n     * b2\n     * b3\n     * b4\n     */'
> '\n'
# '/* list intro\n   * list 1\n   * list 2\n */'
''')


    def test_read_bits_nodel(self):
        self.check_bits(self.EXAMPLE[:-1], include_delimeters=False, expected=r'''
> 'no comment\nend comment '
> 'end comment '
# 'here1\nhere2\n'
# 'full line end comment1\nfull line end comment2\n'
> '\n'
# 'full line end comment3\n'
# 'full line multiline comment'
> '\nbefore '
# 'line one\nline two\nline three'
> ' after\nbefore '
# 'inside'
> ' after\na'
# 'b'
> 'c'
# 'd'
> '\n'
# 'a'
> 'b'
# 'c'
# 'd'
> '\n'
# 'a'
# 'b'
# 'c'
> 'd'
# 'e'
> '\n'
# 'a1\na2\na3\na4\n'
> '\n    '
# 'b1\nb2\nb3\nb4\n'
> '\n'
# 'list intro\n* list 1\n* list 2\n'
''')



if __name__ == "__main__":
    import doctest
    doctest.testmod()
    unittest.main()
