-class TestStGrouper(unittest.TestCase):
- def test_incomplete(self):
- grouper = StatementGrouper()
- parsed = parse("select * from `T`")[0]
- grouper.process(parsed)
- self.assertFalse(grouper.statements)
- self.assertEqual(len(grouper.statements), 0)
- self.assertRaises(ValueError, grouper.close)
-
- def test_statements(self):
- grouper = StatementGrouper()
- parsed = parse("select * from `T`;")[0]
- grouper.process(parsed)
- self.assertTrue(grouper.statements)
- self.assertEqual(len(grouper.statements), 1)
- g = grouper.get_statements()
- statement = next(g)
- requote_names(statement)
- query = tlist2str(parsed)
- self.assertEqual(query, 'SELECT * FROM "T";')
- self.assertRaises(StopIteration, next, g)
- self.assertEqual(len(grouper.statements), 0)
- self.assertEqual(grouper.close(), [])
-
-if __name__ == "__main__":
- main()
+def test_statements():
+ grouper = StatementGrouper()
+ grouper.process_line("select * from T;")
+ assert grouper.statements
+ assert len(grouper.statements) == 1
+ for statement in grouper.get_statements():
+ query = tlist2str(statement)
+ assert query == 'SELECT * FROM T;'
+ assert len(grouper.statements) == 0
+ assert grouper.close() is None