From ad12dfd9c03a4f6bbd03d22238b6399ab09962ed Mon Sep 17 00:00:00 2001 From: Oleg Broytman Date: Wed, 7 Sep 2016 22:26:13 +0300 Subject: [PATCH] Skip semicolons and newlines /*! directives */; --- TODO | 3 --- scripts/mysql2sql | 13 ++++++++-- sqlconvert/process_mysql.py | 50 ++++++++++++++++++++++++------------ sqlconvert/process_tokens.py | 21 ++++++++++----- tests/test_tokens.py | 12 ++++++--- 5 files changed, 68 insertions(+), 31 deletions(-) diff --git a/TODO b/TODO index 590f4bb..26baf84 100644 --- a/TODO +++ b/TODO @@ -1,6 +1,3 @@ -Fix semicolons and newlines after /*! directives */ - - Convert string escapes to generic SQL, Postgres- or SQLite-specific. diff --git a/scripts/mysql2sql b/scripts/mysql2sql index 17ee67c..3bb0302 100755 --- a/scripts/mysql2sql +++ b/scripts/mysql2sql @@ -8,8 +8,8 @@ import sys from sqlparse.compat import text_type from sqlconvert.print_tokens import print_tokens -from sqlconvert.process_mysql import process_statement -from sqlconvert.process_tokens import StatementGrouper +from sqlconvert.process_mysql import is_directive_statement, process_statement +from sqlconvert.process_tokens import is_newline_statement, StatementGrouper from m_lib.defenc import default_encoding from m_lib.pbar.tty_pbar import ttyProgressBar @@ -39,6 +39,7 @@ def main(infile, encoding, outfile, output_encoding, use_pbar): cur_pos = 0 grouper = StatementGrouper(encoding=encoding) + got_directive = False for line in infile: if use_pbar: if isinstance(line, text_type): @@ -49,6 +50,14 @@ def main(infile, encoding, outfile, output_encoding, use_pbar): grouper.process_line(line) if grouper.statements: for statement in grouper.get_statements(): + if got_directive and is_newline_statement(statement): + # Replace a sequence of newlines after a /*! directive */; + # with one newline + #outfile.write(u'\n') + continue + got_directive = is_directive_statement(statement) + if got_directive: + continue process_statement(statement) print_tokens(statement, outfile=outfile, encoding=output_encoding) diff --git a/sqlconvert/process_mysql.py b/sqlconvert/process_mysql.py index 5c63d0d..3f2f6aa 100644 --- a/sqlconvert/process_mysql.py +++ b/sqlconvert/process_mysql.py @@ -3,6 +3,39 @@ from sqlparse.sql import Comment from sqlparse import tokens as T +def _is_directive_token(token): + if isinstance(token, Comment): + subtokens = token.tokens + if subtokens: + comment = subtokens[0] + if comment.ttype is T.Comment.Multiline and \ + comment.value.startswith('/*!'): + return True + return False + + +def is_directive_statement(statement): + tokens = statement.tokens + if not _is_directive_token(tokens[0]): + return False + if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';': + return False + for token in tokens[1:-1]: + if token.ttype not in (T.Newline, T.Whitespace): + return False + return True + + +def remove_directives(statement): + """Remove /*! directives */ from the first-level""" + new_tokens = [] + for token in statement.tokens: + if _is_directive_token(token): + continue + new_tokens.append(token) + statement.tokens = new_tokens + + def requote_names(token_list): """Remove backticks, quote non-lowercase identifiers""" for token in token_list.flatten(): @@ -16,21 +49,6 @@ def requote_names(token_list): token.normalized = token.value = '"%s"' % value -def remove_directives(statement): - """Remove /*! directives */ from the first-level""" - new_tokens = [] - for token in statement.tokens: - if isinstance(token, Comment): - subtokens = token.tokens - if subtokens: - comment = subtokens[0] - if comment.ttype is T.Comment.Multiline and \ - comment.value.startswith('/*!'): - continue - new_tokens.append(token) - statement.tokens = new_tokens - - def process_statement(statement): - requote_names(statement) remove_directives(statement) + requote_names(statement) diff --git a/sqlconvert/process_tokens.py b/sqlconvert/process_tokens.py index 0bbf94c..924ba4a 100644 --- a/sqlconvert/process_tokens.py +++ b/sqlconvert/process_tokens.py @@ -1,17 +1,24 @@ from sqlparse import parse from sqlparse.compat import PY3 -from sqlparse.tokens import Error, Punctuation, Comment, Newline, Whitespace +from sqlparse import tokens as T def find_error(token_list): """Find an error""" for token in token_list.flatten(): - if token.ttype is Error: + if token.ttype is T.Error: return True return False +def is_newline_statement(statement): + for token in statement.tokens[:]: + if token.ttype is not T.Newline: + return False + return True + + if PY3: xrange = range @@ -33,10 +40,10 @@ class StatementGrouper(object): last_stmt = statements[-1] for i in xrange(len(last_stmt.tokens) - 1, 0, -1): token = last_stmt.tokens[i] - if token.ttype in (Comment.Single, Comment.Multiline, - Newline, Whitespace): + if token.ttype in (T.Comment.Single, T.Comment.Multiline, + T.Newline, T.Whitespace): continue - if token.ttype is Punctuation and token.value == ';': + if token.ttype is T.Punctuation and token.value == ';': break # The last statement is complete # The last statement is still incomplete - wait for the next line return @@ -53,8 +60,8 @@ class StatementGrouper(object): return tokens = parse(''.join(self.lines), encoding=self.encoding) for token in tokens: - if (token.ttype not in (Comment.Single, Comment.Multiline, - Newline, Whitespace)): + if (token.ttype not in (T.Comment.Single, T.Comment.Multiline, + T.Newline, T.Whitespace)): raise ValueError("Incomplete SQL statement: %s" % tokens) return tokens diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 3b0452a..2c930f7 100755 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -5,8 +5,8 @@ import unittest from sqlparse import parse from sqlconvert.print_tokens import tlist2str -from sqlconvert.process_mysql import requote_names, remove_directives, \ - process_statement +from sqlconvert.process_mysql import remove_directives, requote_names, \ + is_directive_statement, process_statement from tests import main @@ -28,12 +28,18 @@ class TestTokens(unittest.TestCase): query = tlist2str(parsed) self.assertEqual(query, 'SELECT * FROM "T"') - def test_directives(self): + def test_directive(self): parsed = parse("select /*! test */ * from /* test */ `T`")[0] remove_directives(parsed) query = tlist2str(parsed) self.assertEqual(query, 'SELECT * FROM /* test */ `T`') + def test_directive_statement(self): + parsed = parse("/*! test */ test ;")[0] + self.assertFalse(is_directive_statement(parsed)) + parsed = parse("/*! test */ ;")[0] + self.assertTrue(is_directive_statement(parsed)) + def test_process(self): parsed = parse("select /*! test */ * from /* test */ `T`")[0] process_statement(parsed) -- 2.39.5