]> git.phdru.name Git - sqlconvert.git/commitdiff
Process input stream line by line
authorOleg Broytman <phd@phdru.name>
Wed, 24 Aug 2016 14:44:41 +0000 (17:44 +0300)
committerOleg Broytman <phd@phdru.name>
Wed, 24 Aug 2016 17:21:17 +0000 (20:21 +0300)
mysql2sql/process_tokens.py
scripts/group-file.py
scripts/group-sql.py
tests/test_stgrouper.py

index 7f30a5589b7fad13265b4a9e147fe6308afba545..37752e7c146a8d76ad63f9b7ccbbe395066eec4e 100644 (file)
@@ -1,4 +1,5 @@
 
+from sqlparse import parse
 from sqlparse.sql import Statement
 from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \
     Whitespace
@@ -27,21 +28,29 @@ def find_error(token_list):
 
 class StatementGrouper(object):
     def __init__(self):
-        self.tokens = []
         self.statements = []
+        self.tokens = []
+        self.lines = []
 
-    def get_statements(self):
-        for statement in self.statements:
-            yield statement
-        self.statements = []
+    def process_line(self, line):
+        lines = self.lines
+        lines.append(line)
+        tokens = parse('\n'.join(lines))[0]
+        self.process_tokens(tokens)
+        self.lines = []
 
-    def process(self, tokens):
+    def process_tokens(self, tokens):
         for token in tokens:
             self.tokens.append(token)
             if (token.ttype == Punctuation) and (token.value == ';'):
                 self.statements.append(Statement(self.tokens))
                 self.tokens = []
 
+    def get_statements(self):
+        for statement in self.statements:
+            yield statement
+        self.statements = []
+
     def close(self):
         for token in self.tokens:
             if (token.ttype not in (Comment.Single, Comment.Multiline,
index 53a11f10fb4362b5df18da0a9174a3b5f613491f..92f8fb9100d8839d7f0150e899db8c5886d231d8 100755 (executable)
@@ -2,7 +2,6 @@
 from __future__ import print_function
 
 import sys
-from sqlparse import parse
 from mysql2sql.print_tokens import print_tokens
 from mysql2sql.process_tokens import requote_names, find_error, \
     StatementGrouper
@@ -12,7 +11,7 @@ def main(filename):
     grouper = StatementGrouper()
     with open(filename) as infile:
         for line in infile:
-            grouper.process(parse(line)[0])
+            grouper.process_line(line)
             if grouper.statements:
                 for statement in grouper.get_statements():
                     print("----------")
index 0eab6b4042253cba3a629df455841924e2935c22..850dcb30b2ad1fd72f79274515feff33517f4c7b 100755 (executable)
@@ -2,7 +2,6 @@
 from __future__ import print_function
 
 import sys
-from sqlparse import parse
 from mysql2sql.print_tokens import print_tokens
 from mysql2sql.process_tokens import requote_names, find_error, \
     StatementGrouper
@@ -11,7 +10,7 @@ from mysql2sql.process_tokens import requote_names, find_error, \
 def main(*queries):
     grouper = StatementGrouper()
     for query in queries:
-        grouper.process(parse(query)[0])
+        grouper.process_line(query)
         if grouper.statements:
             for statement in grouper.get_statements():
                 print("----------")
index cb4f88f0aa30194d3cac8d7d351adb34c9459044..bd60168c020300a0969de0a142c9f9c3fb405f17 100755 (executable)
@@ -2,7 +2,6 @@
 
 
 import unittest
-from sqlparse import parse
 
 from mysql2sql.print_tokens import tlist2str
 from mysql2sql.process_tokens import requote_names, StatementGrouper
@@ -12,16 +11,14 @@ from tests import main
 class TestStGrouper(unittest.TestCase):
     def test_incomplete(self):
         grouper = StatementGrouper()
-        parsed = parse("select * from `T`")[0]
-        grouper.process(parsed)
+        grouper.process_line("select * from `T`")
         self.assertFalse(grouper.statements)
         self.assertEqual(len(grouper.statements), 0)
         self.assertRaises(ValueError, grouper.close)
 
     def test_statements(self):
         grouper = StatementGrouper()
-        parsed = parse("select * from `T`;")[0]
-        grouper.process(parsed)
+        grouper.process_line("select * from `T`;")
         self.assertTrue(grouper.statements)
         self.assertEqual(len(grouper.statements), 1)
         for statement in grouper.get_statements():