Use encoding (default is utf-8) and unicode
authorOleg Broytman <phd@phdru.name>
Fri, 2 Sep 2016 22:26:02 +0000 (01:26 +0300)
committerOleg Broytman <phd@phdru.name>
Fri, 2 Sep 2016 22:26:02 +0000 (01:26 +0300)
demo/sample.sql
docs/index.rst
mysql2sql/print_tokens.py
mysql2sql/process_tokens.py
requirements.txt
scripts/mysql-to-sql.py
tests/test_tokens.py

index fb31ec3..64818c3 100644 (file)
@@ -1,5 +1,5 @@
 SELECT * FROM `mytable`; -- line-comment"
-INSERT into /* inline comment */ mytable VALUES (1, 'one');
+INSERT into /* inline comment */ mytable VALUES (1, 'тест');
 /*! directive*/ INSERT INTO `MyTable` (`Id`, `Name`)
 VALUES (1, 'one');
 
index 356f1d1..49cb311 100644 (file)
@@ -25,10 +25,16 @@ mysql-to-sql.py
 
 Usage::
 
-    mysql-to-sql.py [infile] [[-o] outfile]
+    mysql-to-sql.py [-e encoding] [-E output_encoding] [infile] [[-o] outfile]
 
 Options::
 
+   -e ENCODING, --encoding ENCODING
+                           input/output encoding, default is utf-8
+   -E OUTPUT_ENCODING, --output-encoding OUTPUT_ENCODING
+                           separate output encoding, default is the same as
+                           `-e` except for console; for console output charset
+                           from the current locale is used
     infile                 Input file, stdin if absent or '-'
     -o, --outfile outfile  Output file, stdout if absent or '-'
 
index 142391f..3e2b0d5 100644 (file)
@@ -2,10 +2,15 @@
 import sys
 
 
-def print_tokens(token_list, outfile=sys.stdout):
+def print_tokens(token_list, outfile=sys.stdout, encoding=None):
+    if encoding:
+        outfile = getattr(outfile, 'buffer', outfile)
     for token in token_list.flatten():
-        outfile.write(token.normalized)
+        normalized = token.normalized
+        if encoding:
+            normalized = normalized.encode(encoding)
+        outfile.write(normalized)
 
 
 def tlist2str(token_list):
-    return ''.join(token.normalized for token in token_list.flatten())
+    return u''.join(token.normalized for token in token_list.flatten())
index 9e1e760..1e74ac9 100644 (file)
@@ -33,16 +33,17 @@ if PY3:
 class StatementGrouper(object):
     """Collect lines and reparse until the last statement is complete"""
 
-    def __init__(self):
+    def __init__(self, encoding=None):
         self.lines = []
         self.statements = []
+        self.encoding = encoding
 
     def process_line(self, line):
         self.lines.append(line)
         self.process_lines()
 
     def process_lines(self):
-        statements = parse(''.join(self.lines))
+        statements = parse(''.join(self.lines), encoding=self.encoding)
         last_stmt = statements[-1]
         for i in xrange(len(last_stmt.tokens) - 1, 0, -1):
             token = last_stmt.tokens[i]
@@ -64,7 +65,7 @@ class StatementGrouper(object):
     def close(self):
         if not self.lines:
             return
-        tokens = parse(''.join(self.lines))
+        tokens = parse(''.join(self.lines), encoding=self.encoding)
         for token in tokens:
             if (token.ttype not in (Comment.Single, Comment.Multiline,
                                     Newline, Whitespace)):
index 90c942b..f1bf8b4 100644 (file)
@@ -4,3 +4,5 @@
 
 argparse; python_version == '2.6'
 sqlparse
+m_lib>=2.0; python_version >= '2.6' and python_version < '3.0'
+m_lib>=3.0; python_version >= '3.4'
index c40563f..23197ff 100755 (executable)
@@ -2,28 +2,38 @@
 from __future__ import print_function
 
 import argparse
+from io import open
 import sys
 
 from mysql2sql.print_tokens import print_tokens
 from mysql2sql.process_tokens import requote_names, StatementGrouper
 
+from m_lib.defenc import default_encoding
 
-def main(infile, outfile):
-    grouper = StatementGrouper()
+
+def main(infile, encoding, outfile, output_encoding):
+    grouper = StatementGrouper(encoding=encoding)
     for line in infile:
         grouper.process_line(line)
         if grouper.statements:
             for statement in grouper.get_statements():
                 requote_names(statement)
-                print_tokens(statement, outfile=outfile)
+                print_tokens(statement, outfile=outfile,
+                             encoding=output_encoding)
     tokens = grouper.close()
     if tokens:
         for token in tokens:
-            print_tokens(token, outfile=outfile)
+            print_tokens(token, outfile=outfile, encoding=output_encoding)
 
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Convert MySQL to SQL')
+    parser.add_argument('-e', '--encoding', default='utf-8',
+                        help='input/output encoding, default is utf-8')
+    parser.add_argument('-E', '--output-encoding',
+                        help='separate output encoding, default is the same '
+                        'as -e except for console; for console output '
+                        'charset from the current locale is used')
     parser.add_argument('-o', '--outfile', help='output file name')
     parser.add_argument('infile', help='input file name')
     parser.add_argument('output_file', nargs='?', help='output file name')
@@ -33,7 +43,7 @@ if __name__ == '__main__':
         if args.infile == '-':
             infile = sys.stdin
         else:
-            infile = open(args.infile, 'rt')
+            infile = open(args.infile, 'rt', encoding=args.encoding)
     else:
         infile = sys.stdin
 
@@ -56,14 +66,21 @@ if __name__ == '__main__':
     else:
         outfile = '-'
 
+    if args.output_encoding:
+        output_encoding = args.output_encoding
+    elif outfile == '-':
+        output_encoding = default_encoding
+    else:
+        output_encoding = args.encoding
+
     if outfile == '-':
         outfile = sys.stdout
     else:
         try:
-            outfile = open(outfile, 'wt')
+            outfile = open(outfile, 'wt', encoding=output_encoding)
         except:
             if infile is not sys.stdin:
                 infile.close()
             raise
 
-    main(infile, outfile)
+    main(infile, args.encoding, outfile, output_encoding)
index c39cd31..83951c5 100755 (executable)
@@ -1,5 +1,5 @@
 #! /usr/bin/env python
-
+# -*- coding: utf-8 -*-
 
 import unittest
 from sqlparse import parse
@@ -16,6 +16,16 @@ class TestTokens(unittest.TestCase):
         query = tlist2str(parsed)
         self.assertEqual(query, 'SELECT * FROM "T"')
 
+    def test_encoding(self):
+        parsed = parse("insert into test (1, 'тест')", 'utf-8')[0]
+        query = tlist2str(parsed).encode('utf-8')
+        self.assertEqual(query, "INSERT INTO test (1, 'тест')")
+
+    def test_unicode(self):
+        parsed = parse(u"insert into test (1, 'тест')")[0]
+        query = tlist2str(parsed)
+        self.assertEqual(query, u"INSERT INTO test (1, 'тест')")
+
 
 if __name__ == "__main__":
     main()