2 from sqlparse.sql import Comment, Function, Identifier, Parenthesis, Statement
3 from sqlparse import tokens as T
4 from .process_tokens import escape_strings, is_comment_or_space
7 def _is_directive_token(token):
8 if isinstance(token, Comment):
9 subtokens = token.tokens
11 comment = subtokens[0]
12 if comment.ttype is T.Comment.Multiline and \
13 comment.value.startswith('/*!'):
18 def is_directive_statement(statement):
19 tokens = statement.tokens
20 if not _is_directive_token(tokens[0]):
22 if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';':
24 for token in tokens[1:-1]:
25 if token.ttype not in (T.Newline, T.Whitespace):
30 def remove_directive_tokens(statement):
31 """Remove /*! directives */ from the first-level"""
33 for token in statement.tokens:
34 if _is_directive_token(token):
36 new_tokens.append(token)
37 statement.tokens = new_tokens
40 def requote_names(token_list):
41 """Remove backticks, quote non-lowercase identifiers"""
42 for token in token_list.flatten():
43 if token.ttype is T.Name:
45 if (value[0] == "`") and (value[-1] == "`"):
48 token.normalized = token.value = value
50 token.normalized = token.value = '"%s"' % value
53 def unescape_strings(token_list):
54 """Unescape strings"""
55 for token in token_list.flatten():
56 if token.ttype is T.String.Single:
69 value = value.replace(orig, repl)
70 token.normalized = token.value = value
73 def is_insert(statement):
74 for token in statement.tokens:
75 if is_comment_or_space(token):
77 return (token.ttype is T.DML) and (token.normalized == 'INSERT')
80 def split_ext_insert(statement):
81 """Split extended INSERT into multiple standard INSERTs"""
86 for token in statement.tokens:
87 if is_comment_or_space(token):
88 insert_tokens.append(token)
90 elif expected == 'INSERT':
91 if (token.ttype is T.DML) and (token.normalized == 'INSERT'):
92 insert_tokens.append(token)
95 elif expected == 'INTO':
96 if (token.ttype is T.Keyword) and (token.normalized == 'INTO'):
97 insert_tokens.append(token)
98 expected = 'TABLE_NAME'
100 elif expected == 'TABLE_NAME':
101 if isinstance(token, (Function, Identifier)):
102 insert_tokens.append(token)
105 elif expected == 'VALUES':
106 if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'):
107 insert_tokens.append(token)
108 expected = 'VALUES_OR_SEMICOLON'
110 elif expected == 'VALUES_OR_SEMICOLON':
111 if isinstance(token, Parenthesis):
112 values_tokens.append(token)
114 elif token.ttype is T.Punctuation:
115 if token.value == ',':
117 elif token.value == ';':
121 'SQL syntax error: expected "%s", got %s "%s"' % (
122 expected, token.ttype, token.normalized))
123 for values in values_tokens:
124 # The statemnt sets `parent` attribute of the every token to self
128 vl.append(last_token)
129 statement = Statement(insert_tokens + vl)
133 def process_statement(statement, quoting_style='sqlite'):
134 requote_names(statement)
135 unescape_strings(statement)
136 remove_directive_tokens(statement)
137 escape_strings(statement, quoting_style)
138 if is_insert(statement):
139 for statement in split_ext_insert(statement):