]> git.phdru.name Git - sqlconvert.git/blobdiff - mysql2sql/process_tokens.py
Process input stream line by line
[sqlconvert.git] / mysql2sql / process_tokens.py
index e9a62f825a6d5625dcd6dfc2e03b2f150208a12b..37752e7c146a8d76ad63f9b7ccbbe395066eec4e 100644 (file)
@@ -1,19 +1,59 @@
 
-from sqlparse.sql import TokenList
-from sqlparse.tokens import Name
+from sqlparse import parse
+from sqlparse.sql import Statement
+from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \
+    Whitespace
 
 
 def requote_names(token_list):
     """Remove backticks, quote non-lowercase identifiers"""
-    for token in token_list:
-        if isinstance(token, TokenList):
-            requote_names(token)
-        else:
-            if token.ttype is Name:
-                value = token.value
-                if (value[0] == "`") and (value[-1] == "`"):
-                    value = value[1:-1]
-                if value.islower():
-                    token.normalized = token.value = value
-                else:
-                    token.normalized = token.value = '"%s"' % value
+    for token in token_list.flatten():
+        if token.ttype is Name:
+            value = token.value
+            if (value[0] == "`") and (value[-1] == "`"):
+                value = value[1:-1]
+            if value.islower():
+                token.normalized = token.value = value
+            else:
+                token.normalized = token.value = '"%s"' % value
+
+
+def find_error(token_list):
+    """Find an error"""
+    for token in token_list.flatten():
+        if token.ttype is Error:
+            return True
+    return False
+
+
+class StatementGrouper(object):
+    def __init__(self):
+        self.statements = []
+        self.tokens = []
+        self.lines = []
+
+    def process_line(self, line):
+        lines = self.lines
+        lines.append(line)
+        tokens = parse('\n'.join(lines))[0]
+        self.process_tokens(tokens)
+        self.lines = []
+
+    def process_tokens(self, tokens):
+        for token in tokens:
+            self.tokens.append(token)
+            if (token.ttype == Punctuation) and (token.value == ';'):
+                self.statements.append(Statement(self.tokens))
+                self.tokens = []
+
+    def get_statements(self):
+        for statement in self.statements:
+            yield statement
+        self.statements = []
+
+    def close(self):
+        for token in self.tokens:
+            if (token.ttype not in (Comment.Single, Comment.Multiline,
+                                    Newline, Whitespace)):
+                raise ValueError("Incomplete SQL statement: %s" % self.tokens)
+        return self.tokens