1#!/usr/bin/env python
2"""A glorified C pre-processor parser."""
3
4import ctypes
5import logging
6import os
7import re
8import site
9import unittest
10import utils
11
12top = os.getenv('ANDROID_BUILD_TOP')
13if top is None:
14    utils.panic('ANDROID_BUILD_TOP not set.\n')
15
16# Set up the env vars for libclang.
17site.addsitedir(os.path.join(top, 'external/clang/bindings/python'))
18
19import clang.cindex
20from clang.cindex import conf
21from clang.cindex import Cursor
22from clang.cindex import CursorKind
23from clang.cindex import SourceLocation
24from clang.cindex import SourceRange
25from clang.cindex import TokenGroup
26from clang.cindex import TokenKind
27from clang.cindex import TranslationUnit
28
29# Set up LD_LIBRARY_PATH to include libclang.so, libLLVM.so, and etc.
30# Note that setting LD_LIBRARY_PATH with os.putenv() sometimes doesn't help.
31clang.cindex.Config.set_library_file(os.path.join(top, 'prebuilts/sdk/tools/linux/lib64/libclang_android.so'))
32
33from defaults import *
34
35
36debugBlockParser = False
37debugCppExpr = False
38debugOptimIf01 = False
39
40###############################################################################
41###############################################################################
42#####                                                                     #####
43#####           C P P   T O K E N S                                       #####
44#####                                                                     #####
45###############################################################################
46###############################################################################
47
48# the list of supported C-preprocessor tokens
49# plus a couple of C tokens as well
50tokEOF = "\0"
51tokLN = "\n"
52tokSTRINGIFY = "#"
53tokCONCAT = "##"
54tokLOGICAND = "&&"
55tokLOGICOR = "||"
56tokSHL = "<<"
57tokSHR = ">>"
58tokEQUAL = "=="
59tokNEQUAL = "!="
60tokLT = "<"
61tokLTE = "<="
62tokGT = ">"
63tokGTE = ">="
64tokELLIPSIS = "..."
65tokSPACE = " "
66tokDEFINED = "defined"
67tokLPAREN = "("
68tokRPAREN = ")"
69tokNOT = "!"
70tokPLUS = "+"
71tokMINUS = "-"
72tokMULTIPLY = "*"
73tokDIVIDE = "/"
74tokMODULUS = "%"
75tokBINAND = "&"
76tokBINOR = "|"
77tokBINXOR = "^"
78tokCOMMA = ","
79tokLBRACE = "{"
80tokRBRACE = "}"
81tokARROW = "->"
82tokINCREMENT = "++"
83tokDECREMENT = "--"
84tokNUMBER = "<number>"
85tokIDENT = "<ident>"
86tokSTRING = "<string>"
87
88
89class Token(clang.cindex.Token):
90    """A class that represents one token after parsing.
91
92    It inherits the class in libclang, with an extra id property to hold the
93    new spelling of the token. The spelling property in the base class is
94    defined as read-only. New names after macro instantiation are saved in
95    their ids now. It also facilitates the renaming of directive optimizations
96    like replacing 'ifndef X' with 'if !defined(X)'.
97
98    It also overrides the cursor property of the base class. Because the one
99    in libclang always queries based on a single token, which usually doesn't
100    hold useful information. The cursor in this class can be set by calling
101    CppTokenizer.getTokensWithCursors(). Otherwise it returns the one in the
102    base class.
103    """
104
105    def __init__(self, tu=None, group=None, int_data=None, ptr_data=None,
106                 cursor=None):
107        clang.cindex.Token.__init__(self)
108        self._id = None
109        self._tu = tu
110        self._group = group
111        self._cursor = cursor
112        # self.int_data and self.ptr_data are from the base class. But
113        # self.int_data doesn't accept a None value.
114        if int_data is not None:
115            self.int_data = int_data
116        self.ptr_data = ptr_data
117
118    @property
119    def id(self):
120        """Name of the token."""
121        if self._id is None:
122            return self.spelling
123        else:
124            return self._id
125
126    @id.setter
127    def id(self, new_id):
128        """Setting name of the token."""
129        self._id = new_id
130
131    @property
132    def cursor(self):
133        if self._cursor is None:
134            self._cursor = clang.cindex.Token.cursor
135        return self._cursor
136
137    @cursor.setter
138    def cursor(self, new_cursor):
139        self._cursor = new_cursor
140
141    def __repr__(self):
142        if self.id == 'defined':
143            return self.id
144        elif self.kind == TokenKind.IDENTIFIER:
145            return "(ident %s)" % self.id
146
147        return self.id
148
149    def __str__(self):
150        return self.id
151
152
153class BadExpectedToken(Exception):
154    """An exception that will be raised for unexpected tokens."""
155    pass
156
157
158# The __contains__ function in libclang SourceRange class contains a bug. It
159# gives wrong result when dealing with single line range.
160# Bug filed with upstream:
161# http://llvm.org/bugs/show_bug.cgi?id=22243, http://reviews.llvm.org/D7277
162def SourceRange__contains__(self, other):
163    """Determine if a given location is inside the range."""
164    if not isinstance(other, SourceLocation):
165        return False
166    if other.file is None and self.start.file is None:
167        pass
168    elif (self.start.file.name != other.file.name or
169          other.file.name != self.end.file.name):
170        # same file name
171        return False
172    # same file, in between lines
173    if self.start.line < other.line < self.end.line:
174        return True
175    # same file, same line
176    elif self.start.line == other.line == self.end.line:
177        if self.start.column <= other.column <= self.end.column:
178            return True
179    elif self.start.line == other.line:
180        # same file first line
181        if self.start.column <= other.column:
182            return True
183    elif other.line == self.end.line:
184        # same file last line
185        if other.column <= self.end.column:
186            return True
187    return False
188
189
190SourceRange.__contains__ = SourceRange__contains__
191
192
193################################################################################
194################################################################################
195#####                                                                      #####
196#####           C P P   T O K E N I Z E R                                  #####
197#####                                                                      #####
198################################################################################
199################################################################################
200
201
202class CppTokenizer(object):
203    """A tokenizer that converts some input text into a list of tokens.
204
205    It calls libclang's tokenizer to get the parsed tokens. In addition, it
206    updates the cursor property in each token after parsing, by calling
207    getTokensWithCursors().
208    """
209
210    clang_flags = ['-E', '-x', 'c']
211    options = TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD
212
213    def __init__(self):
214        """Initialize a new CppTokenizer object."""
215        self._indexer = clang.cindex.Index.create()
216        self._tu = None
217        self._index = 0
218        self.tokens = None
219
220    def _getTokensWithCursors(self):
221        """Helper method to return all tokens with their cursors.
222
223        The cursor property in a clang Token doesn't provide enough
224        information. Because it is queried based on single token each time
225        without any context, i.e. via calling conf.lib.clang_annotateTokens()
226        with only one token given. So we often see 'INVALID_FILE' in one
227        token's cursor. In this function it passes all the available tokens
228        to get more informative cursors.
229        """
230
231        tokens_memory = ctypes.POINTER(clang.cindex.Token)()
232        tokens_count = ctypes.c_uint()
233
234        conf.lib.clang_tokenize(self._tu, self._tu.cursor.extent,
235                                ctypes.byref(tokens_memory),
236                                ctypes.byref(tokens_count))
237
238        count = int(tokens_count.value)
239
240        # If we get no tokens, no memory was allocated. Be sure not to return
241        # anything and potentially call a destructor on nothing.
242        if count < 1:
243            return
244
245        cursors = (Cursor * count)()
246        cursors_memory = ctypes.cast(cursors, ctypes.POINTER(Cursor))
247
248        conf.lib.clang_annotateTokens(self._tu, tokens_memory, count,
249                                      cursors_memory)
250
251        tokens_array = ctypes.cast(
252            tokens_memory,
253            ctypes.POINTER(clang.cindex.Token * count)).contents
254        token_group = TokenGroup(self._tu, tokens_memory, tokens_count)
255
256        tokens = []
257        for i in xrange(0, count):
258            token = Token(self._tu, token_group,
259                          int_data=tokens_array[i].int_data,
260                          ptr_data=tokens_array[i].ptr_data,
261                          cursor=cursors[i])
262            # We only want non-comment tokens.
263            if token.kind != TokenKind.COMMENT:
264                tokens.append(token)
265
266        return tokens
267
268    def parseString(self, lines):
269        """Parse a list of text lines into a BlockList object."""
270        file_ = 'no-filename-available.c'
271        self._tu = self._indexer.parse(file_, self.clang_flags,
272                                       unsaved_files=[(file_, lines)],
273                                       options=self.options)
274        self.tokens = self._getTokensWithCursors()
275
276    def parseFile(self, file_):
277        """Parse a file into a BlockList object."""
278        self._tu = self._indexer.parse(file_, self.clang_flags,
279                                       options=self.options)
280        self.tokens = self._getTokensWithCursors()
281
282    def nextToken(self):
283        """Return next token from the list."""
284        if self._index < len(self.tokens):
285            t = self.tokens[self._index]
286            self._index += 1
287            return t
288        else:
289            return None
290
291
292class CppStringTokenizer(CppTokenizer):
293    """A CppTokenizer derived class that accepts a string of text as input."""
294
295    def __init__(self, line):
296        CppTokenizer.__init__(self)
297        self.parseString(line)
298
299
300class CppFileTokenizer(CppTokenizer):
301    """A CppTokenizer derived class that accepts a file as input."""
302
303    def __init__(self, file_):
304        CppTokenizer.__init__(self)
305        self.parseFile(file_)
306
307
308# Unit testing
309#
310class CppTokenizerTests(unittest.TestCase):
311    """CppTokenizer tests."""
312
313    def get_tokens(self, token_string, line_col=False):
314        tokens = CppStringTokenizer(token_string)
315        token_list = []
316        while True:
317            token = tokens.nextToken()
318            if not token:
319                break
320            if line_col:
321                token_list.append((token.id, token.location.line,
322                                   token.location.column))
323            else:
324                token_list.append(token.id)
325        return token_list
326
327    def test_hash(self):
328        self.assertEqual(self.get_tokens("#an/example  && (01923_xy)"),
329                         ["#", "an", "/", "example", tokLOGICAND, tokLPAREN,
330                          "01923_xy", tokRPAREN])
331
332    def test_parens(self):
333        self.assertEqual(self.get_tokens("FOO(BAR) && defined(BAZ)"),
334                         ["FOO", tokLPAREN, "BAR", tokRPAREN, tokLOGICAND,
335                          "defined", tokLPAREN, "BAZ", tokRPAREN])
336
337    def test_comment(self):
338        self.assertEqual(self.get_tokens("/*\n#\n*/"), [])
339
340    def test_line_cross(self):
341        self.assertEqual(self.get_tokens("first\nsecond"), ["first", "second"])
342
343    def test_line_cross_line_col(self):
344        self.assertEqual(self.get_tokens("first second\n  third", True),
345                         [("first", 1, 1), ("second", 1, 7), ("third", 2, 3)])
346
347    def test_comment_line_col(self):
348        self.assertEqual(self.get_tokens("boo /* what the\nhell */", True),
349                         [("boo", 1, 1)])
350
351    def test_escapes(self):
352        self.assertEqual(self.get_tokens("an \\\n example", True),
353                         [("an", 1, 1), ("example", 2, 2)])
354
355
356################################################################################
357################################################################################
358#####                                                                      #####
359#####           C P P   E X P R E S S I O N S                              #####
360#####                                                                      #####
361################################################################################
362################################################################################
363
364
365class CppExpr(object):
366    """A class that models the condition of #if directives into an expr tree.
367
368    Each node in the tree is of the form (op, arg) or (op, arg1, arg2) where
369    "op" is a string describing the operation
370    """
371
372    unaries = ["!", "~"]
373    binaries = ["+", "-", "<", "<=", ">=", ">", "&&", "||", "*", "/", "%",
374                "&", "|", "^", "<<", ">>", "==", "!=", "?", ":"]
375    precedences = {
376        "?": 1, ":": 1,
377        "||": 2,
378        "&&": 3,
379        "|": 4,
380        "^": 5,
381        "&": 6,
382        "==": 7, "!=": 7,
383        "<": 8, "<=": 8, ">": 8, ">=": 8,
384        "<<": 9, ">>": 9,
385        "+": 10, "-": 10,
386        "*": 11, "/": 11, "%": 11,
387        "!": 12, "~": 12
388    }
389
390    def __init__(self, tokens):
391        """Initialize a CppExpr. 'tokens' must be a CppToken list."""
392        self.tokens = tokens
393        self._num_tokens = len(tokens)
394        self._index = 0
395
396        if debugCppExpr:
397            print "CppExpr: trying to parse %s" % repr(tokens)
398        self.expr = self.parseExpression(0)
399        if debugCppExpr:
400            print "CppExpr: got " + repr(self.expr)
401        if self._index != self._num_tokens:
402            self.throw(BadExpectedToken, "crap at end of input (%d != %d): %s"
403                       % (self._index, self._num_tokens, repr(tokens)))
404
405    def throw(self, exception, msg):
406        if self._index < self._num_tokens:
407            tok = self.tokens[self._index]
408            print "%d:%d: %s" % (tok.location.line, tok.location.column, msg)
409        else:
410            print "EOF: %s" % msg
411        raise exception(msg)
412
413    def expectId(self, id):
414        """Check that a given token id is at the current position."""
415        token = self.tokens[self._index]
416        if self._index >= self._num_tokens or token.id != id:
417            self.throw(BadExpectedToken,
418                       "### expecting '%s' in expression, got '%s'" % (
419                           id, token.id))
420        self._index += 1
421
422    def is_decimal(self):
423        token = self.tokens[self._index].id
424        if token[-1] in "ULul":
425            token = token[:-1]
426        try:
427            val = int(token, 10)
428            self._index += 1
429            return ('int', val)
430        except ValueError:
431            return None
432
433    def is_octal(self):
434        token = self.tokens[self._index].id
435        if token[-1] in "ULul":
436            token = token[:-1]
437        if len(token) < 2 or token[0] != '0':
438            return None
439        try:
440            val = int(token, 8)
441            self._index += 1
442            return ('oct', val)
443        except ValueError:
444            return None
445
446    def is_hexadecimal(self):
447        token = self.tokens[self._index].id
448        if token[-1] in "ULul":
449            token = token[:-1]
450        if len(token) < 3 or (token[:2] != '0x' and token[:2] != '0X'):
451            return None
452        try:
453            val = int(token, 16)
454            self._index += 1
455            return ('hex', val)
456        except ValueError:
457            return None
458
459    def is_integer(self):
460        if self.tokens[self._index].kind != TokenKind.LITERAL:
461            return None
462
463        c = self.is_hexadecimal()
464        if c:
465            return c
466
467        c = self.is_octal()
468        if c:
469            return c
470
471        c = self.is_decimal()
472        if c:
473            return c
474
475        return None
476
477    def is_number(self):
478        t = self.tokens[self._index]
479        if t.id == tokMINUS and self._index + 1 < self._num_tokens:
480            self._index += 1
481            c = self.is_integer()
482            if c:
483                op, val = c
484                return (op, -val)
485        if t.id == tokPLUS and self._index + 1 < self._num_tokens:
486            self._index += 1
487            c = self.is_integer()
488            if c:
489                return c
490
491        return self.is_integer()
492
493    def is_defined(self):
494        t = self.tokens[self._index]
495        if t.id != tokDEFINED:
496            return None
497
498        # We have the defined keyword, check the rest.
499        self._index += 1
500        used_parens = False
501        if (self._index < self._num_tokens and
502            self.tokens[self._index].id == tokLPAREN):
503            used_parens = True
504            self._index += 1
505
506        if self._index >= self._num_tokens:
507            self.throw(BadExpectedToken,
508                       "### 'defined' must be followed by macro name or left "
509                       "paren")
510
511        t = self.tokens[self._index]
512        if t.kind != TokenKind.IDENTIFIER:
513            self.throw(BadExpectedToken,
514                       "### 'defined' must be followed by macro name")
515
516        self._index += 1
517        if used_parens:
518            self.expectId(tokRPAREN)
519
520        return ("defined", t.id)
521
522    def is_call_or_ident(self):
523        if self._index >= self._num_tokens:
524            return None
525
526        t = self.tokens[self._index]
527        if t.kind != TokenKind.IDENTIFIER:
528            return None
529
530        name = t.id
531
532        self._index += 1
533        if (self._index >= self._num_tokens or
534            self.tokens[self._index].id != tokLPAREN):
535            return ("ident", name)
536
537        params = []
538        depth = 1
539        self._index += 1
540        j = self._index
541        while self._index < self._num_tokens:
542            id = self.tokens[self._index].id
543            if id == tokLPAREN:
544                depth += 1
545            elif depth == 1 and (id == tokCOMMA or id == tokRPAREN):
546                k = self._index
547                param = self.tokens[j:k]
548                params.append(param)
549                if id == tokRPAREN:
550                    break
551                j = self._index + 1
552            elif id == tokRPAREN:
553                depth -= 1
554            self._index += 1
555
556        if self._index >= self._num_tokens:
557            return None
558
559        self._index += 1
560        return ("call", (name, params))
561
562    # Implements the "precedence climbing" algorithm from
563    # http://www.engr.mun.ca/~theo/Misc/exp_parsing.htm.
564    # The "classic" algorithm would be fine if we were using a tool to
565    # generate the parser, but we're not. Dijkstra's "shunting yard"
566    # algorithm hasn't been necessary yet.
567
568    def parseExpression(self, minPrecedence):
569        if self._index >= self._num_tokens:
570            return None
571
572        node = self.parsePrimary()
573        while (self.token() and self.isBinary(self.token()) and
574               self.precedence(self.token()) >= minPrecedence):
575            op = self.token()
576            self.nextToken()
577            rhs = self.parseExpression(self.precedence(op) + 1)
578            node = (op.id, node, rhs)
579
580        return node
581
582    def parsePrimary(self):
583        op = self.token()
584        if self.isUnary(op):
585            self.nextToken()
586            return (op.id, self.parseExpression(self.precedence(op)))
587
588        primary = None
589        if op.id == tokLPAREN:
590            self.nextToken()
591            primary = self.parseExpression(0)
592            self.expectId(tokRPAREN)
593        elif op.id == "?":
594            self.nextToken()
595            primary = self.parseExpression(0)
596            self.expectId(":")
597        elif op.id == '+' or op.id == '-' or op.kind == TokenKind.LITERAL:
598            primary = self.is_number()
599        # Checking for 'defined' needs to come first now because 'defined' is
600        # recognized as IDENTIFIER.
601        elif op.id == tokDEFINED:
602            primary = self.is_defined()
603        elif op.kind == TokenKind.IDENTIFIER:
604            primary = self.is_call_or_ident()
605        else:
606            self.throw(BadExpectedToken,
607                       "didn't expect to see a %s in factor" % (
608                           self.tokens[self._index].id))
609        return primary
610
611    def isBinary(self, token):
612        return token.id in self.binaries
613
614    def isUnary(self, token):
615        return token.id in self.unaries
616
617    def precedence(self, token):
618        return self.precedences.get(token.id)
619
620    def token(self):
621        if self._index >= self._num_tokens:
622            return None
623        return self.tokens[self._index]
624
625    def nextToken(self):
626        self._index += 1
627        if self._index >= self._num_tokens:
628            return None
629        return self.tokens[self._index]
630
631    def dump_node(self, e):
632        op = e[0]
633        line = "(" + op
634        if op == "int":
635            line += " %d)" % e[1]
636        elif op == "oct":
637            line += " 0%o)" % e[1]
638        elif op == "hex":
639            line += " 0x%x)" % e[1]
640        elif op == "ident":
641            line += " %s)" % e[1]
642        elif op == "defined":
643            line += " %s)" % e[1]
644        elif op == "call":
645            arg = e[1]
646            line += " %s [" % arg[0]
647            prefix = ""
648            for param in arg[1]:
649                par = ""
650                for tok in param:
651                    par += str(tok)
652                line += "%s%s" % (prefix, par)
653                prefix = ","
654            line += "])"
655        elif op in CppExpr.unaries:
656            line += " %s)" % self.dump_node(e[1])
657        elif op in CppExpr.binaries:
658            line += " %s %s)" % (self.dump_node(e[1]), self.dump_node(e[2]))
659        else:
660            line += " ?%s)" % repr(e[1])
661
662        return line
663
664    def __repr__(self):
665        return self.dump_node(self.expr)
666
667    def source_node(self, e):
668        op = e[0]
669        if op == "int":
670            return "%d" % e[1]
671        if op == "hex":
672            return "0x%x" % e[1]
673        if op == "oct":
674            return "0%o" % e[1]
675        if op == "ident":
676            # XXX: should try to expand
677            return e[1]
678        if op == "defined":
679            return "defined(%s)" % e[1]
680
681        prec = CppExpr.precedences.get(op, 1000)
682        arg = e[1]
683        if op in CppExpr.unaries:
684            arg_src = self.source_node(arg)
685            arg_op = arg[0]
686            arg_prec = CppExpr.precedences.get(arg_op, 1000)
687            if arg_prec < prec:
688                return "!(" + arg_src + ")"
689            else:
690                return "!" + arg_src
691        if op in CppExpr.binaries:
692            arg2 = e[2]
693            arg1_op = arg[0]
694            arg2_op = arg2[0]
695            arg1_src = self.source_node(arg)
696            arg2_src = self.source_node(arg2)
697            if CppExpr.precedences.get(arg1_op, 1000) < prec:
698                arg1_src = "(%s)" % arg1_src
699            if CppExpr.precedences.get(arg2_op, 1000) < prec:
700                arg2_src = "(%s)" % arg2_src
701
702            return "%s %s %s" % (arg1_src, op, arg2_src)
703        return "???"
704
705    def __str__(self):
706        return self.source_node(self.expr)
707
708    @staticmethod
709    def int_node(e):
710        if e[0] in ["int", "oct", "hex"]:
711            return e[1]
712        else:
713            return None
714
715    def toInt(self):
716        return self.int_node(self.expr)
717
718    def optimize_node(self, e, macros=None):
719        if macros is None:
720            macros = {}
721        op = e[0]
722
723        if op == "defined":
724            op, name = e
725            if macros.has_key(name):
726                if macros[name] == kCppUndefinedMacro:
727                    return ("int", 0)
728                else:
729                    try:
730                        value = int(macros[name])
731                        return ("int", value)
732                    except ValueError:
733                        return ("defined", macros[name])
734
735            if kernel_remove_config_macros and name.startswith("CONFIG_"):
736                return ("int", 0)
737
738            return e
739
740        elif op == "ident":
741            op, name = e
742            if macros.has_key(name):
743                try:
744                    value = int(macros[name])
745                    expanded = ("int", value)
746                except ValueError:
747                    expanded = ("ident", macros[name])
748                return self.optimize_node(expanded, macros)
749            return e
750
751        elif op == "!":
752            op, v = e
753            v = self.optimize_node(v, macros)
754            if v[0] == "int":
755                if v[1] == 0:
756                    return ("int", 1)
757                else:
758                    return ("int", 0)
759            return ('!', v)
760
761        elif op == "&&":
762            op, l, r = e
763            l = self.optimize_node(l, macros)
764            r = self.optimize_node(r, macros)
765            li = self.int_node(l)
766            ri = self.int_node(r)
767            if li is not None:
768                if li == 0:
769                    return ("int", 0)
770                else:
771                    return r
772            elif ri is not None:
773                if ri == 0:
774                    return ("int", 0)
775                else:
776                    return l
777            return (op, l, r)
778
779        elif op == "||":
780            op, l, r = e
781            l = self.optimize_node(l, macros)
782            r = self.optimize_node(r, macros)
783            li = self.int_node(l)
784            ri = self.int_node(r)
785            if li is not None:
786                if li == 0:
787                    return r
788                else:
789                    return ("int", 1)
790            elif ri is not None:
791                if ri == 0:
792                    return l
793                else:
794                    return ("int", 1)
795            return (op, l, r)
796
797        else:
798            return e
799
800    def optimize(self, macros=None):
801        if macros is None:
802            macros = {}
803        self.expr = self.optimize_node(self.expr, macros)
804
805class CppExprTest(unittest.TestCase):
806    """CppExpr unit tests."""
807
808    def get_expr(self, expr):
809        return repr(CppExpr(CppStringTokenizer(expr).tokens))
810
811    def test_cpp_expr(self):
812        self.assertEqual(self.get_expr("0"), "(int 0)")
813        self.assertEqual(self.get_expr("1"), "(int 1)")
814        self.assertEqual(self.get_expr("-5"), "(int -5)")
815        self.assertEqual(self.get_expr("+1"), "(int 1)")
816        self.assertEqual(self.get_expr("0U"), "(int 0)")
817        self.assertEqual(self.get_expr("015"), "(oct 015)")
818        self.assertEqual(self.get_expr("015l"), "(oct 015)")
819        self.assertEqual(self.get_expr("0x3e"), "(hex 0x3e)")
820        self.assertEqual(self.get_expr("(0)"), "(int 0)")
821        self.assertEqual(self.get_expr("1 && 1"), "(&& (int 1) (int 1))")
822        self.assertEqual(self.get_expr("1 && 0"), "(&& (int 1) (int 0))")
823        self.assertEqual(self.get_expr("EXAMPLE"), "(ident EXAMPLE)")
824        self.assertEqual(self.get_expr("EXAMPLE - 3"),
825                         "(- (ident EXAMPLE) (int 3))")
826        self.assertEqual(self.get_expr("defined(EXAMPLE)"),
827                         "(defined EXAMPLE)")
828        self.assertEqual(self.get_expr("defined ( EXAMPLE ) "),
829                         "(defined EXAMPLE)")
830        self.assertEqual(self.get_expr("!defined(EXAMPLE)"),
831                         "(! (defined EXAMPLE))")
832        self.assertEqual(self.get_expr("defined(ABC) || defined(BINGO)"),
833                         "(|| (defined ABC) (defined BINGO))")
834        self.assertEqual(self.get_expr("FOO(BAR,5)"), "(call FOO [BAR,5])")
835        self.assertEqual(self.get_expr("A == 1 || defined(B)"),
836                         "(|| (== (ident A) (int 1)) (defined B))")
837
838    def get_expr_optimize(self, expr, macros=None):
839        if macros is None:
840            macros = {}
841        e = CppExpr(CppStringTokenizer(expr).tokens)
842        e.optimize(macros)
843        return repr(e)
844
845    def test_cpp_expr_optimize(self):
846        self.assertEqual(self.get_expr_optimize("0"), "(int 0)")
847        self.assertEqual(self.get_expr_optimize("1"), "(int 1)")
848        self.assertEqual(self.get_expr_optimize("1 && 1"), "(int 1)")
849        self.assertEqual(self.get_expr_optimize("1 && +1"), "(int 1)")
850        self.assertEqual(self.get_expr_optimize("0x1 && 01"), "(oct 01)")
851        self.assertEqual(self.get_expr_optimize("1 && 0"), "(int 0)")
852        self.assertEqual(self.get_expr_optimize("0 && 1"), "(int 0)")
853        self.assertEqual(self.get_expr_optimize("0 && 0"), "(int 0)")
854        self.assertEqual(self.get_expr_optimize("1 || 1"), "(int 1)")
855        self.assertEqual(self.get_expr_optimize("1 || 0"), "(int 1)")
856        self.assertEqual(self.get_expr_optimize("0 || 1"), "(int 1)")
857        self.assertEqual(self.get_expr_optimize("0 || 0"), "(int 0)")
858        self.assertEqual(self.get_expr_optimize("A"), "(ident A)")
859        self.assertEqual(self.get_expr_optimize("A", {"A": 1}), "(int 1)")
860        self.assertEqual(self.get_expr_optimize("A || B", {"A": 1}), "(int 1)")
861        self.assertEqual(self.get_expr_optimize("A || B", {"B": 1}), "(int 1)")
862        self.assertEqual(self.get_expr_optimize("A && B", {"A": 1}), "(ident B)")
863        self.assertEqual(self.get_expr_optimize("A && B", {"B": 1}), "(ident A)")
864        self.assertEqual(self.get_expr_optimize("A && B"), "(&& (ident A) (ident B))")
865        self.assertEqual(self.get_expr_optimize("EXAMPLE"), "(ident EXAMPLE)")
866        self.assertEqual(self.get_expr_optimize("EXAMPLE - 3"), "(- (ident EXAMPLE) (int 3))")
867        self.assertEqual(self.get_expr_optimize("defined(EXAMPLE)"), "(defined EXAMPLE)")
868        self.assertEqual(self.get_expr_optimize("defined(EXAMPLE)",
869                                                {"EXAMPLE": "XOWOE"}),
870                         "(defined XOWOE)")
871        self.assertEqual(self.get_expr_optimize("defined(EXAMPLE)",
872                                                {"EXAMPLE": kCppUndefinedMacro}),
873                         "(int 0)")
874        self.assertEqual(self.get_expr_optimize("!defined(EXAMPLE)"), "(! (defined EXAMPLE))")
875        self.assertEqual(self.get_expr_optimize("!defined(EXAMPLE)",
876                                                {"EXAMPLE": "XOWOE"}),
877                         "(! (defined XOWOE))")
878        self.assertEqual(self.get_expr_optimize("!defined(EXAMPLE)",
879                                                {"EXAMPLE": kCppUndefinedMacro}),
880                         "(int 1)")
881        self.assertEqual(self.get_expr_optimize("defined(A) || defined(B)"),
882                        "(|| (defined A) (defined B))")
883        self.assertEqual(self.get_expr_optimize("defined(A) || defined(B)",
884                                                {"A": "1"}),
885                         "(int 1)")
886        self.assertEqual(self.get_expr_optimize("defined(A) || defined(B)",
887                                                {"B": "1"}),
888                         "(int 1)")
889        self.assertEqual(self.get_expr_optimize("defined(A) || defined(B)",
890                                                {"B": kCppUndefinedMacro}),
891                         "(defined A)")
892        self.assertEqual(self.get_expr_optimize("defined(A) || defined(B)",
893                                                {"A": kCppUndefinedMacro,
894                                                 "B": kCppUndefinedMacro}),
895                         "(int 0)")
896        self.assertEqual(self.get_expr_optimize("defined(A) && defined(B)"),
897                         "(&& (defined A) (defined B))")
898        self.assertEqual(self.get_expr_optimize("defined(A) && defined(B)",
899                                                {"A": "1"}),
900                         "(defined B)")
901        self.assertEqual(self.get_expr_optimize("defined(A) && defined(B)",
902                                                {"B": "1"}),
903                         "(defined A)")
904        self.assertEqual(self.get_expr_optimize("defined(A) && defined(B)",
905                                                {"B": kCppUndefinedMacro}),
906                        "(int 0)")
907        self.assertEqual(self.get_expr_optimize("defined(A) && defined(B)",
908                                                {"A": kCppUndefinedMacro}),
909                        "(int 0)")
910        self.assertEqual(self.get_expr_optimize("A == 1 || defined(B)"),
911                         "(|| (== (ident A) (int 1)) (defined B))")
912        self.assertEqual(self.get_expr_optimize(
913              "defined(__KERNEL__) || !defined(__GLIBC__) || (__GLIBC__ < 2)",
914              {"__KERNEL__": kCppUndefinedMacro}),
915              "(|| (! (defined __GLIBC__)) (< (ident __GLIBC__) (int 2)))")
916
917    def get_expr_string(self, expr):
918        return str(CppExpr(CppStringTokenizer(expr).tokens))
919
920    def test_cpp_expr_string(self):
921        self.assertEqual(self.get_expr_string("0"), "0")
922        self.assertEqual(self.get_expr_string("1"), "1")
923        self.assertEqual(self.get_expr_string("1 && 1"), "1 && 1")
924        self.assertEqual(self.get_expr_string("1 && 0"), "1 && 0")
925        self.assertEqual(self.get_expr_string("0 && 1"), "0 && 1")
926        self.assertEqual(self.get_expr_string("0 && 0"), "0 && 0")
927        self.assertEqual(self.get_expr_string("1 || 1"), "1 || 1")
928        self.assertEqual(self.get_expr_string("1 || 0"), "1 || 0")
929        self.assertEqual(self.get_expr_string("0 || 1"), "0 || 1")
930        self.assertEqual(self.get_expr_string("0 || 0"), "0 || 0")
931        self.assertEqual(self.get_expr_string("EXAMPLE"), "EXAMPLE")
932        self.assertEqual(self.get_expr_string("EXAMPLE - 3"), "EXAMPLE - 3")
933        self.assertEqual(self.get_expr_string("defined(EXAMPLE)"), "defined(EXAMPLE)")
934        self.assertEqual(self.get_expr_string("defined EXAMPLE"), "defined(EXAMPLE)")
935        self.assertEqual(self.get_expr_string("A == 1 || defined(B)"), "A == 1 || defined(B)")
936
937
938################################################################################
939################################################################################
940#####                                                                      #####
941#####          C P P   B L O C K                                           #####
942#####                                                                      #####
943################################################################################
944################################################################################
945
946
947class Block(object):
948    """A class used to model a block of input source text.
949
950    There are two block types:
951      - directive blocks: contain the tokens of a single pre-processor
952        directive (e.g. #if)
953      - text blocks, contain the tokens of non-directive blocks
954
955    The cpp parser class below will transform an input source file into a list
956    of Block objects (grouped in a BlockList object for convenience)
957    """
958
959    def __init__(self, tokens, directive=None, lineno=0, identifier=None):
960        """Initialize a new block, if 'directive' is None, it is a text block.
961
962        NOTE: This automatically converts '#ifdef MACRO' into
963        '#if defined(MACRO)' and '#ifndef MACRO' into '#if !defined(MACRO)'.
964        """
965
966        if directive == "ifdef":
967            tok = Token()
968            tok.id = tokDEFINED
969            tokens = [tok] + tokens
970            directive = "if"
971
972        elif directive == "ifndef":
973            tok1 = Token()
974            tok2 = Token()
975            tok1.id = tokNOT
976            tok2.id = tokDEFINED
977            tokens = [tok1, tok2] + tokens
978            directive = "if"
979
980        self.tokens = tokens
981        self.directive = directive
982        self.define_id = identifier
983        if lineno > 0:
984            self.lineno = lineno
985        else:
986            self.lineno = self.tokens[0].location.line
987
988        if self.isIf():
989            self.expr = CppExpr(self.tokens)
990
991    def isDirective(self):
992        """Return True iff this is a directive block."""
993        return self.directive is not None
994
995    def isConditional(self):
996        """Return True iff this is a conditional directive block."""
997        return self.directive in ["if", "ifdef", "ifndef", "else", "elif",
998                                  "endif"]
999
1000    def isDefine(self):
1001        """Return the macro name in a #define directive, or None otherwise."""
1002        if self.directive != "define":
1003            return None
1004        return self.define_id
1005
1006    def isIf(self):
1007        """Return True iff this is an #if-like directive block."""
1008        return self.directive in ["if", "ifdef", "ifndef", "elif"]
1009
1010    def isEndif(self):
1011        """Return True iff this is an #endif directive block."""
1012        return self.directive == "endif"
1013
1014    def isInclude(self):
1015        """Check whether this is a #include directive.
1016
1017        If true, returns the corresponding file name (with brackets or
1018        double-qoutes). None otherwise.
1019        """
1020
1021        if self.directive != "include":
1022            return None
1023        return ''.join([str(x) for x in self.tokens])
1024
1025    @staticmethod
1026    def format_blocks(tokens, indent=0):
1027        """Return the formatted lines of strings with proper indentation."""
1028        newline = True
1029        result = []
1030        buf = ''
1031        i = 0
1032        while i < len(tokens):
1033            t = tokens[i]
1034            if t.id == '{':
1035                buf += ' {'
1036                result.append(strip_space(buf))
1037                # Do not indent if this is extern "C" {
1038                if i < 2 or tokens[i-2].id != 'extern' or tokens[i-1].id != '"C"':
1039                    indent += 2
1040                buf = ''
1041                newline = True
1042            elif t.id == '}':
1043                if indent >= 2:
1044                    indent -= 2
1045                if not newline:
1046                    result.append(strip_space(buf))
1047                # Look ahead to determine if it's the end of line.
1048                if (i + 1 < len(tokens) and
1049                    (tokens[i+1].id == ';' or
1050                     tokens[i+1].id in ['else', '__attribute__',
1051                                        '__attribute', '__packed'] or
1052                     tokens[i+1].kind == TokenKind.IDENTIFIER)):
1053                    buf = ' ' * indent + '}'
1054                    newline = False
1055                else:
1056                    result.append(' ' * indent + '}')
1057                    buf = ''
1058                    newline = True
1059            elif t.id == ';':
1060                result.append(strip_space(buf) + ';')
1061                buf = ''
1062                newline = True
1063            # We prefer a new line for each constant in enum.
1064            elif t.id == ',' and t.cursor.kind == CursorKind.ENUM_DECL:
1065                result.append(strip_space(buf) + ',')
1066                buf = ''
1067                newline = True
1068            else:
1069                if newline:
1070                    buf += ' ' * indent + str(t)
1071                else:
1072                    buf += ' ' + str(t)
1073                newline = False
1074            i += 1
1075
1076        if buf:
1077            result.append(strip_space(buf))
1078
1079        return result, indent
1080
1081    def write(self, out, indent):
1082        """Dump the current block."""
1083        # removeWhiteSpace() will sometimes creates non-directive blocks
1084        # without any tokens. These come from blocks that only contained
1085        # empty lines and spaces. They should not be printed in the final
1086        # output, and then should not be counted for this operation.
1087        #
1088        if self.directive is None and not self.tokens:
1089            return indent
1090
1091        if self.directive:
1092            out.write(str(self) + '\n')
1093        else:
1094            lines, indent = self.format_blocks(self.tokens, indent)
1095            for line in lines:
1096                out.write(line + '\n')
1097
1098        return indent
1099
1100    def __repr__(self):
1101        """Generate the representation of a given block."""
1102        if self.directive:
1103            result = "#%s " % self.directive
1104            if self.isIf():
1105                result += repr(self.expr)
1106            else:
1107                for tok in self.tokens:
1108                    result += repr(tok)
1109        else:
1110            result = ""
1111            for tok in self.tokens:
1112                result += repr(tok)
1113
1114        return result
1115
1116    def __str__(self):
1117        """Generate the string representation of a given block."""
1118        if self.directive:
1119            # "#if"
1120            if self.directive == "if":
1121                # small optimization to re-generate #ifdef and #ifndef
1122                e = self.expr.expr
1123                op = e[0]
1124                if op == "defined":
1125                    result = "#ifdef %s" % e[1]
1126                elif op == "!" and e[1][0] == "defined":
1127                    result = "#ifndef %s" % e[1][1]
1128                else:
1129                    result = "#if " + str(self.expr)
1130
1131            # "#define"
1132            elif self.isDefine():
1133                result = "#%s %s" % (self.directive, self.define_id)
1134                if self.tokens:
1135                    result += " "
1136                expr = strip_space(' '.join([tok.id for tok in self.tokens]))
1137                # remove the space between name and '(' in function call
1138                result += re.sub(r'(\w+) \(', r'\1(', expr)
1139
1140            # "#error"
1141            # Concatenating tokens with a space separator, because they may
1142            # not be quoted and broken into several tokens
1143            elif self.directive == "error":
1144                result = "#error %s" % ' '.join([tok.id for tok in self.tokens])
1145
1146            else:
1147                result = "#%s" % self.directive
1148                if self.tokens:
1149                    result += " "
1150                result += ''.join([tok.id for tok in self.tokens])
1151        else:
1152            lines, _ = self.format_blocks(self.tokens)
1153            result = '\n'.join(lines)
1154
1155        return result
1156
1157
1158class BlockList(object):
1159    """A convenience class used to hold and process a list of blocks.
1160
1161    It calls the cpp parser to get the blocks.
1162    """
1163
1164    def __init__(self, blocks):
1165        self.blocks = blocks
1166
1167    def __len__(self):
1168        return len(self.blocks)
1169
1170    def __getitem__(self, n):
1171        return self.blocks[n]
1172
1173    def __repr__(self):
1174        return repr(self.blocks)
1175
1176    def __str__(self):
1177        result = '\n'.join([str(b) for b in self.blocks])
1178        return result
1179
1180    def dump(self):
1181        """Dump all the blocks in current BlockList."""
1182        print '##### BEGIN #####'
1183        for i, b in enumerate(self.blocks):
1184            print '### BLOCK %d ###' % i
1185            print b
1186        print '##### END #####'
1187
1188    def optimizeIf01(self):
1189        """Remove the code between #if 0 .. #endif in a BlockList."""
1190        self.blocks = optimize_if01(self.blocks)
1191
1192    def optimizeMacros(self, macros):
1193        """Remove known defined and undefined macros from a BlockList."""
1194        for b in self.blocks:
1195            if b.isIf():
1196                b.expr.optimize(macros)
1197
1198    def removeStructs(self, structs):
1199        """Remove structs."""
1200        for b in self.blocks:
1201            # Have to look in each block for a top-level struct definition.
1202            if b.directive:
1203                continue
1204            num_tokens = len(b.tokens)
1205            # A struct definition has at least 5 tokens:
1206            #   struct
1207            #   ident
1208            #   {
1209            #   }
1210            #   ;
1211            if num_tokens < 5:
1212                continue
1213            # This is a simple struct finder, it might fail if a top-level
1214            # structure has an #if type directives that confuses the algorithm
1215            # for finding th end of the structure. Or if there is another
1216            # structure definition embedded in the structure.
1217            i = 0
1218            while i < num_tokens - 2:
1219                if (b.tokens[i].kind != TokenKind.KEYWORD or
1220                    b.tokens[i].id != "struct"):
1221                    i += 1
1222                    continue
1223                if (b.tokens[i + 1].kind == TokenKind.IDENTIFIER and
1224                    b.tokens[i + 2].kind == TokenKind.PUNCTUATION and
1225                    b.tokens[i + 2].id == "{" and b.tokens[i + 1].id in structs):
1226                    # Search forward for the end of the structure.
1227                    # Very simple search, look for } and ; tokens. If something
1228                    # more complicated is needed we can add it later.
1229                    j = i + 3
1230                    while j < num_tokens - 1:
1231                        if (b.tokens[j].kind == TokenKind.PUNCTUATION and
1232                            b.tokens[j].id == "}" and
1233                            b.tokens[j + 1].kind == TokenKind.PUNCTUATION and
1234                            b.tokens[j + 1].id == ";"):
1235                            b.tokens = b.tokens[0:i] + b.tokens[j + 2:num_tokens]
1236                            num_tokens = len(b.tokens)
1237                            j = i
1238                            break
1239                        j += 1
1240                    i = j
1241                    continue
1242                i += 1
1243
1244    def optimizeAll(self, macros):
1245        self.optimizeMacros(macros)
1246        self.optimizeIf01()
1247        return
1248
1249    def findIncludes(self):
1250        """Return the list of included files in a BlockList."""
1251        result = []
1252        for b in self.blocks:
1253            i = b.isInclude()
1254            if i:
1255                result.append(i)
1256        return result
1257
1258    def write(self, out):
1259        indent = 0
1260        for b in self.blocks:
1261            indent = b.write(out, indent)
1262
1263    def removeVarsAndFuncs(self, keep):
1264        """Remove variable and function declarations.
1265
1266        All extern and static declarations corresponding to variable and
1267        function declarations are removed. We only accept typedefs and
1268        enum/structs/union declarations.
1269
1270        In addition, remove any macros expanding in the headers. Usually,
1271        these macros are static inline functions, which is why they are
1272        removed.
1273
1274        However, we keep the definitions corresponding to the set of known
1275        static inline functions in the set 'keep', which is useful
1276        for optimized byteorder swap functions and stuff like that.
1277        """
1278
1279        # state = NORMAL => normal (i.e. LN + spaces)
1280        # state = OTHER_DECL => typedef/struct encountered, ends with ";"
1281        # state = VAR_DECL => var declaration encountered, ends with ";"
1282        # state = FUNC_DECL => func declaration encountered, ends with "}"
1283        NORMAL = 0
1284        OTHER_DECL = 1
1285        VAR_DECL = 2
1286        FUNC_DECL = 3
1287
1288        state = NORMAL
1289        depth = 0
1290        blocksToKeep = []
1291        blocksInProgress = []
1292        blocksOfDirectives = []
1293        ident = ""
1294        state_token = ""
1295        macros = set()
1296        for block in self.blocks:
1297            if block.isDirective():
1298                # Record all macros.
1299                if block.directive == 'define':
1300                    macro_name = block.define_id
1301                    paren_index = macro_name.find('(')
1302                    if paren_index == -1:
1303                        macros.add(macro_name)
1304                    else:
1305                        macros.add(macro_name[0:paren_index])
1306                blocksInProgress.append(block)
1307                # If this is in a function/variable declaration, we might need
1308                # to emit the directives alone, so save them separately.
1309                blocksOfDirectives.append(block)
1310                continue
1311
1312            numTokens = len(block.tokens)
1313            lastTerminatorIndex = 0
1314            i = 0
1315            while i < numTokens:
1316                token_id = block.tokens[i].id
1317                terminator = False
1318                if token_id == '{':
1319                    depth += 1
1320                    if (i >= 2 and block.tokens[i-2].id == 'extern' and
1321                        block.tokens[i-1].id == '"C"'):
1322                        # For an extern "C" { pretend as though this is depth 0.
1323                        depth -= 1
1324                elif token_id == '}':
1325                    if depth > 0:
1326                        depth -= 1
1327                    if depth == 0:
1328                        if state == OTHER_DECL:
1329                            # Loop through until we hit the ';'
1330                            i += 1
1331                            while i < numTokens:
1332                                if block.tokens[i].id == ';':
1333                                    token_id = ';'
1334                                    break
1335                                i += 1
1336                            # If we didn't hit the ';', just consider this the
1337                            # terminator any way.
1338                        terminator = True
1339                elif depth == 0:
1340                    if token_id == ';':
1341                        if state == NORMAL:
1342                            blocksToKeep.extend(blocksInProgress)
1343                            blocksInProgress = []
1344                            blocksOfDirectives = []
1345                            state = FUNC_DECL
1346                        terminator = True
1347                    elif (state == NORMAL and token_id == '(' and i >= 1 and
1348                          block.tokens[i-1].kind == TokenKind.IDENTIFIER and
1349                          block.tokens[i-1].id in macros):
1350                        # This is a plain macro being expanded in the header
1351                        # which needs to be removed.
1352                        blocksToKeep.extend(blocksInProgress)
1353                        if lastTerminatorIndex < i - 1:
1354                            blocksToKeep.append(Block(block.tokens[lastTerminatorIndex:i-1]))
1355                        blocksInProgress = []
1356                        blocksOfDirectives = []
1357
1358                        # Skip until we see the terminating ')'
1359                        i += 1
1360                        paren_depth = 1
1361                        while i < numTokens:
1362                            if block.tokens[i].id == ')':
1363                                paren_depth -= 1
1364                                if paren_depth == 0:
1365                                    break
1366                            elif block.tokens[i].id == '(':
1367                                paren_depth += 1
1368                            i += 1
1369                        lastTerminatorIndex = i + 1
1370                    elif (state != FUNC_DECL and token_id == '(' and
1371                          state_token != 'typedef'):
1372                        blocksToKeep.extend(blocksInProgress)
1373                        blocksInProgress = []
1374                        blocksOfDirectives = []
1375                        state = VAR_DECL
1376                    elif state == NORMAL and token_id in ['struct', 'typedef',
1377                                                          'enum', 'union',
1378                                                          '__extension__']:
1379                        state = OTHER_DECL
1380                        state_token = token_id
1381                    elif block.tokens[i].kind == TokenKind.IDENTIFIER:
1382                        if state != VAR_DECL or ident == "":
1383                            ident = token_id
1384
1385                if terminator:
1386                    if state != VAR_DECL and state != FUNC_DECL or ident in keep:
1387                        blocksInProgress.append(Block(block.tokens[lastTerminatorIndex:i+1]))
1388                        blocksToKeep.extend(blocksInProgress)
1389                    else:
1390                        # Only keep the directives found.
1391                        blocksToKeep.extend(blocksOfDirectives)
1392                    lastTerminatorIndex = i + 1
1393                    blocksInProgress = []
1394                    blocksOfDirectives = []
1395                    state = NORMAL
1396                    ident = ""
1397                    state_token = ""
1398                i += 1
1399            if lastTerminatorIndex < numTokens:
1400                blocksInProgress.append(Block(block.tokens[lastTerminatorIndex:numTokens]))
1401        if len(blocksInProgress) > 0:
1402            blocksToKeep.extend(blocksInProgress)
1403        self.blocks = blocksToKeep
1404
1405    def replaceTokens(self, replacements):
1406        """Replace tokens according to the given dict."""
1407        extra_includes = []
1408        for b in self.blocks:
1409            made_change = False
1410            if b.isInclude() is None:
1411                i = 0
1412                while i < len(b.tokens):
1413                    tok = b.tokens[i]
1414                    if (tok.kind == TokenKind.KEYWORD and tok.id == 'struct'
1415                        and (i + 2) < len(b.tokens) and b.tokens[i + 2].id == '{'):
1416                        struct_name = b.tokens[i + 1].id
1417                        if struct_name in kernel_struct_replacements:
1418                            extra_includes.append("<bits/%s.h>" % struct_name)
1419                            end = i + 2
1420                            while end < len(b.tokens) and b.tokens[end].id != '}':
1421                                end += 1
1422                            end += 1 # Swallow '}'
1423                            while end < len(b.tokens) and b.tokens[end].id != ';':
1424                                end += 1
1425                            end += 1 # Swallow ';'
1426                            # Remove these tokens. We'll replace them later with a #include block.
1427                            b.tokens[i:end] = []
1428                            made_change = True
1429                            # We've just modified b.tokens, so revisit the current offset.
1430                            continue
1431                    if tok.kind == TokenKind.IDENTIFIER:
1432                        if tok.id in replacements:
1433                            tok.id = replacements[tok.id]
1434                            made_change = True
1435                    i += 1
1436
1437                if b.isDefine() and b.define_id in replacements:
1438                    b.define_id = replacements[b.define_id]
1439                    made_change = True
1440
1441            if made_change and b.isIf():
1442                # Keep 'expr' in sync with 'tokens'.
1443                b.expr = CppExpr(b.tokens)
1444
1445        for extra_include in extra_includes:
1446            replacement = CppStringTokenizer(extra_include)
1447            self.blocks.insert(2, Block(replacement.tokens, directive='include'))
1448
1449
1450
1451def strip_space(s):
1452    """Strip out redundant space in a given string."""
1453
1454    # NOTE: It ought to be more clever to not destroy spaces in string tokens.
1455    replacements = {' . ': '.',
1456                    ' [': '[',
1457                    '[ ': '[',
1458                    ' ]': ']',
1459                    '( ': '(',
1460                    ' )': ')',
1461                    ' ,': ',',
1462                    '# ': '#',
1463                    ' ;': ';',
1464                    '~ ': '~',
1465                    ' -> ': '->'}
1466    result = s
1467    for r in replacements:
1468        result = result.replace(r, replacements[r])
1469
1470    # Remove the space between function name and the parenthesis.
1471    result = re.sub(r'(\w+) \(', r'\1(', result)
1472    return result
1473
1474
1475class BlockParser(object):
1476    """A class that converts an input source file into a BlockList object."""
1477
1478    def __init__(self, tokzer=None):
1479        """Initialize a block parser.
1480
1481        The input source is provided through a Tokenizer object.
1482        """
1483        self._tokzer = tokzer
1484        self._parsed = False
1485
1486    @property
1487    def parsed(self):
1488        return self._parsed
1489
1490    @staticmethod
1491    def _short_extent(extent):
1492        return '%d:%d - %d:%d' % (extent.start.line, extent.start.column,
1493                                  extent.end.line, extent.end.column)
1494
1495    def getBlocks(self, tokzer=None):
1496        """Return all the blocks parsed."""
1497
1498        def consume_extent(i, tokens, extent=None, detect_change=False):
1499            """Return tokens that belong to the given extent.
1500
1501            It parses all the tokens that follow tokens[i], until getting out
1502            of the extent. When detect_change is True, it may terminate early
1503            when detecting preprocessing directives inside the extent.
1504            """
1505
1506            result = []
1507            if extent is None:
1508                extent = tokens[i].cursor.extent
1509
1510            while i < len(tokens) and tokens[i].location in extent:
1511                t = tokens[i]
1512                if debugBlockParser:
1513                    print ' ' * 2, t.id, t.kind, t.cursor.kind
1514                if (detect_change and t.cursor.extent != extent and
1515                    t.cursor.kind == CursorKind.PREPROCESSING_DIRECTIVE):
1516                    break
1517                result.append(t)
1518                i += 1
1519            return (i, result)
1520
1521        def consume_line(i, tokens):
1522            """Return tokens that follow tokens[i] in the same line."""
1523            result = []
1524            line = tokens[i].location.line
1525            while i < len(tokens) and tokens[i].location.line == line:
1526                if tokens[i].cursor.kind == CursorKind.PREPROCESSING_DIRECTIVE:
1527                    break
1528                result.append(tokens[i])
1529                i += 1
1530            return (i, result)
1531
1532        if tokzer is None:
1533            tokzer = self._tokzer
1534        tokens = tokzer.tokens
1535
1536        blocks = []
1537        buf = []
1538        i = 0
1539
1540        while i < len(tokens):
1541            t = tokens[i]
1542            cursor = t.cursor
1543
1544            if debugBlockParser:
1545                print ("%d: Processing [%s], kind=[%s], cursor=[%s], "
1546                       "extent=[%s]" % (t.location.line, t.spelling, t.kind,
1547                                        cursor.kind,
1548                                        self._short_extent(cursor.extent)))
1549
1550            if cursor.kind == CursorKind.PREPROCESSING_DIRECTIVE:
1551                if buf:
1552                    blocks.append(Block(buf))
1553                    buf = []
1554
1555                j = i
1556                if j + 1 >= len(tokens):
1557                    raise BadExpectedToken("### BAD TOKEN at %s" % (t.location))
1558                directive = tokens[j+1].id
1559
1560                if directive == 'define':
1561                    if i+2 >= len(tokens):
1562                        raise BadExpectedToken("### BAD TOKEN at %s" %
1563                                               (tokens[i].location))
1564
1565                    # Skip '#' and 'define'.
1566                    extent = tokens[i].cursor.extent
1567                    i += 2
1568                    id = ''
1569                    # We need to separate the id from the remaining of
1570                    # the line, especially for the function-like macro.
1571                    if (i + 1 < len(tokens) and tokens[i+1].id == '(' and
1572                        (tokens[i].location.column + len(tokens[i].spelling) ==
1573                         tokens[i+1].location.column)):
1574                        while i < len(tokens):
1575                            id += tokens[i].id
1576                            if tokens[i].spelling == ')':
1577                                i += 1
1578                                break
1579                            i += 1
1580                    else:
1581                        id += tokens[i].id
1582                        # Advance to the next token that follows the macro id
1583                        i += 1
1584
1585                    (i, ret) = consume_extent(i, tokens, extent=extent)
1586                    blocks.append(Block(ret, directive=directive,
1587                                        lineno=t.location.line, identifier=id))
1588
1589                else:
1590                    (i, ret) = consume_extent(i, tokens)
1591                    blocks.append(Block(ret[2:], directive=directive,
1592                                        lineno=t.location.line))
1593
1594            elif cursor.kind == CursorKind.INCLUSION_DIRECTIVE:
1595                if buf:
1596                    blocks.append(Block(buf))
1597                    buf = []
1598                directive = tokens[i+1].id
1599                (i, ret) = consume_extent(i, tokens)
1600
1601                blocks.append(Block(ret[2:], directive=directive,
1602                                    lineno=t.location.line))
1603
1604            elif cursor.kind == CursorKind.VAR_DECL:
1605                if buf:
1606                    blocks.append(Block(buf))
1607                    buf = []
1608
1609                (i, ret) = consume_extent(i, tokens, detect_change=True)
1610                buf += ret
1611
1612            elif cursor.kind == CursorKind.FUNCTION_DECL:
1613                if buf:
1614                    blocks.append(Block(buf))
1615                    buf = []
1616
1617                (i, ret) = consume_extent(i, tokens, detect_change=True)
1618                buf += ret
1619
1620            else:
1621                (i, ret) = consume_line(i, tokens)
1622                buf += ret
1623
1624        if buf:
1625            blocks.append(Block(buf))
1626
1627        # _parsed=True indicates a successful parsing, although may result an
1628        # empty BlockList.
1629        self._parsed = True
1630
1631        return BlockList(blocks)
1632
1633    def parse(self, tokzer):
1634        return self.getBlocks(tokzer)
1635
1636    def parseFile(self, path):
1637        return self.getBlocks(CppFileTokenizer(path))
1638
1639
1640class BlockParserTests(unittest.TestCase):
1641    """BlockParser unit tests."""
1642
1643    def get_blocks(self, lines):
1644        blocks = BlockParser().parse(CppStringTokenizer('\n'.join(lines)))
1645        return map(lambda a: str(a), blocks)
1646
1647    def test_hash(self):
1648        self.assertEqual(self.get_blocks(["#error hello"]), ["#error hello"])
1649
1650    def test_empty_line(self):
1651        self.assertEqual(self.get_blocks(["foo", "", "bar"]), ["foo bar"])
1652
1653    def test_hash_with_space(self):
1654        # We currently cannot handle the following case with libclang properly.
1655        # Fortunately it doesn't appear in current headers.
1656        #self.assertEqual(self.get_blocks(["foo", "  #  ", "bar"]), ["foo", "bar"])
1657        pass
1658
1659    def test_with_comment(self):
1660        self.assertEqual(self.get_blocks(["foo",
1661                                          "  #  /* ahah */ if defined(__KERNEL__) /* more */",
1662                                          "bar", "#endif"]),
1663                         ["foo", "#ifdef __KERNEL__", "bar", "#endif"])
1664
1665
1666################################################################################
1667################################################################################
1668#####                                                                      #####
1669#####        B L O C K   L I S T   O P T I M I Z A T I O N                 #####
1670#####                                                                      #####
1671################################################################################
1672################################################################################
1673
1674
1675def find_matching_endif(blocks, i):
1676    """Traverse the blocks to find out the matching #endif."""
1677    n = len(blocks)
1678    depth = 1
1679    while i < n:
1680        if blocks[i].isDirective():
1681            dir_ = blocks[i].directive
1682            if dir_ in ["if", "ifndef", "ifdef"]:
1683                depth += 1
1684            elif depth == 1 and dir_ in ["else", "elif"]:
1685                return i
1686            elif dir_ == "endif":
1687                depth -= 1
1688                if depth == 0:
1689                    return i
1690        i += 1
1691    return i
1692
1693
1694def optimize_if01(blocks):
1695    """Remove the code between #if 0 .. #endif in a list of CppBlocks."""
1696    i = 0
1697    n = len(blocks)
1698    result = []
1699    while i < n:
1700        j = i
1701        while j < n and not blocks[j].isIf():
1702            j += 1
1703        if j > i:
1704            logging.debug("appending lines %d to %d", blocks[i].lineno,
1705                          blocks[j-1].lineno)
1706            result += blocks[i:j]
1707        if j >= n:
1708            break
1709        expr = blocks[j].expr
1710        r = expr.toInt()
1711        if r is None:
1712            result.append(blocks[j])
1713            i = j + 1
1714            continue
1715
1716        if r == 0:
1717            # if 0 => skip everything until the corresponding #endif
1718            start_dir = blocks[j].directive
1719            j = find_matching_endif(blocks, j + 1)
1720            if j >= n:
1721                # unterminated #if 0, finish here
1722                break
1723            dir_ = blocks[j].directive
1724            if dir_ == "endif":
1725                logging.debug("remove 'if 0' .. 'endif' (lines %d to %d)",
1726                              blocks[i].lineno, blocks[j].lineno)
1727                if start_dir == "elif":
1728                    # Put an endif since we started with an elif.
1729                    result += blocks[j:j+1]
1730                i = j + 1
1731            elif dir_ == "else":
1732                # convert 'else' into 'if 1'
1733                logging.debug("convert 'if 0' .. 'else' into 'if 1' (lines %d "
1734                              "to %d)", blocks[i].lineno, blocks[j-1].lineno)
1735                if start_dir == "elif":
1736                    blocks[j].directive = "elif"
1737                else:
1738                    blocks[j].directive = "if"
1739                blocks[j].expr = CppExpr(CppStringTokenizer("1").tokens)
1740                i = j
1741            elif dir_ == "elif":
1742                # convert 'elif' into 'if'
1743                logging.debug("convert 'if 0' .. 'elif' into 'if'")
1744                if start_dir == "elif":
1745                    blocks[j].directive = "elif"
1746                else:
1747                    blocks[j].directive = "if"
1748                i = j
1749            continue
1750
1751        # if 1 => find corresponding endif and remove/transform them
1752        k = find_matching_endif(blocks, j + 1)
1753        if k >= n:
1754            # unterminated #if 1, finish here
1755            logging.debug("unterminated 'if 1'")
1756            result += blocks[j+1:k]
1757            break
1758
1759        start_dir = blocks[j].directive
1760        dir_ = blocks[k].directive
1761        if dir_ == "endif":
1762            logging.debug("convert 'if 1' .. 'endif' (lines %d to %d)",
1763                          blocks[j].lineno, blocks[k].lineno)
1764            if start_dir == "elif":
1765                # Add the elif in to the results and convert it to an elif 1.
1766                blocks[j].tokens = CppStringTokenizer("1").tokens
1767                result += blocks[j:j+1]
1768            result += optimize_if01(blocks[j+1:k])
1769            if start_dir == "elif":
1770                # Add the endif in to the results.
1771                result += blocks[k:k+1]
1772            i = k + 1
1773        elif dir_ == "else":
1774            # convert 'else' into 'if 0'
1775            logging.debug("convert 'if 1' .. 'else' (lines %d to %d)",
1776                          blocks[j].lineno, blocks[k].lineno)
1777            if start_dir == "elif":
1778                # Add the elif in to the results and convert it to an elif 1.
1779                blocks[j].tokens = CppStringTokenizer("1").tokens
1780                result += blocks[j:j+1]
1781            result += optimize_if01(blocks[j+1:k])
1782            if start_dir == "elif":
1783                blocks[k].directive = "elif"
1784            else:
1785                blocks[k].directive = "if"
1786            blocks[k].expr = CppExpr(CppStringTokenizer("0").tokens)
1787            i = k
1788        elif dir_ == "elif":
1789            # convert 'elif' into 'if 0'
1790            logging.debug("convert 'if 1' .. 'elif' (lines %d to %d)",
1791                          blocks[j].lineno, blocks[k].lineno)
1792            result += optimize_if01(blocks[j+1:k])
1793            blocks[k].expr = CppExpr(CppStringTokenizer("0").tokens)
1794            i = k
1795    return result
1796
1797class OptimizerTests(unittest.TestCase):
1798    def parse(self, text, macros=None):
1799        out = utils.StringOutput()
1800        blocks = BlockParser().parse(CppStringTokenizer(text))
1801        blocks.optimizeAll(macros)
1802        blocks.write(out)
1803        return out.get()
1804
1805    def test_if1(self):
1806        text = """\
1807#if 1
1808#define  GOOD
1809#endif
1810"""
1811        expected = """\
1812#define GOOD
1813"""
1814        self.assertEqual(self.parse(text), expected)
1815
1816    def test_if0(self):
1817        text = """\
1818#if 0
1819#define  SHOULD_SKIP1
1820#define  SHOULD_SKIP2
1821#endif
1822"""
1823        expected = ""
1824        self.assertEqual(self.parse(text), expected)
1825
1826    def test_if1_else(self):
1827        text = """\
1828#if 1
1829#define  GOOD
1830#else
1831#define  BAD
1832#endif
1833"""
1834        expected = """\
1835#define GOOD
1836"""
1837        self.assertEqual(self.parse(text), expected)
1838
1839    def test_if0_else(self):
1840        text = """\
1841#if 0
1842#define  BAD
1843#else
1844#define  GOOD
1845#endif
1846"""
1847        expected = """\
1848#define GOOD
1849"""
1850        self.assertEqual(self.parse(text), expected)
1851
1852    def test_if_elif1(self):
1853        text = """\
1854#if defined(something)
1855#define EXISTS
1856#elif 1
1857#define GOOD
1858#endif
1859"""
1860        expected = """\
1861#ifdef something
1862#define EXISTS
1863#elif 1
1864#define GOOD
1865#endif
1866"""
1867        self.assertEqual(self.parse(text), expected)
1868
1869    def test_if_elif1_macro(self):
1870        text = """\
1871#if defined(something)
1872#define EXISTS
1873#elif defined(WILL_BE_ONE)
1874#define GOOD
1875#endif
1876"""
1877        expected = """\
1878#ifdef something
1879#define EXISTS
1880#elif 1
1881#define GOOD
1882#endif
1883"""
1884        self.assertEqual(self.parse(text, {"WILL_BE_ONE": "1"}), expected)
1885
1886
1887    def test_if_elif1_else(self):
1888        text = """\
1889#if defined(something)
1890#define EXISTS
1891#elif 1
1892#define GOOD
1893#else
1894#define BAD
1895#endif
1896"""
1897        expected = """\
1898#ifdef something
1899#define EXISTS
1900#elif 1
1901#define GOOD
1902#endif
1903"""
1904        self.assertEqual(self.parse(text), expected)
1905
1906    def test_if_elif1_else_macro(self):
1907        text = """\
1908#if defined(something)
1909#define EXISTS
1910#elif defined(WILL_BE_ONE)
1911#define GOOD
1912#else
1913#define BAD
1914#endif
1915"""
1916        expected = """\
1917#ifdef something
1918#define EXISTS
1919#elif 1
1920#define GOOD
1921#endif
1922"""
1923        self.assertEqual(self.parse(text, {"WILL_BE_ONE": "1"}), expected)
1924
1925
1926    def test_if_elif1_else_macro(self):
1927        text = """\
1928#if defined(something)
1929#define EXISTS
1930#elif defined(WILL_BE_ONE)
1931#define GOOD
1932#else
1933#define BAD
1934#endif
1935"""
1936        expected = """\
1937#ifdef something
1938#define EXISTS
1939#elif 1
1940#define GOOD
1941#endif
1942"""
1943        self.assertEqual(self.parse(text, {"WILL_BE_ONE": "1"}), expected)
1944
1945    def test_macro_set_to_undefined_single(self):
1946        text = """\
1947#if defined(__KERNEL__)
1948#define BAD_KERNEL
1949#endif
1950"""
1951        expected = ""
1952        macros = {"__KERNEL__": kCppUndefinedMacro}
1953        self.assertEqual(self.parse(text, macros), expected)
1954
1955    def test_macro_set_to_undefined_if(self):
1956        text = """\
1957#if defined(__KERNEL__) || !defined(__GLIBC__) || (__GLIBC__ < 2)
1958#define CHECK
1959#endif
1960"""
1961        expected = """\
1962#if !defined(__GLIBC__) || __GLIBC__ < 2
1963#define CHECK
1964#endif
1965"""
1966        macros = {"__KERNEL__": kCppUndefinedMacro}
1967        self.assertEqual(self.parse(text, macros), expected)
1968
1969    def test_endif_comment_removed(self):
1970        text = """\
1971#ifndef SIGRTMAX
1972#define SIGRTMAX 123
1973#endif /* SIGRTMAX */
1974"""
1975        expected = """\
1976#ifndef SIGRTMAX
1977#define SIGRTMAX 123
1978#endif
1979"""
1980        self.assertEqual(self.parse(text), expected)
1981
1982    def test_multilevel_if0(self):
1983        text = """\
1984#if 0
1985#if 1
1986#define  BAD_6
1987#endif
1988#endif
1989"""
1990        expected = ""
1991        self.assertEqual(self.parse(text), expected)
1992
1993class RemoveStructsTests(unittest.TestCase):
1994    def parse(self, text, structs):
1995        out = utils.StringOutput()
1996        blocks = BlockParser().parse(CppStringTokenizer(text))
1997        blocks.removeStructs(structs)
1998        blocks.write(out)
1999        return out.get()
2000
2001    def test_remove_struct_from_start(self):
2002        text = """\
2003struct remove {
2004  int val1;
2005  int val2;
2006};
2007struct something {
2008  struct timeval val1;
2009  struct timeval val2;
2010};
2011"""
2012        expected = """\
2013struct something {
2014  struct timeval val1;
2015  struct timeval val2;
2016};
2017"""
2018        self.assertEqual(self.parse(text, set(["remove"])), expected)
2019
2020    def test_remove_struct_from_end(self):
2021        text = """\
2022struct something {
2023  struct timeval val1;
2024  struct timeval val2;
2025};
2026struct remove {
2027  int val1;
2028  int val2;
2029};
2030"""
2031        expected = """\
2032struct something {
2033  struct timeval val1;
2034  struct timeval val2;
2035};
2036"""
2037        self.assertEqual(self.parse(text, set(["remove"])), expected)
2038
2039    def test_remove_minimal_struct(self):
2040        text = """\
2041struct remove {
2042};
2043"""
2044        expected = "";
2045        self.assertEqual(self.parse(text, set(["remove"])), expected)
2046
2047    def test_remove_struct_with_struct_fields(self):
2048        text = """\
2049struct something {
2050  struct remove val1;
2051  struct remove val2;
2052};
2053struct remove {
2054  int val1;
2055  struct something val3;
2056  int val2;
2057};
2058"""
2059        expected = """\
2060struct something {
2061  struct remove val1;
2062  struct remove val2;
2063};
2064"""
2065        self.assertEqual(self.parse(text, set(["remove"])), expected)
2066
2067    def test_remove_consecutive_structs(self):
2068        text = """\
2069struct keep1 {
2070  struct timeval val1;
2071  struct timeval val2;
2072};
2073struct remove1 {
2074  int val1;
2075  int val2;
2076};
2077struct remove2 {
2078  int val1;
2079  int val2;
2080  int val3;
2081};
2082struct keep2 {
2083  struct timeval val1;
2084  struct timeval val2;
2085};
2086"""
2087        expected = """\
2088struct keep1 {
2089  struct timeval val1;
2090  struct timeval val2;
2091};
2092struct keep2 {
2093  struct timeval val1;
2094  struct timeval val2;
2095};
2096"""
2097        self.assertEqual(self.parse(text, set(["remove1", "remove2"])), expected)
2098
2099    def test_remove_multiple_structs(self):
2100        text = """\
2101struct keep1 {
2102  int val;
2103};
2104struct remove1 {
2105  int val1;
2106  int val2;
2107};
2108struct keep2 {
2109  int val;
2110};
2111struct remove2 {
2112  struct timeval val1;
2113  struct timeval val2;
2114};
2115struct keep3 {
2116  int val;
2117};
2118"""
2119        expected = """\
2120struct keep1 {
2121  int val;
2122};
2123struct keep2 {
2124  int val;
2125};
2126struct keep3 {
2127  int val;
2128};
2129"""
2130        self.assertEqual(self.parse(text, set(["remove1", "remove2"])), expected)
2131
2132
2133class FullPathTest(unittest.TestCase):
2134    """Test of the full path parsing."""
2135
2136    def parse(self, text, keep=None):
2137        if not keep:
2138            keep = set()
2139        out = utils.StringOutput()
2140        blocks = BlockParser().parse(CppStringTokenizer(text))
2141
2142        blocks.removeStructs(kernel_structs_to_remove)
2143        blocks.removeVarsAndFuncs(keep)
2144        blocks.replaceTokens(kernel_token_replacements)
2145        blocks.optimizeAll(None)
2146
2147        blocks.write(out)
2148        return out.get()
2149
2150    def test_function_removed(self):
2151        text = """\
2152static inline __u64 function()
2153{
2154}
2155"""
2156        expected = ""
2157        self.assertEqual(self.parse(text), expected)
2158
2159    def test_function_removed_with_struct(self):
2160        text = """\
2161static inline struct something* function()
2162{
2163}
2164"""
2165        expected = ""
2166        self.assertEqual(self.parse(text), expected)
2167
2168    def test_function_kept(self):
2169        text = """\
2170static inline __u64 function()
2171{
2172}
2173"""
2174        expected = """\
2175static inline __u64 function() {
2176}
2177"""
2178        self.assertEqual(self.parse(text, set(["function"])), expected)
2179
2180    def test_var_removed(self):
2181        text = "__u64 variable;"
2182        expected = ""
2183        self.assertEqual(self.parse(text), expected)
2184
2185    def test_var_kept(self):
2186        text = "__u64 variable;"
2187        expected = "__u64 variable;\n"
2188        self.assertEqual(self.parse(text, set(["variable"])), expected)
2189
2190    def test_keep_function_typedef(self):
2191        text = "typedef void somefunction_t(void);"
2192        expected = "typedef void somefunction_t(void);\n"
2193        self.assertEqual(self.parse(text), expected)
2194
2195    def test_struct_keep_attribute(self):
2196        text = """\
2197struct something_s {
2198  __u32 s1;
2199  __u32 s2;
2200} __attribute__((packed));
2201"""
2202        expected = """\
2203struct something_s {
2204  __u32 s1;
2205  __u32 s2;
2206} __attribute__((packed));
2207"""
2208        self.assertEqual(self.parse(text), expected)
2209
2210    def test_function_keep_attribute_structs(self):
2211        text = """\
2212static __inline__ struct some_struct1 * function(struct some_struct2 * e) {
2213}
2214"""
2215        expected = """\
2216static __inline__ struct some_struct1 * function(struct some_struct2 * e) {
2217}
2218"""
2219        self.assertEqual(self.parse(text, set(["function"])), expected)
2220
2221    def test_struct_after_struct(self):
2222        text = """\
2223struct first {
2224};
2225
2226struct second {
2227  unsigned short s1;
2228#define SOMETHING 8
2229  unsigned short s2;
2230};
2231"""
2232        expected = """\
2233struct first {
2234};
2235struct second {
2236  unsigned short s1;
2237#define SOMETHING 8
2238  unsigned short s2;
2239};
2240"""
2241        self.assertEqual(self.parse(text), expected)
2242
2243    def test_other_not_removed(self):
2244        text = """\
2245typedef union {
2246  __u64 tu1;
2247  __u64 tu2;
2248} typedef_name;
2249
2250union {
2251  __u64 u1;
2252  __u64 u2;
2253};
2254
2255struct {
2256  __u64 s1;
2257  __u64 s2;
2258};
2259
2260enum {
2261  ENUM1 = 0,
2262  ENUM2,
2263};
2264
2265__extension__ typedef __signed__ long long __s64;
2266"""
2267        expected = """\
2268typedef union {
2269  __u64 tu1;
2270  __u64 tu2;
2271} typedef_name;
2272union {
2273  __u64 u1;
2274  __u64 u2;
2275};
2276struct {
2277  __u64 s1;
2278  __u64 s2;
2279};
2280enum {
2281  ENUM1 = 0,
2282  ENUM2,
2283};
2284__extension__ typedef __signed__ long long __s64;
2285"""
2286
2287        self.assertEqual(self.parse(text), expected)
2288
2289    def test_semicolon_after_function(self):
2290        text = """\
2291static inline __u64 function()
2292{
2293};
2294
2295struct should_see {
2296        __u32                           field;
2297};
2298"""
2299        expected = """\
2300struct should_see {
2301  __u32 field;
2302};
2303"""
2304        self.assertEqual(self.parse(text), expected)
2305
2306    def test_define_in_middle_keep(self):
2307        text = """\
2308enum {
2309  ENUM0 = 0x10,
2310  ENUM1 = 0x20,
2311#define SOMETHING SOMETHING_ELSE
2312  ENUM2 = 0x40,
2313};
2314"""
2315        expected = """\
2316enum {
2317  ENUM0 = 0x10,
2318  ENUM1 = 0x20,
2319#define SOMETHING SOMETHING_ELSE
2320  ENUM2 = 0x40,
2321};
2322"""
2323        self.assertEqual(self.parse(text), expected)
2324
2325    def test_define_in_middle_remove(self):
2326        text = """\
2327static inline function() {
2328#define SOMETHING1 SOMETHING_ELSE1
2329  i = 0;
2330  {
2331    i = 1;
2332  }
2333#define SOMETHING2 SOMETHING_ELSE2
2334}
2335"""
2336        expected = """\
2337#define SOMETHING1 SOMETHING_ELSE1
2338#define SOMETHING2 SOMETHING_ELSE2
2339"""
2340        self.assertEqual(self.parse(text), expected)
2341
2342    def test_define_in_middle_force_keep(self):
2343        text = """\
2344static inline function() {
2345#define SOMETHING1 SOMETHING_ELSE1
2346  i = 0;
2347  {
2348    i = 1;
2349  }
2350#define SOMETHING2 SOMETHING_ELSE2
2351}
2352"""
2353        expected = """\
2354static inline function() {
2355#define SOMETHING1 SOMETHING_ELSE1
2356  i = 0;
2357 {
2358    i = 1;
2359  }
2360#define SOMETHING2 SOMETHING_ELSE2
2361}
2362"""
2363        self.assertEqual(self.parse(text, set(["function"])), expected)
2364
2365    def test_define_before_remove(self):
2366        text = """\
2367#define SHOULD_BE_KEPT NOTHING1
2368#define ANOTHER_TO_KEEP NOTHING2
2369static inline function() {
2370#define SOMETHING1 SOMETHING_ELSE1
2371  i = 0;
2372  {
2373    i = 1;
2374  }
2375#define SOMETHING2 SOMETHING_ELSE2
2376}
2377"""
2378        expected = """\
2379#define SHOULD_BE_KEPT NOTHING1
2380#define ANOTHER_TO_KEEP NOTHING2
2381#define SOMETHING1 SOMETHING_ELSE1
2382#define SOMETHING2 SOMETHING_ELSE2
2383"""
2384        self.assertEqual(self.parse(text), expected)
2385
2386    def test_extern_C(self):
2387        text = """\
2388#if defined(__cplusplus)
2389extern "C" {
2390#endif
2391
2392struct something {
2393};
2394
2395#if defined(__cplusplus)
2396}
2397#endif
2398"""
2399        expected = """\
2400#ifdef __cplusplus
2401extern "C" {
2402#endif
2403struct something {
2404};
2405#ifdef __cplusplus
2406}
2407#endif
2408"""
2409        self.assertEqual(self.parse(text), expected)
2410
2411    def test_macro_definition_removed(self):
2412        text = """\
2413#define MACRO_FUNCTION_NO_PARAMS static inline some_func() {}
2414MACRO_FUNCTION_NO_PARAMS()
2415
2416#define MACRO_FUNCTION_PARAMS(a) static inline some_func() { a; }
2417MACRO_FUNCTION_PARAMS(a = 1)
2418
2419something that should still be kept
2420MACRO_FUNCTION_PARAMS(b)
2421"""
2422        expected = """\
2423#define MACRO_FUNCTION_NO_PARAMS static inline some_func() { }
2424#define MACRO_FUNCTION_PARAMS(a) static inline some_func() { a; }
2425something that should still be kept
2426"""
2427        self.assertEqual(self.parse(text), expected)
2428
2429    def test_verify_timeval_itemerval(self):
2430        text = """\
2431struct __kernel_old_timeval {
2432  struct something val;
2433};
2434struct __kernel_old_itimerval {
2435  struct __kernel_old_timeval val;
2436};
2437struct fields {
2438  struct __kernel_old_timeval timeval;
2439  struct __kernel_old_itimerval itimerval;
2440};
2441"""
2442        expected = """\
2443struct fields {
2444  struct timeval timeval;
2445  struct itimerval itimerval;
2446};
2447"""
2448        self.assertEqual(self.parse(text), expected)
2449
2450    def test_token_replacement(self):
2451        text = """\
2452#define SIGRTMIN 32
2453#define SIGRTMAX _NSIG
2454"""
2455        expected = """\
2456#define __SIGRTMIN 32
2457#define __SIGRTMAX _KERNEL__NSIG
2458"""
2459        self.assertEqual(self.parse(text), expected)
2460
2461
2462if __name__ == '__main__':
2463    unittest.main()
2464