]> git.phdru.name Git - sqlconvert.git/commitdiff
Group statements separated by semicolons
authorOleg Broytman <phd@phdru.name>
Sun, 21 Aug 2016 08:09:55 +0000 (11:09 +0300)
committerOleg Broytman <phd@phdru.name>
Sun, 21 Aug 2016 08:30:51 +0000 (11:30 +0300)
mysql2sql/process_tokens.py
scripts/group-file.py [new file with mode: 0755]
scripts/group-sql.py [new file with mode: 0755]
tests/test_stgrouper.py [new file with mode: 0755]

index ac9930eb81963b1e0c376ff67af035aaaca37679..94879cc555ea2bbbb399ecb655657dfead8ceab0 100644 (file)
@@ -1,6 +1,6 @@
 
-from sqlparse.sql import TokenList
-from sqlparse.tokens import Name, Error
+from sqlparse.sql import Statement
+from sqlparse.tokens import Name, Error, Punctuation
 
 
 def requote_names(token_list):
@@ -22,3 +22,25 @@ def find_error(token_list):
         if token.ttype is Error:
             return True
     return False
+
+
+class StatementGrouper(object):
+    def __init__(self):
+        self.tokens = []
+        self.statements = []
+
+    def get_statements(self):
+        for statement in self.statements:
+            yield statement
+        self.statements = []
+
+    def process(self, tokens):
+        for token in tokens:
+            self.tokens.append(token)
+            if (token.ttype == Punctuation) and (token.value == ';'):
+                self.statements.append(Statement(self.tokens))
+                self.tokens = []
+
+    def close(self):
+        if self.tokens:
+            raise ValueError("Incomplete SQL statement")
diff --git a/scripts/group-file.py b/scripts/group-file.py
new file mode 100755 (executable)
index 0000000..41f9b33
--- /dev/null
@@ -0,0 +1,32 @@
+#! /usr/bin/env python
+from __future__ import print_function
+
+import sys
+from sqlparse import parse
+from mysql2sql.print_tokens import print_tokens
+from mysql2sql.process_tokens import requote_names, find_error, \
+    StatementGrouper
+
+
+def main(filename):
+    grouper = StatementGrouper()
+    with open(filename) as infile:
+        for line in infile:
+            grouper.process(parse(line)[0])
+            if grouper.statements:
+                for statement in grouper.get_statements():
+                    print("----------")
+                    if find_error(statement):
+                        print("ERRORS IN QUERY")
+                    requote_names(statement)
+                    print_tokens(statement)
+                    print()
+                    statement._pprint_tree()
+                print("----------")
+    grouper.close()
+
+
+if __name__ == '__main__':
+    if len(sys.argv) <= 1:
+        sys.exit("Usage: %s file" % sys.argv[0])
+    main(sys.argv[1])
diff --git a/scripts/group-sql.py b/scripts/group-sql.py
new file mode 100755 (executable)
index 0000000..f1f0988
--- /dev/null
@@ -0,0 +1,45 @@
+#! /usr/bin/env python
+from __future__ import print_function
+
+import sys
+from sqlparse import parse
+from mysql2sql.print_tokens import print_tokens
+from mysql2sql.process_tokens import requote_names, find_error, \
+    StatementGrouper
+
+
+def main(*queries):
+    grouper = StatementGrouper()
+    for query in queries:
+        grouper.process(parse(query)[0])
+        if grouper.statements:
+            for statement in grouper.get_statements():
+                print("----------")
+                if find_error(statement):
+                    print("ERRORS IN QUERY")
+                requote_names(statement)
+                print_tokens(statement)
+                print()
+                statement._pprint_tree()
+            print("----------")
+    grouper.close()
+
+
+def test():
+    main(
+        "SELECT * FROM `mytable`; -- line-comment",
+        "INSERT into /* inline comment */ mytable VALUES (1, 'one');",
+        "/*! directive*/ INSERT INTO `MyTable` (`Id`, `Name`) "
+        "VALUES (1, 'one');"
+    )
+
+
+if __name__ == '__main__':
+    if len(sys.argv) <= 1:
+        sys.exit("Usage: %s [-t | sql_query_string [; sql_query_string ...]]" %
+                 sys.argv[0])
+    if sys.argv[1] == '-t':
+        test()
+    else:
+        queries = sys.argv[1:]
+        main(*queries)
diff --git a/tests/test_stgrouper.py b/tests/test_stgrouper.py
new file mode 100755 (executable)
index 0000000..0e72ced
--- /dev/null
@@ -0,0 +1,37 @@
+#! /usr/bin/env python
+
+
+import unittest
+from sqlparse import parse
+
+from mysql2sql.print_tokens import tlist2str
+from mysql2sql.process_tokens import requote_names, StatementGrouper
+from tests import main
+
+
+class TestStGrouper(unittest.TestCase):
+    def test_incomplete(self):
+        grouper = StatementGrouper()
+        parsed = parse("select * from `T`")[0]
+        grouper.process(parsed)
+        self.assertFalse(grouper.statements)
+        self.assertEqual(len(grouper.statements), 0)
+        self.assertRaises(ValueError, grouper.close)
+
+    def test_statements(self):
+        grouper = StatementGrouper()
+        parsed = parse("select * from `T`;")[0]
+        grouper.process(parsed)
+        self.assertTrue(grouper.statements)
+        self.assertEqual(len(grouper.statements), 1)
+        g = grouper.get_statements()
+        statement = next(g)
+        requote_names(statement)
+        query = tlist2str(parsed)
+        self.assertEqual(query, 'SELECT * FROM "T";')
+        self.assertRaises(StopIteration, next, g)
+        self.assertEqual(len(grouper.statements), 0)
+        self.assertIsNone(grouper.close())
+
+if __name__ == "__main__":
+    main()