]> git.phdru.name Git - sqlconvert.git/blobdiff - sqlconvert/process_tokens.py
Tests: Use tox instead of tests/Makefile
[sqlconvert.git] / sqlconvert / process_tokens.py
index b1c26022f7ad2adbecaf943fee47cd74c7794a7d..cb14467ccf958b478d888714ebaa6240e10fede2 100644 (file)
@@ -1,4 +1,6 @@
 
+from sqlparse.sql import Comment
+from sqlobject.converters import sqlrepr
 from sqlparse import parse
 from sqlparse.compat import PY3
 from sqlparse import tokens as T
@@ -12,6 +14,12 @@ def find_error(token_list):
     return False
 
 
+def is_comment_or_space(token):
+    return isinstance(token, Comment) or \
+        token.ttype in (T.Comment, T.Comment.Single, T.Comment.Multiline,
+                        T.Newline, T.Whitespace)
+
+
 def is_newline_statement(statement):
     for token in statement.tokens[:]:
         if token.ttype is not T.Newline:
@@ -19,6 +27,15 @@ def is_newline_statement(statement):
     return True
 
 
+def escape_strings(token_list, dbname):
+    """Escape strings"""
+    for token in token_list.flatten():
+        if token.ttype is T.String.Single:
+            value = token.value[1:-1]  # unquote by removing apostrophes
+            value = sqlrepr(value, dbname)
+            token.normalized = token.value = value
+
+
 if PY3:
     xrange = range
 
@@ -40,8 +57,7 @@ class StatementGrouper(object):
         last_stmt = statements[-1]
         for i in xrange(len(last_stmt.tokens) - 1, 0, -1):
             token = last_stmt.tokens[i]
-            if token.ttype in (T.Comment.Single, T.Comment.Multiline,
-                               T.Newline, T.Whitespace):
+            if is_comment_or_space(token):
                 continue
             if token.ttype is T.Punctuation and token.value == ';':
                 break  # The last statement is complete
@@ -54,15 +70,16 @@ class StatementGrouper(object):
         for stmt in self.statements:
             yield stmt
         self.statements = []
-        raise StopIteration
+        return
 
     def close(self):
         if not self.lines:
             return
         tokens = parse(''.join(self.lines), encoding=self.encoding)
         for token in tokens:
-            if (token.ttype not in (T.Comment.Single, T.Comment.Multiline,
-                                    T.Newline, T.Whitespace)):
+            if not is_comment_or_space(token):
                 raise ValueError("Incomplete SQL statement: %s" %
                                  tokens)
+        self.lines = []
+        self.statements = []
         return tokens