From 4c93c3d89685aba33fc45082022373eb93b6583e Mon Sep 17 00:00:00 2001 From: Oleg Broytman Date: Sat, 3 Sep 2016 01:26:02 +0300 Subject: [PATCH] Use encoding (default is utf-8) and unicode --- demo/sample.sql | 2 +- docs/index.rst | 8 +++++++- mysql2sql/print_tokens.py | 11 ++++++++--- mysql2sql/process_tokens.py | 7 ++++--- requirements.txt | 2 ++ scripts/mysql-to-sql.py | 31 ++++++++++++++++++++++++------- tests/test_tokens.py | 12 +++++++++++- 7 files changed, 57 insertions(+), 16 deletions(-) diff --git a/demo/sample.sql b/demo/sample.sql index fb31ec3..64818c3 100644 --- a/demo/sample.sql +++ b/demo/sample.sql @@ -1,5 +1,5 @@ SELECT * FROM `mytable`; -- line-comment" -INSERT into /* inline comment */ mytable VALUES (1, 'one'); +INSERT into /* inline comment */ mytable VALUES (1, 'тест'); /*! directive*/ INSERT INTO `MyTable` (`Id`, `Name`) VALUES (1, 'one'); diff --git a/docs/index.rst b/docs/index.rst index 356f1d1..49cb311 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,10 +25,16 @@ mysql-to-sql.py Usage:: - mysql-to-sql.py [infile] [[-o] outfile] + mysql-to-sql.py [-e encoding] [-E output_encoding] [infile] [[-o] outfile] Options:: + -e ENCODING, --encoding ENCODING + input/output encoding, default is utf-8 + -E OUTPUT_ENCODING, --output-encoding OUTPUT_ENCODING + separate output encoding, default is the same as + `-e` except for console; for console output charset + from the current locale is used infile Input file, stdin if absent or '-' -o, --outfile outfile Output file, stdout if absent or '-' diff --git a/mysql2sql/print_tokens.py b/mysql2sql/print_tokens.py index 142391f..3e2b0d5 100644 --- a/mysql2sql/print_tokens.py +++ b/mysql2sql/print_tokens.py @@ -2,10 +2,15 @@ import sys -def print_tokens(token_list, outfile=sys.stdout): +def print_tokens(token_list, outfile=sys.stdout, encoding=None): + if encoding: + outfile = getattr(outfile, 'buffer', outfile) for token in token_list.flatten(): - outfile.write(token.normalized) + normalized = token.normalized + if encoding: + normalized = normalized.encode(encoding) + outfile.write(normalized) def tlist2str(token_list): - return ''.join(token.normalized for token in token_list.flatten()) + return u''.join(token.normalized for token in token_list.flatten()) diff --git a/mysql2sql/process_tokens.py b/mysql2sql/process_tokens.py index 9e1e760..1e74ac9 100644 --- a/mysql2sql/process_tokens.py +++ b/mysql2sql/process_tokens.py @@ -33,16 +33,17 @@ if PY3: class StatementGrouper(object): """Collect lines and reparse until the last statement is complete""" - def __init__(self): + def __init__(self, encoding=None): self.lines = [] self.statements = [] + self.encoding = encoding def process_line(self, line): self.lines.append(line) self.process_lines() def process_lines(self): - statements = parse(''.join(self.lines)) + statements = parse(''.join(self.lines), encoding=self.encoding) last_stmt = statements[-1] for i in xrange(len(last_stmt.tokens) - 1, 0, -1): token = last_stmt.tokens[i] @@ -64,7 +65,7 @@ class StatementGrouper(object): def close(self): if not self.lines: return - tokens = parse(''.join(self.lines)) + tokens = parse(''.join(self.lines), encoding=self.encoding) for token in tokens: if (token.ttype not in (Comment.Single, Comment.Multiline, Newline, Whitespace)): diff --git a/requirements.txt b/requirements.txt index 90c942b..f1bf8b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ argparse; python_version == '2.6' sqlparse +m_lib>=2.0; python_version >= '2.6' and python_version < '3.0' +m_lib>=3.0; python_version >= '3.4' diff --git a/scripts/mysql-to-sql.py b/scripts/mysql-to-sql.py index c40563f..23197ff 100755 --- a/scripts/mysql-to-sql.py +++ b/scripts/mysql-to-sql.py @@ -2,28 +2,38 @@ from __future__ import print_function import argparse +from io import open import sys from mysql2sql.print_tokens import print_tokens from mysql2sql.process_tokens import requote_names, StatementGrouper +from m_lib.defenc import default_encoding -def main(infile, outfile): - grouper = StatementGrouper() + +def main(infile, encoding, outfile, output_encoding): + grouper = StatementGrouper(encoding=encoding) for line in infile: grouper.process_line(line) if grouper.statements: for statement in grouper.get_statements(): requote_names(statement) - print_tokens(statement, outfile=outfile) + print_tokens(statement, outfile=outfile, + encoding=output_encoding) tokens = grouper.close() if tokens: for token in tokens: - print_tokens(token, outfile=outfile) + print_tokens(token, outfile=outfile, encoding=output_encoding) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Convert MySQL to SQL') + parser.add_argument('-e', '--encoding', default='utf-8', + help='input/output encoding, default is utf-8') + parser.add_argument('-E', '--output-encoding', + help='separate output encoding, default is the same ' + 'as -e except for console; for console output ' + 'charset from the current locale is used') parser.add_argument('-o', '--outfile', help='output file name') parser.add_argument('infile', help='input file name') parser.add_argument('output_file', nargs='?', help='output file name') @@ -33,7 +43,7 @@ if __name__ == '__main__': if args.infile == '-': infile = sys.stdin else: - infile = open(args.infile, 'rt') + infile = open(args.infile, 'rt', encoding=args.encoding) else: infile = sys.stdin @@ -56,14 +66,21 @@ if __name__ == '__main__': else: outfile = '-' + if args.output_encoding: + output_encoding = args.output_encoding + elif outfile == '-': + output_encoding = default_encoding + else: + output_encoding = args.encoding + if outfile == '-': outfile = sys.stdout else: try: - outfile = open(outfile, 'wt') + outfile = open(outfile, 'wt', encoding=output_encoding) except: if infile is not sys.stdin: infile.close() raise - main(infile, outfile) + main(infile, args.encoding, outfile, output_encoding) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index c39cd31..83951c5 100755 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,5 +1,5 @@ #! /usr/bin/env python - +# -*- coding: utf-8 -*- import unittest from sqlparse import parse @@ -16,6 +16,16 @@ class TestTokens(unittest.TestCase): query = tlist2str(parsed) self.assertEqual(query, 'SELECT * FROM "T"') + def test_encoding(self): + parsed = parse("insert into test (1, 'тест')", 'utf-8')[0] + query = tlist2str(parsed).encode('utf-8') + self.assertEqual(query, "INSERT INTO test (1, 'тест')") + + def test_unicode(self): + parsed = parse(u"insert into test (1, 'тест')")[0] + query = tlist2str(parsed) + self.assertEqual(query, u"INSERT INTO test (1, 'тест')") + if __name__ == "__main__": main() -- 2.39.2