]> git.phdru.name Git - sqlconvert.git/blobdiff - tests/test_tokens.py
Add MySQL-specific remove_directives() and process_statement()
[sqlconvert.git] / tests / test_tokens.py
index 83951c5cc3152d89ed965351804ff5076bd14924..3b0452a50cdcdc4c770b6ffbbfb653d29fc5286f 100755 (executable)
@@ -4,28 +4,42 @@
 import unittest
 from sqlparse import parse
 
-from mysql2sql.process_tokens import requote_names
-from mysql2sql.print_tokens import tlist2str
+from sqlconvert.print_tokens import tlist2str
+from sqlconvert.process_mysql import requote_names, remove_directives, \
+        process_statement
 from tests import main
 
 
 class TestTokens(unittest.TestCase):
-    def test_requote(self):
-        parsed = parse("select * from `T`")[0]
-        requote_names(parsed)
-        query = tlist2str(parsed)
-        self.assertEqual(query, 'SELECT * FROM "T"')
-
     def test_encoding(self):
         parsed = parse("insert into test (1, 'тест')", 'utf-8')[0]
         query = tlist2str(parsed).encode('utf-8')
-        self.assertEqual(query, "INSERT INTO test (1, 'тест')")
+        self.assertEqual(query,
+                         u"INSERT INTO test (1, 'тест')".encode('utf-8'))
 
     def test_unicode(self):
         parsed = parse(u"insert into test (1, 'тест')")[0]
         query = tlist2str(parsed)
         self.assertEqual(query, u"INSERT INTO test (1, 'тест')")
 
+    def test_requote(self):
+        parsed = parse("select * from `T`")[0]
+        requote_names(parsed)
+        query = tlist2str(parsed)
+        self.assertEqual(query, 'SELECT * FROM "T"')
+
+    def test_directives(self):
+        parsed = parse("select /*! test */ * from /* test */ `T`")[0]
+        remove_directives(parsed)
+        query = tlist2str(parsed)
+        self.assertEqual(query, 'SELECT * FROM /* test */ `T`')
+
+    def test_process(self):
+        parsed = parse("select /*! test */ * from /* test */ `T`")[0]
+        process_statement(parsed)
+        query = tlist2str(parsed)
+        self.assertEqual(query, 'SELECT * FROM /* test */ "T"')
+
 
 if __name__ == "__main__":
     main()