From: Oleg Broytman Date: Sun, 21 Aug 2016 08:09:55 +0000 (+0300) Subject: Group statements separated by semicolons X-Git-Tag: 0.0.1~26 X-Git-Url: https://git.phdru.name/?p=sqlconvert.git;a=commitdiff_plain;h=031cc0d6a41717d4c5d7c4659290e05810202eb9 Group statements separated by semicolons --- diff --git a/mysql2sql/process_tokens.py b/mysql2sql/process_tokens.py index ac9930e..94879cc 100644 --- a/mysql2sql/process_tokens.py +++ b/mysql2sql/process_tokens.py @@ -1,6 +1,6 @@ -from sqlparse.sql import TokenList -from sqlparse.tokens import Name, Error +from sqlparse.sql import Statement +from sqlparse.tokens import Name, Error, Punctuation def requote_names(token_list): @@ -22,3 +22,25 @@ def find_error(token_list): if token.ttype is Error: return True return False + + +class StatementGrouper(object): + def __init__(self): + self.tokens = [] + self.statements = [] + + def get_statements(self): + for statement in self.statements: + yield statement + self.statements = [] + + def process(self, tokens): + for token in tokens: + self.tokens.append(token) + if (token.ttype == Punctuation) and (token.value == ';'): + self.statements.append(Statement(self.tokens)) + self.tokens = [] + + def close(self): + if self.tokens: + raise ValueError("Incomplete SQL statement") diff --git a/scripts/group-file.py b/scripts/group-file.py new file mode 100755 index 0000000..41f9b33 --- /dev/null +++ b/scripts/group-file.py @@ -0,0 +1,32 @@ +#! /usr/bin/env python +from __future__ import print_function + +import sys +from sqlparse import parse +from mysql2sql.print_tokens import print_tokens +from mysql2sql.process_tokens import requote_names, find_error, \ + StatementGrouper + + +def main(filename): + grouper = StatementGrouper() + with open(filename) as infile: + for line in infile: + grouper.process(parse(line)[0]) + if grouper.statements: + for statement in grouper.get_statements(): + print("----------") + if find_error(statement): + print("ERRORS IN QUERY") + requote_names(statement) + print_tokens(statement) + print() + statement._pprint_tree() + print("----------") + grouper.close() + + +if __name__ == '__main__': + if len(sys.argv) <= 1: + sys.exit("Usage: %s file" % sys.argv[0]) + main(sys.argv[1]) diff --git a/scripts/group-sql.py b/scripts/group-sql.py new file mode 100755 index 0000000..f1f0988 --- /dev/null +++ b/scripts/group-sql.py @@ -0,0 +1,45 @@ +#! /usr/bin/env python +from __future__ import print_function + +import sys +from sqlparse import parse +from mysql2sql.print_tokens import print_tokens +from mysql2sql.process_tokens import requote_names, find_error, \ + StatementGrouper + + +def main(*queries): + grouper = StatementGrouper() + for query in queries: + grouper.process(parse(query)[0]) + if grouper.statements: + for statement in grouper.get_statements(): + print("----------") + if find_error(statement): + print("ERRORS IN QUERY") + requote_names(statement) + print_tokens(statement) + print() + statement._pprint_tree() + print("----------") + grouper.close() + + +def test(): + main( + "SELECT * FROM `mytable`; -- line-comment", + "INSERT into /* inline comment */ mytable VALUES (1, 'one');", + "/*! directive*/ INSERT INTO `MyTable` (`Id`, `Name`) " + "VALUES (1, 'one');" + ) + + +if __name__ == '__main__': + if len(sys.argv) <= 1: + sys.exit("Usage: %s [-t | sql_query_string [; sql_query_string ...]]" % + sys.argv[0]) + if sys.argv[1] == '-t': + test() + else: + queries = sys.argv[1:] + main(*queries) diff --git a/tests/test_stgrouper.py b/tests/test_stgrouper.py new file mode 100755 index 0000000..0e72ced --- /dev/null +++ b/tests/test_stgrouper.py @@ -0,0 +1,37 @@ +#! /usr/bin/env python + + +import unittest +from sqlparse import parse + +from mysql2sql.print_tokens import tlist2str +from mysql2sql.process_tokens import requote_names, StatementGrouper +from tests import main + + +class TestStGrouper(unittest.TestCase): + def test_incomplete(self): + grouper = StatementGrouper() + parsed = parse("select * from `T`")[0] + grouper.process(parsed) + self.assertFalse(grouper.statements) + self.assertEqual(len(grouper.statements), 0) + self.assertRaises(ValueError, grouper.close) + + def test_statements(self): + grouper = StatementGrouper() + parsed = parse("select * from `T`;")[0] + grouper.process(parsed) + self.assertTrue(grouper.statements) + self.assertEqual(len(grouper.statements), 1) + g = grouper.get_statements() + statement = next(g) + requote_names(statement) + query = tlist2str(parsed) + self.assertEqual(query, 'SELECT * FROM "T";') + self.assertRaises(StopIteration, next, g) + self.assertEqual(len(grouper.statements), 0) + self.assertIsNone(grouper.close()) + +if __name__ == "__main__": + main()