]> git.phdru.name Git - sqlconvert.git/commitdiff
Skip semicolons and newlines /*! directives */;
authorOleg Broytman <phd@phdru.name>
Wed, 7 Sep 2016 19:26:13 +0000 (22:26 +0300)
committerOleg Broytman <phd@phdru.name>
Wed, 7 Sep 2016 19:50:28 +0000 (22:50 +0300)
TODO
scripts/mysql2sql
sqlconvert/process_mysql.py
sqlconvert/process_tokens.py
tests/test_tokens.py

diff --git a/TODO b/TODO
index 590f4bb4fff8d0b9f9f5e98add486ae27f0bd2a1..26baf8483a8282730500a3584a44955f4c288b85 100644 (file)
--- a/TODO
+++ b/TODO
@@ -1,6 +1,3 @@
-Fix semicolons and newlines after /*! directives */
-
-
 Convert string escapes to generic SQL, Postgres- or SQLite-specific.
 
 
index 17ee67cd4fb123f08332e1fa37633f896bc66cc1..3bb03028d9d334b192fd27904e043fbdd74b4124 100755 (executable)
@@ -8,8 +8,8 @@ import sys
 
 from sqlparse.compat import text_type
 from sqlconvert.print_tokens import print_tokens
-from sqlconvert.process_mysql import process_statement
-from sqlconvert.process_tokens import StatementGrouper
+from sqlconvert.process_mysql import is_directive_statement, process_statement
+from sqlconvert.process_tokens import is_newline_statement, StatementGrouper
 
 from m_lib.defenc import default_encoding
 from m_lib.pbar.tty_pbar import ttyProgressBar
@@ -39,6 +39,7 @@ def main(infile, encoding, outfile, output_encoding, use_pbar):
         cur_pos = 0
 
     grouper = StatementGrouper(encoding=encoding)
+    got_directive = False
     for line in infile:
         if use_pbar:
             if isinstance(line, text_type):
@@ -49,6 +50,14 @@ def main(infile, encoding, outfile, output_encoding, use_pbar):
         grouper.process_line(line)
         if grouper.statements:
             for statement in grouper.get_statements():
+                if got_directive and is_newline_statement(statement):
+                    # Replace a sequence of newlines after a /*! directive */;
+                    # with one newline
+                    #outfile.write(u'\n')
+                    continue
+                got_directive = is_directive_statement(statement)
+                if got_directive:
+                    continue
                 process_statement(statement)
                 print_tokens(statement, outfile=outfile,
                              encoding=output_encoding)
index 5c63d0d1bc1de9c39c67455cd422ffc838a4220f..3f2f6aa3c7dee8ad84f60c3a630050c8f3d413b6 100644 (file)
@@ -3,6 +3,39 @@ from sqlparse.sql import Comment
 from sqlparse import tokens as T
 
 
+def _is_directive_token(token):
+    if isinstance(token, Comment):
+        subtokens = token.tokens
+        if subtokens:
+            comment = subtokens[0]
+            if comment.ttype is T.Comment.Multiline and \
+                    comment.value.startswith('/*!'):
+                return True
+    return False
+
+
+def is_directive_statement(statement):
+    tokens = statement.tokens
+    if not _is_directive_token(tokens[0]):
+        return False
+    if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';':
+        return False
+    for token in tokens[1:-1]:
+        if token.ttype not in (T.Newline, T.Whitespace):
+            return False
+    return True
+
+
+def remove_directives(statement):
+    """Remove /*! directives */ from the first-level"""
+    new_tokens = []
+    for token in statement.tokens:
+        if _is_directive_token(token):
+            continue
+        new_tokens.append(token)
+    statement.tokens = new_tokens
+
+
 def requote_names(token_list):
     """Remove backticks, quote non-lowercase identifiers"""
     for token in token_list.flatten():
@@ -16,21 +49,6 @@ def requote_names(token_list):
                 token.normalized = token.value = '"%s"' % value
 
 
-def remove_directives(statement):
-    """Remove /*! directives */ from the first-level"""
-    new_tokens = []
-    for token in statement.tokens:
-        if isinstance(token, Comment):
-            subtokens = token.tokens
-            if subtokens:
-                comment = subtokens[0]
-                if comment.ttype is T.Comment.Multiline and \
-                        comment.value.startswith('/*!'):
-                    continue
-        new_tokens.append(token)
-    statement.tokens = new_tokens
-
-
 def process_statement(statement):
-    requote_names(statement)
     remove_directives(statement)
+    requote_names(statement)
index 0bbf94cb703ecb8963d93eeac147ecff0c19f1d4..924ba4abd7ce2f00c8b5ecb477a081c023cc674d 100644 (file)
@@ -1,17 +1,24 @@
 
 from sqlparse import parse
 from sqlparse.compat import PY3
-from sqlparse.tokens import Error, Punctuation, Comment, Newline, Whitespace
+from sqlparse import tokens as T
 
 
 def find_error(token_list):
     """Find an error"""
     for token in token_list.flatten():
-        if token.ttype is Error:
+        if token.ttype is T.Error:
             return True
     return False
 
 
+def is_newline_statement(statement):
+    for token in statement.tokens[:]:
+        if token.ttype is not T.Newline:
+            return False
+    return True
+
+
 if PY3:
     xrange = range
 
@@ -33,10 +40,10 @@ class StatementGrouper(object):
         last_stmt = statements[-1]
         for i in xrange(len(last_stmt.tokens) - 1, 0, -1):
             token = last_stmt.tokens[i]
-            if token.ttype in (Comment.Single, Comment.Multiline,
-                               Newline, Whitespace):
+            if token.ttype in (T.Comment.Single, T.Comment.Multiline,
+                               T.Newline, T.Whitespace):
                 continue
-            if token.ttype is Punctuation and token.value == ';':
+            if token.ttype is T.Punctuation and token.value == ';':
                 break  # The last statement is complete
             # The last statement is still incomplete - wait for the next line
             return
@@ -53,8 +60,8 @@ class StatementGrouper(object):
             return
         tokens = parse(''.join(self.lines), encoding=self.encoding)
         for token in tokens:
-            if (token.ttype not in (Comment.Single, Comment.Multiline,
-                                    Newline, Whitespace)):
+            if (token.ttype not in (T.Comment.Single, T.Comment.Multiline,
+                                    T.Newline, T.Whitespace)):
                 raise ValueError("Incomplete SQL statement: %s" %
                                  tokens)
         return tokens
index 3b0452a50cdcdc4c770b6ffbbfb653d29fc5286f..2c930f77e07593eafb835dcf4264d3cce3e14520 100755 (executable)
@@ -5,8 +5,8 @@ import unittest
 from sqlparse import parse
 
 from sqlconvert.print_tokens import tlist2str
-from sqlconvert.process_mysql import requote_names, remove_directives, \
-        process_statement
+from sqlconvert.process_mysql import remove_directives, requote_names, \
+        is_directive_statement, process_statement
 from tests import main
 
 
@@ -28,12 +28,18 @@ class TestTokens(unittest.TestCase):
         query = tlist2str(parsed)
         self.assertEqual(query, 'SELECT * FROM "T"')
 
-    def test_directives(self):
+    def test_directive(self):
         parsed = parse("select /*! test */ * from /* test */ `T`")[0]
         remove_directives(parsed)
         query = tlist2str(parsed)
         self.assertEqual(query, 'SELECT * FROM /* test */ `T`')
 
+    def test_directive_statement(self):
+        parsed = parse("/*! test */ test ;")[0]
+        self.assertFalse(is_directive_statement(parsed))
+        parsed = parse("/*! test */ ;")[0]
+        self.assertTrue(is_directive_statement(parsed))
+
     def test_process(self):
         parsed = parse("select /*! test */ * from /* test */ `T`")[0]
         process_statement(parsed)