]> git.phdru.name Git - sqlconvert.git/blobdiff - scripts/mysql2sql
Build(GHActions): Use `checkout@v4` instead of outdated `v2`
[sqlconvert.git] / scripts / mysql2sql
index 23197ffce6b2cbe7b312d7ab7b0f60d02ba7a7c2..e859f2a533df665041d6320902efb1c3245b0cc2 100755 (executable)
@@ -3,28 +3,78 @@ from __future__ import print_function
 
 import argparse
 from io import open
+import os
 import sys
 
-from mysql2sql.print_tokens import print_tokens
-from mysql2sql.process_tokens import requote_names, StatementGrouper
+from sqlconvert.print_tokens import print_tokens
+from sqlconvert.process_mysql import is_directive_statement, process_statement
+from sqlconvert.process_tokens import is_newline_statement, StatementGrouper
 
 from m_lib.defenc import default_encoding
+from m_lib.pbar.tty_pbar import ttyProgressBar
 
+try:
+    text_type = unicode
+except NameError:
+    text_type = str
+
+
+def get_fsize(fp):
+    try:
+        fp.seek(0, os.SEEK_END)
+    except IOError:
+        return None  # File size is unknown
+    size = fp.tell()
+    fp.seek(0, os.SEEK_SET)
+    return size
+
+
+def main(infile, encoding, outfile, output_encoding, use_pbar, quoting_style):
+    if use_pbar:
+        size = get_fsize(infile)
+        if size is None:
+            use_pbar = False
+
+    if use_pbar:
+        print("Converting", end='', file=sys.stderr)
+        if infile.name != '<stdin>':
+            print(' ' + infile.name, end='', file=sys.stderr)
+        print(": ", end='', file=sys.stderr)
+        sys.stderr.flush()
+
+    if use_pbar:
+        pbar = ttyProgressBar(0, size-1)
+        cur_pos = 0
 
-def main(infile, encoding, outfile, output_encoding):
     grouper = StatementGrouper(encoding=encoding)
+    got_directive = False
     for line in infile:
+        if use_pbar:
+            if isinstance(line, text_type):
+                cur_pos += len(line.encode(encoding))
+            else:
+                cur_pos += len(line)
+            pbar.display(cur_pos)
         grouper.process_line(line)
-        if grouper.statements:
-            for statement in grouper.get_statements():
-                requote_names(statement)
-                print_tokens(statement, outfile=outfile,
+        for statement in grouper.get_statements():
+            if got_directive and is_newline_statement(statement):
+                # Condense a sequence of newlines after a /*! directive */;
+                got_directive = False
+                continue
+            got_directive = is_directive_statement(statement)
+            if got_directive:
+                continue
+            for _statement in process_statement(statement, quoting_style):
+                print_tokens(_statement, outfile=outfile,
                              encoding=output_encoding)
     tokens = grouper.close()
     if tokens:
         for token in tokens:
             print_tokens(token, outfile=outfile, encoding=output_encoding)
 
+    if use_pbar:
+        pbar.erase()
+        print("done.")
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Convert MySQL to SQL')
@@ -34,11 +84,27 @@ if __name__ == '__main__':
                         help='separate output encoding, default is the same '
                         'as -e except for console; for console output '
                         'charset from the current locale is used')
+    parser.add_argument('-m', '--mysql', action='store_true',
+                        help='MySQL/MariaDB quoting style')
+    parser.add_argument('-p', '--pg', '--postgres', action='store_true',
+                        help='PostgreSQL quoting style')
+    parser.add_argument('-s', '--sqlite', action='store_true',
+                        help='Generic SQL/SQLite quoting style; '
+                        'this is the default')
     parser.add_argument('-o', '--outfile', help='output file name')
+    parser.add_argument('-P', '--no-pbar', action='store_true',
+                        help='inhibit progress bar')
     parser.add_argument('infile', help='input file name')
     parser.add_argument('output_file', nargs='?', help='output file name')
     args = parser.parse_args()
 
+    if int(args.mysql) + int(args.pg) + int(args.sqlite) > 1:
+        print("Error: options -m/-p/-s are mutually incompatible, "
+              "use only one of them",
+              file=sys.stderr)
+        parser.print_help()
+        sys.exit(1)
+
     if args.infile:
         if args.infile == '-':
             infile = sys.stdin
@@ -75,6 +141,7 @@ if __name__ == '__main__':
 
     if outfile == '-':
         outfile = sys.stdout
+        args.no_pbar = True
     else:
         try:
             outfile = open(outfile, 'wt', encoding=output_encoding)
@@ -83,4 +150,12 @@ if __name__ == '__main__':
                 infile.close()
             raise
 
-    main(infile, args.encoding, outfile, output_encoding)
+    if args.mysql:
+        quoting_style = 'mysql'
+    elif args.pg:
+        quoting_style = 'postgres'
+    else:
+        quoting_style = 'sqlite'
+
+    main(infile, args.encoding, outfile, output_encoding, not args.no_pbar,
+         quoting_style)