]> git.phdru.name Git - sqlconvert.git/blob - sqlconvert/process_mysql.py
Change quoting style to MySQL, PostgreSQL or SQLite
[sqlconvert.git] / sqlconvert / process_mysql.py
1
2 from sqlparse.sql import Comment
3 from sqlparse import tokens as T
4 from .process_tokens import escape_strings
5
6
7 def _is_directive_token(token):
8     if isinstance(token, Comment):
9         subtokens = token.tokens
10         if subtokens:
11             comment = subtokens[0]
12             if comment.ttype is T.Comment.Multiline and \
13                     comment.value.startswith('/*!'):
14                 return True
15     return False
16
17
18 def is_directive_statement(statement):
19     tokens = statement.tokens
20     if not _is_directive_token(tokens[0]):
21         return False
22     if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';':
23         return False
24     for token in tokens[1:-1]:
25         if token.ttype not in (T.Newline, T.Whitespace):
26             return False
27     return True
28
29
30 def remove_directive_tokens(statement):
31     """Remove /*! directives */ from the first-level"""
32     new_tokens = []
33     for token in statement.tokens:
34         if _is_directive_token(token):
35             continue
36         new_tokens.append(token)
37     statement.tokens = new_tokens
38
39
40 def requote_names(token_list):
41     """Remove backticks, quote non-lowercase identifiers"""
42     for token in token_list.flatten():
43         if token.ttype is T.Name:
44             value = token.value
45             if (value[0] == "`") and (value[-1] == "`"):
46                 value = value[1:-1]
47             if value.islower():
48                 token.normalized = token.value = value
49             else:
50                 token.normalized = token.value = '"%s"' % value
51
52
53 def unescape_strings(token_list):
54     """Unescape strings"""
55     for token in token_list.flatten():
56         if token.ttype is T.String.Single:
57             value = token.value
58             for orig, repl in (
59                 ('\\"', '"'),
60                 ("\\'", "'"),
61                 ("''", "'"),
62                 ('\\b', '\b'),
63                 ('\\n', '\n'),
64                 ('\\r', '\r'),
65                 ('\\t', '\t'),
66                 ('\\\032', '\032'),
67                 ('\\\\', '\\'),
68             ):
69                 value = value.replace(orig, repl)
70             token.normalized = token.value = value
71
72
73 def process_statement(statement, quoting_style='sqlite'):
74     remove_directive_tokens(statement)
75     requote_names(statement)
76     unescape_strings(statement)
77     escape_strings(statement, quoting_style)