]> git.phdru.name Git - sqlconvert.git/commitdiff
Collect lines and reparse until the last statement is complete
authorOleg Broytman <phd@phdru.name>
Fri, 26 Aug 2016 21:48:02 +0000 (00:48 +0300)
committerOleg Broytman <phd@phdru.name>
Fri, 26 Aug 2016 21:48:02 +0000 (00:48 +0300)
mysql2sql/process_tokens.py
scripts/group-file.py
scripts/group-sql.py
tests/test_stgrouper.py

index 37752e7c146a8d76ad63f9b7ccbbe395066eec4e..ddab2bf649eac2e69e67eef48e7b1edbf07b0dea 100644 (file)
@@ -1,6 +1,5 @@
 
 from sqlparse import parse
-from sqlparse.sql import Statement
 from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \
     Whitespace
 
@@ -27,33 +26,43 @@ def find_error(token_list):
 
 
 class StatementGrouper(object):
+    """Collect lines and reparse until the last statement is complete"""
+
     def __init__(self):
-        self.statements = []
-        self.tokens = []
         self.lines = []
+        self.statements = []
 
     def process_line(self, line):
-        lines = self.lines
-        lines.append(line)
-        tokens = parse('\n'.join(lines))[0]
-        self.process_tokens(tokens)
-        self.lines = []
+        self.lines.append(line)
+        self.process_lines()
 
-    def process_tokens(self, tokens):
-        for token in tokens:
-            self.tokens.append(token)
-            if (token.ttype == Punctuation) and (token.value == ';'):
-                self.statements.append(Statement(self.tokens))
-                self.tokens = []
+    def process_lines(self):
+        statements = parse('\n'.join(self.lines))
+        last_stmt = statements[-1]
+        for i in xrange(len(last_stmt.tokens) - 1, 0, -1):
+            token = last_stmt.tokens[i]
+            if token.ttype in (Comment.Single, Comment.Multiline,
+                               Newline, Whitespace):
+                continue
+            if token.ttype is Punctuation and token.value == ';':
+                break  # The last statement is complete
+            # The last statement is still incomplete - wait for the next line
+            return
+        self.lines = []
+        self.statements = statements
 
     def get_statements(self):
-        for statement in self.statements:
-            yield statement
+        for stmt in self.statements:
+            yield stmt
         self.statements = []
 
     def close(self):
-        for token in self.tokens:
+        if not self.lines:
+            return
+        tokens = parse('\n'.join(self.lines))
+        for token in tokens:
             if (token.ttype not in (Comment.Single, Comment.Multiline,
                                     Newline, Whitespace)):
-                raise ValueError("Incomplete SQL statement: %s" % self.tokens)
-        return self.tokens
+                raise ValueError("Incomplete SQL statement: %s" %
+                                 tokens)
+        return tokens
index 92f8fb9100d8839d7f0150e899db8c5886d231d8..5ab1f447ea2f48bf61c5dae14db4ccc8a29f4675 100755 (executable)
@@ -23,9 +23,10 @@ def main(filename):
                     statement._pprint_tree()
                 print("----------")
     tokens = grouper.close()
-    for token in tokens:
-        print_tokens(token)
-        print(repr(token))
+    if tokens:
+        for token in tokens:
+            print_tokens(token)
+            print(repr(token))
 
 
 if __name__ == '__main__':
index 850dcb30b2ad1fd72f79274515feff33517f4c7b..953e9e7cfc66a03304af31cdc7473ec584fd1a7f 100755 (executable)
@@ -22,9 +22,10 @@ def main(*queries):
                 statement._pprint_tree()
             print("----------")
     tokens = grouper.close()
-    for token in tokens:
-        print_tokens(token)
-        print(repr(token))
+    if tokens:
+        for token in tokens:
+            print_tokens(token)
+            print(repr(token))
 
 
 def test():
index bd60168c020300a0969de0a142c9f9c3fb405f17..d77773e10e913a46c728dfecde680c14c5c2f603 100755 (executable)
@@ -26,7 +26,7 @@ class TestStGrouper(unittest.TestCase):
             query = tlist2str(statement)
             self.assertEqual(query, 'SELECT * FROM "T";')
         self.assertEqual(len(grouper.statements), 0)
-        self.assertEqual(grouper.close(), [])
+        self.assertEqual(grouper.close(), None)
 
 if __name__ == "__main__":
     main()