From 86cace63e9bc60457f310bc77aba8fc54b748bbe Mon Sep 17 00:00:00 2001 From: Oleg Broytman Date: Sun, 4 Sep 2016 17:29:40 +0300 Subject: [PATCH] Add MySQL-specific remove_directives() and process_statement() --- ChangeLog | 4 ++++ demo/group-file.py | 4 ++-- demo/group-sql.py | 4 ++-- demo/parse-file.py | 4 ++-- demo/parse-sql.py | 4 ++-- scripts/mysql2sql | 4 ++-- sqlconvert/__version__.py | 2 +- sqlconvert/process_mysql.py | 25 +++++++++++++++++++++++-- tests/test_tokens.py | 27 ++++++++++++++++++++------- 9 files changed, 58 insertions(+), 20 deletions(-) diff --git a/ChangeLog b/ChangeLog index e12e5ad..1862c1f 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,7 @@ +Version 0.0.4 (2016-09-04) + + Add MySQL-specific remove_directives() and process_statement(). + Version 0.0.3 (2016-09-04) Rename the project: mysql2py -> sqlconvert. diff --git a/demo/group-file.py b/demo/group-file.py index e59ff45..b8a7021 100755 --- a/demo/group-file.py +++ b/demo/group-file.py @@ -3,7 +3,7 @@ from __future__ import print_function import sys from sqlconvert.print_tokens import print_tokens -from sqlconvert.process_mysql import requote_names +from sqlconvert.process_mysql import process_statement from sqlconvert.process_tokens import find_error, StatementGrouper @@ -17,7 +17,7 @@ def main(filename): print("----------") if find_error(statement): print("ERRORS IN QUERY") - requote_names(statement) + process_statement(statement) print_tokens(statement) print() statement._pprint_tree() diff --git a/demo/group-sql.py b/demo/group-sql.py index 854a0d1..c0bf2d2 100755 --- a/demo/group-sql.py +++ b/demo/group-sql.py @@ -3,7 +3,7 @@ from __future__ import print_function import sys from sqlconvert.print_tokens import print_tokens -from sqlconvert.process_mysql import requote_names +from sqlconvert.process_mysql import process_statement from sqlconvert.process_tokens import find_error, StatementGrouper @@ -16,7 +16,7 @@ def main(*queries): print("----------") if find_error(statement): print("ERRORS IN QUERY") - requote_names(statement) + process_statement(statement) print_tokens(statement) print() statement._pprint_tree() diff --git a/demo/parse-file.py b/demo/parse-file.py index bf65eec..e19bde2 100755 --- a/demo/parse-file.py +++ b/demo/parse-file.py @@ -4,7 +4,7 @@ from __future__ import print_function import sys from sqlparse import parse from sqlconvert.print_tokens import print_tokens -from sqlconvert.process_mysql import requote_names +from sqlconvert.process_mysql import process_statement from sqlconvert.process_tokens import find_error @@ -15,7 +15,7 @@ def main(filename): print("----------") if find_error(parsed): print("ERRORS IN QUERY") - requote_names(parsed) + process_statement(parsed) print_tokens(parsed) print() parsed._pprint_tree() diff --git a/demo/parse-sql.py b/demo/parse-sql.py index 1cbe0a1..e4be365 100755 --- a/demo/parse-sql.py +++ b/demo/parse-sql.py @@ -4,7 +4,7 @@ from __future__ import print_function import sys from sqlparse import parse from sqlconvert.print_tokens import print_tokens -from sqlconvert.process_mysql import requote_names +from sqlconvert.process_mysql import process_statement from sqlconvert.process_tokens import find_error @@ -14,7 +14,7 @@ def main(*queries): print("----------") if find_error(parsed): print("ERRORS IN QUERY") - requote_names(parsed) + process_statement(parsed) print_tokens(parsed) print() parsed._pprint_tree() diff --git a/scripts/mysql2sql b/scripts/mysql2sql index 5ba8fd7..17ee67c 100755 --- a/scripts/mysql2sql +++ b/scripts/mysql2sql @@ -8,7 +8,7 @@ import sys from sqlparse.compat import text_type from sqlconvert.print_tokens import print_tokens -from sqlconvert.process_mysql import requote_names +from sqlconvert.process_mysql import process_statement from sqlconvert.process_tokens import StatementGrouper from m_lib.defenc import default_encoding @@ -49,7 +49,7 @@ def main(infile, encoding, outfile, output_encoding, use_pbar): grouper.process_line(line) if grouper.statements: for statement in grouper.get_statements(): - requote_names(statement) + process_statement(statement) print_tokens(statement, outfile=outfile, encoding=output_encoding) tokens = grouper.close() diff --git a/sqlconvert/__version__.py b/sqlconvert/__version__.py index ffcc925..156d6f9 100644 --- a/sqlconvert/__version__.py +++ b/sqlconvert/__version__.py @@ -1 +1 @@ -__version__ = '0.0.3' +__version__ = '0.0.4' diff --git a/sqlconvert/process_mysql.py b/sqlconvert/process_mysql.py index b11eccb..5c63d0d 100644 --- a/sqlconvert/process_mysql.py +++ b/sqlconvert/process_mysql.py @@ -1,11 +1,12 @@ -from sqlparse.tokens import Name +from sqlparse.sql import Comment +from sqlparse import tokens as T def requote_names(token_list): """Remove backticks, quote non-lowercase identifiers""" for token in token_list.flatten(): - if token.ttype is Name: + if token.ttype is T.Name: value = token.value if (value[0] == "`") and (value[-1] == "`"): value = value[1:-1] @@ -13,3 +14,23 @@ def requote_names(token_list): token.normalized = token.value = value else: token.normalized = token.value = '"%s"' % value + + +def remove_directives(statement): + """Remove /*! directives */ from the first-level""" + new_tokens = [] + for token in statement.tokens: + if isinstance(token, Comment): + subtokens = token.tokens + if subtokens: + comment = subtokens[0] + if comment.ttype is T.Comment.Multiline and \ + comment.value.startswith('/*!'): + continue + new_tokens.append(token) + statement.tokens = new_tokens + + +def process_statement(statement): + requote_names(statement) + remove_directives(statement) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 692121e..3b0452a 100755 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -5,17 +5,12 @@ import unittest from sqlparse import parse from sqlconvert.print_tokens import tlist2str -from sqlconvert.process_mysql import requote_names +from sqlconvert.process_mysql import requote_names, remove_directives, \ + process_statement from tests import main class TestTokens(unittest.TestCase): - def test_requote(self): - parsed = parse("select * from `T`")[0] - requote_names(parsed) - query = tlist2str(parsed) - self.assertEqual(query, 'SELECT * FROM "T"') - def test_encoding(self): parsed = parse("insert into test (1, 'тест')", 'utf-8')[0] query = tlist2str(parsed).encode('utf-8') @@ -27,6 +22,24 @@ class TestTokens(unittest.TestCase): query = tlist2str(parsed) self.assertEqual(query, u"INSERT INTO test (1, 'тест')") + def test_requote(self): + parsed = parse("select * from `T`")[0] + requote_names(parsed) + query = tlist2str(parsed) + self.assertEqual(query, 'SELECT * FROM "T"') + + def test_directives(self): + parsed = parse("select /*! test */ * from /* test */ `T`")[0] + remove_directives(parsed) + query = tlist2str(parsed) + self.assertEqual(query, 'SELECT * FROM /* test */ `T`') + + def test_process(self): + parsed = parse("select /*! test */ * from /* test */ `T`")[0] + process_statement(parsed) + query = tlist2str(parsed) + self.assertEqual(query, 'SELECT * FROM /* test */ "T"') + if __name__ == "__main__": main() -- 2.39.5