-from sqlparse.sql import TokenList
-from sqlparse.tokens import Name, Error
+from sqlparse.sql import Statement
+from sqlparse.tokens import Name, Error, Punctuation
def requote_names(token_list):
if token.ttype is Error:
return True
return False
+
+
+class StatementGrouper(object):
+ def __init__(self):
+ self.tokens = []
+ self.statements = []
+
+ def get_statements(self):
+ for statement in self.statements:
+ yield statement
+ self.statements = []
+
+ def process(self, tokens):
+ for token in tokens:
+ self.tokens.append(token)
+ if (token.ttype == Punctuation) and (token.value == ';'):
+ self.statements.append(Statement(self.tokens))
+ self.tokens = []
+
+ def close(self):
+ if self.tokens:
+ raise ValueError("Incomplete SQL statement")
--- /dev/null
+#! /usr/bin/env python
+from __future__ import print_function
+
+import sys
+from sqlparse import parse
+from mysql2sql.print_tokens import print_tokens
+from mysql2sql.process_tokens import requote_names, find_error, \
+ StatementGrouper
+
+
+def main(filename):
+ grouper = StatementGrouper()
+ with open(filename) as infile:
+ for line in infile:
+ grouper.process(parse(line)[0])
+ if grouper.statements:
+ for statement in grouper.get_statements():
+ print("----------")
+ if find_error(statement):
+ print("ERRORS IN QUERY")
+ requote_names(statement)
+ print_tokens(statement)
+ print()
+ statement._pprint_tree()
+ print("----------")
+ grouper.close()
+
+
+if __name__ == '__main__':
+ if len(sys.argv) <= 1:
+ sys.exit("Usage: %s file" % sys.argv[0])
+ main(sys.argv[1])
--- /dev/null
+#! /usr/bin/env python
+from __future__ import print_function
+
+import sys
+from sqlparse import parse
+from mysql2sql.print_tokens import print_tokens
+from mysql2sql.process_tokens import requote_names, find_error, \
+ StatementGrouper
+
+
+def main(*queries):
+ grouper = StatementGrouper()
+ for query in queries:
+ grouper.process(parse(query)[0])
+ if grouper.statements:
+ for statement in grouper.get_statements():
+ print("----------")
+ if find_error(statement):
+ print("ERRORS IN QUERY")
+ requote_names(statement)
+ print_tokens(statement)
+ print()
+ statement._pprint_tree()
+ print("----------")
+ grouper.close()
+
+
+def test():
+ main(
+ "SELECT * FROM `mytable`; -- line-comment",
+ "INSERT into /* inline comment */ mytable VALUES (1, 'one');",
+ "/*! directive*/ INSERT INTO `MyTable` (`Id`, `Name`) "
+ "VALUES (1, 'one');"
+ )
+
+
+if __name__ == '__main__':
+ if len(sys.argv) <= 1:
+ sys.exit("Usage: %s [-t | sql_query_string [; sql_query_string ...]]" %
+ sys.argv[0])
+ if sys.argv[1] == '-t':
+ test()
+ else:
+ queries = sys.argv[1:]
+ main(*queries)
--- /dev/null
+#! /usr/bin/env python
+
+
+import unittest
+from sqlparse import parse
+
+from mysql2sql.print_tokens import tlist2str
+from mysql2sql.process_tokens import requote_names, StatementGrouper
+from tests import main
+
+
+class TestStGrouper(unittest.TestCase):
+ def test_incomplete(self):
+ grouper = StatementGrouper()
+ parsed = parse("select * from `T`")[0]
+ grouper.process(parsed)
+ self.assertFalse(grouper.statements)
+ self.assertEqual(len(grouper.statements), 0)
+ self.assertRaises(ValueError, grouper.close)
+
+ def test_statements(self):
+ grouper = StatementGrouper()
+ parsed = parse("select * from `T`;")[0]
+ grouper.process(parsed)
+ self.assertTrue(grouper.statements)
+ self.assertEqual(len(grouper.statements), 1)
+ g = grouper.get_statements()
+ statement = next(g)
+ requote_names(statement)
+ query = tlist2str(parsed)
+ self.assertEqual(query, 'SELECT * FROM "T";')
+ self.assertRaises(StopIteration, next, g)
+ self.assertEqual(len(grouper.statements), 0)
+ self.assertIsNone(grouper.close())
+
+if __name__ == "__main__":
+ main()