]> git.phdru.name Git - sqlconvert.git/blobdiff - mysql2sql/process_tokens.py
Fix Python3 compatibility
[sqlconvert.git] / mysql2sql / process_tokens.py
index 7f30a5589b7fad13265b4a9e147fe6308afba545..7e323543f81f9baa5f65444e7edc8327e16d7eb7 100644 (file)
@@ -1,5 +1,6 @@
 
-from sqlparse.sql import Statement
+from sqlparse import parse
+from sqlparse.compat import PY3
 from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \
     Whitespace
 
@@ -25,26 +26,48 @@ def find_error(token_list):
     return False
 
 
+if PY3:
+    xrange = range
+
+
 class StatementGrouper(object):
+    """Collect lines and reparse until the last statement is complete"""
+
     def __init__(self):
-        self.tokens = []
+        self.lines = []
         self.statements = []
 
+    def process_line(self, line):
+        self.lines.append(line)
+        self.process_lines()
+
+    def process_lines(self):
+        statements = parse('\n'.join(self.lines))
+        last_stmt = statements[-1]
+        for i in xrange(len(last_stmt.tokens) - 1, 0, -1):
+            token = last_stmt.tokens[i]
+            if token.ttype in (Comment.Single, Comment.Multiline,
+                               Newline, Whitespace):
+                continue
+            if token.ttype is Punctuation and token.value == ';':
+                break  # The last statement is complete
+            # The last statement is still incomplete - wait for the next line
+            return
+        self.lines = []
+        self.statements = statements
+
     def get_statements(self):
-        for statement in self.statements:
-            yield statement
+        for stmt in self.statements:
+            yield stmt
         self.statements = []
 
-    def process(self, tokens):
-        for token in tokens:
-            self.tokens.append(token)
-            if (token.ttype == Punctuation) and (token.value == ';'):
-                self.statements.append(Statement(self.tokens))
-                self.tokens = []
-
     def close(self):
-        for token in self.tokens:
+        if not self.lines:
+            return
+        tokens = parse('\n'.join(self.lines))
+        for token in tokens:
             if (token.ttype not in (Comment.Single, Comment.Multiline,
                                     Newline, Whitespace)):
-                raise ValueError("Incomplete SQL statement: %s" % self.tokens)
-        return self.tokens
+                raise ValueError("Incomplete SQL statement: %s" %
+                                 tokens)
+        return tokens