+from sqlparse import parse
from sqlparse.sql import Statement
from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \
Whitespace
class StatementGrouper(object):
def __init__(self):
- self.tokens = []
self.statements = []
+ self.tokens = []
+ self.lines = []
- def get_statements(self):
- for statement in self.statements:
- yield statement
- self.statements = []
+ def process_line(self, line):
+ lines = self.lines
+ lines.append(line)
+ tokens = parse('\n'.join(lines))[0]
+ self.process_tokens(tokens)
+ self.lines = []
- def process(self, tokens):
+ def process_tokens(self, tokens):
for token in tokens:
self.tokens.append(token)
if (token.ttype == Punctuation) and (token.value == ';'):
self.statements.append(Statement(self.tokens))
self.tokens = []
+ def get_statements(self):
+ for statement in self.statements:
+ yield statement
+ self.statements = []
+
def close(self):
for token in self.tokens:
if (token.ttype not in (Comment.Single, Comment.Multiline,
from __future__ import print_function
import sys
-from sqlparse import parse
from mysql2sql.print_tokens import print_tokens
from mysql2sql.process_tokens import requote_names, find_error, \
StatementGrouper
grouper = StatementGrouper()
with open(filename) as infile:
for line in infile:
- grouper.process(parse(line)[0])
+ grouper.process_line(line)
if grouper.statements:
for statement in grouper.get_statements():
print("----------")
from __future__ import print_function
import sys
-from sqlparse import parse
from mysql2sql.print_tokens import print_tokens
from mysql2sql.process_tokens import requote_names, find_error, \
StatementGrouper
def main(*queries):
grouper = StatementGrouper()
for query in queries:
- grouper.process(parse(query)[0])
+ grouper.process_line(query)
if grouper.statements:
for statement in grouper.get_statements():
print("----------")
import unittest
-from sqlparse import parse
from mysql2sql.print_tokens import tlist2str
from mysql2sql.process_tokens import requote_names, StatementGrouper
class TestStGrouper(unittest.TestCase):
def test_incomplete(self):
grouper = StatementGrouper()
- parsed = parse("select * from `T`")[0]
- grouper.process(parsed)
+ grouper.process_line("select * from `T`")
self.assertFalse(grouper.statements)
self.assertEqual(len(grouper.statements), 0)
self.assertRaises(ValueError, grouper.close)
def test_statements(self):
grouper = StatementGrouper()
- parsed = parse("select * from `T`;")[0]
- grouper.process(parsed)
+ grouper.process_line("select * from `T`;")
self.assertTrue(grouper.statements)
self.assertEqual(len(grouper.statements), 1)
for statement in grouper.get_statements():