diff --git a/Cargo.lock b/Cargo.lock index 284ef84b..89a65ac7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2686,11 +2686,12 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.34.0" +version = "0.61.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3706eefb17039056234df6b566b0014f303f867f2656108334a55b8096f59" +checksum = "dbf5ea8d4d7c808e1af1cbabebca9a2abe603bcefc22294c5b95018d53200cb7" dependencies = [ "log", + "recursive", "serde", ] diff --git a/Cargo.toml b/Cargo.toml index 15c93b57..79a7d579 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,7 +56,7 @@ rust_decimal = { version = "1" } serde = { version = "1", features = ["derive", "rc"] } kite_sql_serde_macros = { version = "0.1.0", path = "kite_sql_serde_macros" } siphasher = { version = "1", features = ["serde"] } -sqlparser = { version = "0.34", features = ["serde"] } +sqlparser = { version = "0.61", features = ["serde"] } thiserror = { version = "1" } typetag = { version = "0.2" } ulid = { version = "1", features = ["serde"] } diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index d6780d23..2765e65a 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -87,18 +87,14 @@ impl> Binder<'_, '_, T, A> let return_orderby = if !orderbys.is_empty() { let mut return_orderby = vec![]; for orderby in orderbys { - let OrderByExpr { - expr, - asc, - nulls_first, - } = orderby; + let OrderByExpr { expr, options, .. } = orderby; let mut expr = self.bind_expr(expr)?; self.visit_column_agg_expr(&mut expr)?; return_orderby.push(SortField::new( expr, - asc.is_none_or(|asc| asc), - nulls_first.unwrap_or(false), + options.asc.is_none_or(|asc| asc), + options.nulls_first.unwrap_or(false), )); } Some(return_orderby) diff --git a/src/binder/alter_table.rs b/src/binder/alter_table.rs index 1f1fc218..36f0f3f2 100644 --- a/src/binder/alter_table.rs +++ b/src/binder/alter_table.rs @@ -16,7 +16,7 @@ use sqlparser::ast::{AlterTableOperation, ObjectName}; use std::sync::Arc; -use super::{is_valid_identifier, Binder}; +use super::{attach_span_if_absent, is_valid_identifier, Binder}; use crate::binder::lower_case_name; use crate::errors::DatabaseError; use crate::planner::operator::alter_table::add_column::AddColumnOperator; @@ -43,13 +43,15 @@ impl> Binder<'_, '_, T, A> column_keyword: _, if_not_exists, column_def, + .. } => { let plan = TableScanOperator::build(table_name.clone(), table, true)?; let column = self.bind_column(column_def, None)?; if !is_valid_identifier(column.name()) { - return Err(DatabaseError::InvalidColumn( - "illegal column naming".to_string(), + return Err(attach_span_if_absent( + DatabaseError::invalid_column("illegal column naming".to_string()), + column_def, )); } LogicalPlan::new( @@ -62,12 +64,17 @@ impl> Binder<'_, '_, T, A> ) } AlterTableOperation::DropColumn { - column_name, + column_names, if_exists, .. } => { let plan = TableScanOperator::build(table_name.clone(), table, true)?; - let column_name = column_name.value.clone(); + if column_names.len() != 1 { + return Err(DatabaseError::UnsupportedStmt( + "only dropping a single column is supported".to_string(), + )); + } + let column_name = column_names[0].value.clone(); LogicalPlan::new( Operator::DropColumn(DropColumnOperator { diff --git a/src/binder/create_index.rs b/src/binder/create_index.rs index 06d75df8..515f58da 100644 --- a/src/binder/create_index.rs +++ b/src/binder/create_index.rs @@ -22,23 +22,25 @@ use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::index::IndexType; use crate::types::value::DataValue; -use sqlparser::ast::{ObjectName, OrderByExpr}; +use sqlparser::ast::{IndexColumn, ObjectName}; use std::sync::Arc; impl> Binder<'_, '_, T, A> { pub(crate) fn bind_create_index( &mut self, table_name: &ObjectName, - name: &ObjectName, - exprs: &[OrderByExpr], + name: Option<&ObjectName>, + index_columns: &[IndexColumn], if_not_exists: bool, is_unique: bool, ) -> Result { let table_name: Arc = lower_case_name(table_name)?.into(); - let index_name = lower_case_name(name)?; + let index_name = name + .ok_or(DatabaseError::InvalidIndex) + .and_then(lower_case_name)?; let ty = if is_unique { IndexType::Unique - } else if exprs.len() == 1 { + } else if index_columns.len() == 1 { IndexType::Normal } else { IndexType::Composite @@ -52,11 +54,11 @@ impl> Binder<'_, '_, T, A> Source::Table(table) => TableScanOperator::build(table_name.clone(), table, true)?, Source::View(view) => LogicalPlan::clone(&view.plan), }; - let mut columns = Vec::with_capacity(exprs.len()); + let mut columns = Vec::with_capacity(index_columns.len()); - for expr in exprs { + for index_column in index_columns { // TODO: Expression Index - match self.bind_expr(&expr.expr)? { + match self.bind_expr(&index_column.column.expr)? { ScalarExpression::ColumnRef { column, .. } => columns.push(column), expr => { return Err(DatabaseError::UnsupportedStmt(format!( diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index e22ab879..4b22679d 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use super::{is_valid_identifier, Binder}; +use super::{attach_span_if_absent, is_valid_identifier, Binder}; use crate::binder::lower_case_name; use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::errors::DatabaseError; @@ -24,7 +24,7 @@ use crate::storage::Transaction; use crate::types::value::DataValue; use crate::types::LogicalType; use itertools::Itertools; -use sqlparser::ast::{ColumnDef, ColumnOption, ObjectName, TableConstraint}; +use sqlparser::ast::{ColumnDef, ColumnOption, Expr, IndexColumn, ObjectName, TableConstraint}; use std::collections::HashSet; use std::sync::Arc; @@ -40,8 +40,9 @@ impl> Binder<'_, '_, T, A> let table_name: Arc = lower_case_name(name)?.into(); if !is_valid_identifier(&table_name) { - return Err(DatabaseError::InvalidTable( - "illegal table naming".to_string(), + return Err(attach_span_if_absent( + DatabaseError::invalid_table("illegal table naming".to_string()), + name, )); } { @@ -53,8 +54,9 @@ impl> Binder<'_, '_, T, A> return Err(DatabaseError::DuplicateColumn(col_name.clone())); } if !is_valid_identifier(col_name) { - return Err(DatabaseError::InvalidColumn( - "illegal column naming".to_string(), + return Err(attach_span_if_absent( + DatabaseError::invalid_column("illegal column naming".to_string()), + col, )); } } @@ -66,27 +68,15 @@ impl> Binder<'_, '_, T, A> .try_collect()?; for constraint in constraints { match constraint { - TableConstraint::Unique { - columns: column_names, - is_primary, - .. - } => { - for (i, column_name) in column_names - .iter() - .map(|ident| ident.value.to_lowercase()) - .enumerate() - { - if let Some(column) = columns - .iter_mut() - .find(|column| column.name() == column_name) - { - if *is_primary { - column.desc_mut().set_primary(Some(i)); - } else { - column.desc_mut().set_unique(true); - } - } - } + TableConstraint::PrimaryKey(primary) => { + Self::bind_constraint(&mut columns, &primary.columns, |i, desc| { + desc.set_primary(Some(i)) + })?; + } + TableConstraint::Unique(unique) => { + Self::bind_constraint(&mut columns, &unique.columns, |_, desc| { + desc.set_unique() + })?; } constraint => { return Err(DatabaseError::UnsupportedStmt(format!( @@ -97,8 +87,11 @@ impl> Binder<'_, '_, T, A> } if columns.iter().filter(|col| col.desc().is_primary()).count() == 0 { - return Err(DatabaseError::InvalidTable( - "the primary key field must exist and have at least one".to_string(), + return Err(attach_span_if_absent( + DatabaseError::invalid_table( + "the primary key field must exist and have at least one".to_string(), + ), + name, )); } @@ -112,6 +105,29 @@ impl> Binder<'_, '_, T, A> )) } + fn bind_constraint( + table_columns: &mut [ColumnCatalog], + exprs: &[IndexColumn], + fn_constraint: F, + ) -> Result<(), DatabaseError> { + for (i, index_column) in exprs.iter().enumerate() { + let Expr::Identifier(ident) = &index_column.column.expr else { + return Err(DatabaseError::UnsupportedStmt( + "only identifier columns are supported in `PRIMARY KEY/UNIQUE`".to_string(), + )); + }; + let column_name = ident.value.to_lowercase(); + + if let Some(column) = table_columns + .iter_mut() + .find(|column| column.name() == column_name) + { + fn_constraint(i, column.desc_mut()) + } + } + Ok(()) + } + pub fn bind_column( &mut self, column_def: &ColumnDef, @@ -130,16 +146,13 @@ impl> Binder<'_, '_, T, A> match &option_def.option { ColumnOption::Null => nullable = true, ColumnOption::NotNull => nullable = false, - ColumnOption::Unique { is_primary, .. } => { - if *is_primary { - column_desc.set_primary(column_index); - nullable = false; - // Skip other options when using primary key - break; - } else { - column_desc.set_unique(true); - } + ColumnOption::PrimaryKey(_) => { + column_desc.set_primary(column_index); + nullable = false; + // Skip other options when using primary key + break; } + ColumnOption::Unique(_) => column_desc.set_unique(), ColumnOption::Default(expr) => { let mut expr = self.bind_expr(expr)?; diff --git a/src/binder/create_view.rs b/src/binder/create_view.rs index f8b32df7..9d7b684f 100644 --- a/src/binder/create_view.rs +++ b/src/binder/create_view.rs @@ -23,7 +23,7 @@ use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; use itertools::Itertools; -use sqlparser::ast::{Ident, ObjectName, Query}; +use sqlparser::ast::{ObjectName, Query, ViewColumnDef}; use std::sync::Arc; use ulid::Ulid; @@ -32,41 +32,55 @@ impl> Binder<'_, '_, T, A> &mut self, or_replace: &bool, name: &ObjectName, - columns: &[Ident], + columns: &[ViewColumnDef], query: &Query, ) -> Result { + fn projection_exprs( + view_name: &Arc, + mapping_schema: &[ColumnRef], + column_names: impl Iterator, + ) -> Vec { + column_names + .enumerate() + .map(|(i, column_name)| { + let mapping_column = &mapping_schema[i]; + let mut column = ColumnCatalog::new( + column_name, + mapping_column.nullable(), + mapping_column.desc().clone(), + ); + column.set_ref_table(view_name.clone(), Ulid::new(), true); + + ScalarExpression::Alias { + expr: Box::new(ScalarExpression::column_expr(mapping_column.clone())), + alias: AliasType::Expr(Box::new(ScalarExpression::column_expr( + ColumnRef::from(column), + ))), + } + }) + .collect_vec() + } + let view_name: Arc = lower_case_name(name)?.into(); let mut plan = self.bind_query(query)?; let mapping_schema = plan.output_schema(); - let exprs = if columns.is_empty() { - Box::new( + let exprs: Vec = if columns.is_empty() { + projection_exprs( + &view_name, + mapping_schema, mapping_schema .iter() .map(|column| column.name().to_string()), - ) as Box> + ) } else { - Box::new(columns.iter().map(lower_ident)) as Box> - } - .enumerate() - .map(|(i, column_name)| { - let mapping_column = &mapping_schema[i]; - let mut column = ColumnCatalog::new( - column_name, - mapping_column.nullable(), - mapping_column.desc().clone(), - ); - column.set_ref_table(view_name.clone(), Ulid::new(), true); - - ScalarExpression::Alias { - expr: Box::new(ScalarExpression::column_expr(mapping_column.clone())), - alias: AliasType::Expr(Box::new(ScalarExpression::column_expr(ColumnRef::from( - column, - )))), - } - }) - .collect_vec(); + projection_exprs( + &view_name, + mapping_schema, + columns.iter().map(|column| lower_ident(&column.name)), + ) + }; plan = self.bind_project(plan, exprs)?; Ok(LogicalPlan::new( diff --git a/src/binder/delete.rs b/src/binder/delete.rs index 76f5826b..42703c20 100644 --- a/src/binder/delete.rs +++ b/src/binder/delete.rs @@ -35,7 +35,7 @@ impl> Binder<'_, '_, T, A> let mut table_alias = None; let mut alias_idents = None; - if let Some(TableAlias { name, columns }) = alias { + if let Some(TableAlias { name, columns, .. }) = alias { table_alias = Some(name.value.to_lowercase().into()); alias_idents = Some(columns); } diff --git a/src/binder/drop_index.rs b/src/binder/drop_index.rs index 1a2dd6f9..5c6bd1e9 100644 --- a/src/binder/drop_index.rs +++ b/src/binder/drop_index.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{lower_ident, Binder}; +use crate::binder::{attach_span_if_absent, lower_name_part, Binder}; use crate::errors::DatabaseError; use crate::planner::operator::drop_index::DropIndexOperator; use crate::planner::operator::Operator; @@ -27,14 +27,13 @@ impl> Binder<'_, '_, T, A> name: &ObjectName, if_exists: &bool, ) -> Result { - let table_name = name - .0 - .first() - .ok_or(DatabaseError::InvalidTable(name.to_string()))?; + let table_name = name.0.first().ok_or_else(|| { + attach_span_if_absent(DatabaseError::invalid_table(name.to_string()), name) + })?; let index_name = name.0.get(1).ok_or(DatabaseError::InvalidIndex)?; - let table_name = lower_ident(table_name).into(); - let index_name = lower_ident(index_name); + let table_name = lower_name_part(table_name)?.into(); + let index_name = lower_name_part(index_name)?; Ok(LogicalPlan::new( Operator::DropIndex(DropIndexOperator { diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 360a9515..293fce54 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -18,13 +18,16 @@ use crate::expression; use crate::expression::agg::AggKind; use itertools::Itertools; use sqlparser::ast::{ - BinaryOperator, CharLengthUnits, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, - Query, UnaryOperator, Value, + BinaryOperator, CharLengthUnits, DataType, DuplicateTreatment, Expr, Function, FunctionArg, + FunctionArgExpr, FunctionArguments, Ident, Query, TypedString, UnaryOperator, Value, }; use std::collections::HashMap; use std::slice; -use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType}; +use super::{ + attach_span_from_sqlparser_span_if_absent, attach_span_if_absent, lower_ident, Binder, + BinderContext, QueryBindStep, SubQueryType, +}; use crate::expression::function::scala::{ArcScalarFunctionImpl, ScalarFunction}; use crate::expression::function::table::{ArcTableFunctionImpl, TableFunction}; use crate::expression::function::FunctionSummary; @@ -54,6 +57,29 @@ macro_rules! try_default { } impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T, A> { + fn parse_like_escape_char(escape_char: &Option) -> Result, DatabaseError> { + match escape_char { + None => Ok(None), + Some(value) => match value { + Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => { + let mut chars = s.chars(); + let ch = chars.next().ok_or(DatabaseError::InvalidValue( + "escape character must not be empty".to_string(), + ))?; + if chars.next().is_some() { + return Err(DatabaseError::InvalidValue( + "escape character must be a single character".to_string(), + )); + } + Ok(Some(ch)) + } + _ => Err(DatabaseError::InvalidValue( + "escape character must be a quoted string".to_string(), + )), + }, + } + } + pub(crate) fn bind_expr(&mut self, expr: &Expr) -> Result { match expr { Expr::Identifier(ident) => { @@ -62,14 +88,21 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T Expr::CompoundIdentifier(idents) => self.bind_column_ref_from_identifiers(idents, None), Expr::BinaryOp { left, right, op } => self.bind_binary_op_internal(left, right, op), Expr::Value(v) => { - let value = if let Value::Placeholder(name) = v { + let value = if let Value::Placeholder(name) = &v.value { self.args .as_ref() .iter() .find_map(|(key, value)| (key == name).then(|| value.clone())) - .ok_or_else(|| DatabaseError::ParametersNotFound(name.to_string()))? + .ok_or_else(|| { + attach_span_if_absent( + DatabaseError::parameter_not_found(name.to_string()), + v, + ) + })? } else { - v.try_into()? + (&v.value) + .try_into() + .map_err(|err| attach_span_if_absent(err, v))? }; Ok(ScalarExpression::Constant(value)) } @@ -81,6 +114,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T expr, pattern, escape_char, + any: _, } => self.bind_like(*negated, expr, pattern, escape_char), Expr::IsNull(expr) => self.bind_is_null(expr, false), Expr::IsNotNull(expr) => self.bind_is_null(expr, true), @@ -92,14 +126,20 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T Expr::Cast { expr, data_type, .. } => self.bind_cast(expr, data_type), - Expr::TypedString { data_type, value } => { + Expr::TypedString(TypedString { + data_type, value, .. + }) => { let logical_type = LogicalType::try_from(data_type.clone())?; + let raw = value.clone().into_string().ok_or_else(|| { + DatabaseError::InvalidValue("typed string literal must be a string".to_string()) + })?; let value = DataValue::Utf8 { - value: value.to_string(), + value: raw, ty: Utf8Type::Variable(None), unit: CharLengthUnits::Characters, } - .cast(&logical_type)?; + .cast(&logical_type) + .map_err(|err| attach_span_if_absent(err, expr))?; Ok(ScalarExpression::Constant(value)) } @@ -144,6 +184,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T expr, trim_what, trim_where, + .. } => { let mut trim_what_expr = None; if let Some(trim_what) = trim_what { @@ -214,8 +255,8 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T Expr::Case { operand, conditions, - results, else_result, + .. } => { let fn_check_ty = |ty: &mut LogicalType, result_ty| { if result_ty != LogicalType::SqlNull { @@ -234,12 +275,12 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T operand_expr = Some(Box::new(self.bind_expr(expr)?)); } let mut expr_pairs = Vec::with_capacity(conditions.len()); - for i in 0..conditions.len() { - let result = self.bind_expr(&results[i])?; + for when in conditions { + let result = self.bind_expr(&when.result)?; let result_ty = result.return_type(); fn_check_ty(&mut ty, result_ty)?; - expr_pairs.push((self.bind_expr(&conditions[i])?, result)) + expr_pairs.push((self.bind_expr(&when.condition)?, result)) } let mut else_expr = None; @@ -340,14 +381,15 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T negated: bool, expr: &Expr, pattern: &Expr, - escape_char: &Option, + escape_char: &Option, ) -> Result { let left_expr = Box::new(self.bind_expr(expr)?); let right_expr = Box::new(self.bind_expr(pattern)?); + let escape_char = Self::parse_like_escape_char(escape_char)?; let op = if negated { - expression::BinaryOperator::NotLike(*escape_char) + expression::BinaryOperator::NotLike(escape_char) } else { - expression::BinaryOperator::Like(*escape_char) + expression::BinaryOperator::Like(escape_char) }; Ok(ScalarExpression::Binary { op, @@ -367,13 +409,16 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T [column] => (None, lower_ident(column)), [table, column] => (Some(lower_ident(table)), lower_ident(column)), _ => { - return Err(DatabaseError::InvalidColumn( - idents - .iter() - .map(|ident| ident.value.clone()) - .join(".") - .to_string(), - )) + let invalid_name = idents + .iter() + .map(|ident| ident.value.clone()) + .join(".") + .to_string(); + let err = DatabaseError::invalid_column(invalid_name); + return Err(match idents.last() { + Some(ident) => attach_span_from_sqlparser_span_if_absent(err, ident.span), + None => err, + }); } }; try_alias!(self.context, full_name); @@ -381,13 +426,23 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T try_default!(&full_name.0, full_name.1); } if let Some(table) = full_name.0.or(bind_table_name) { - let source = self.context.bind_source(&table)?; + let source = self.context.bind_source(&table).map_err(|err| { + if let [table_ident, _] = idents { + attach_span_from_sqlparser_span_if_absent(err, table_ident.span) + } else { + err + } + })?; let schema_buf = self.table_schema_buf.entry(table.into()).or_default(); Ok(ScalarExpression::column_expr( - source - .column(&full_name.1, schema_buf) - .ok_or_else(|| DatabaseError::ColumnNotFound(full_name.1.to_string()))?, + source.column(&full_name.1, schema_buf).ok_or_else(|| { + let err = DatabaseError::column_not_found(full_name.1.to_string()); + match idents.last() { + Some(ident) => attach_span_from_sqlparser_span_if_absent(err, ident.span), + None => err, + } + })?, )) } else { let op = @@ -427,7 +482,16 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T if let Some(parent) = self.parent { op(&mut got_column, &parent.context, &mut self.table_schema_buf); } - Ok(got_column.ok_or(DatabaseError::ColumnNotFound(full_name.1))?) + match got_column { + Some(column) => Ok(column), + None => { + let err = DatabaseError::column_not_found(full_name.1.clone()); + Err(match idents.last() { + Some(ident) => attach_span_from_sqlparser_span_if_absent(err, ident.span), + None => err, + }) + } + } } } @@ -505,11 +569,24 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T "`TableFunction` cannot bind in non-From step".to_string(), )); } - let mut args = Vec::with_capacity(func.args.len()); + let (func_args, is_distinct) = match &func.args { + FunctionArguments::List(args) => ( + args.args.as_slice(), + matches!(args.duplicate_treatment, Some(DuplicateTreatment::Distinct)), + ), + FunctionArguments::None => (&[][..], false), + FunctionArguments::Subquery(_) => { + return Err(DatabaseError::UnsupportedStmt( + "subquery function args are not supported".to_string(), + )) + } + }; + let mut args = Vec::with_capacity(func_args.len()); - for arg in func.args.iter() { + for arg in func_args { let arg_expr = match arg { FunctionArg::Named { arg, .. } => arg, + FunctionArg::ExprNamed { arg, .. } => arg, FunctionArg::Unnamed(arg) => arg, }; match arg_expr { @@ -530,7 +607,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T return Err(DatabaseError::MisMatch("number of count() parameters", "1")); } return Ok(ScalarExpression::AggCall { - distinct: func.distinct, + distinct: is_distinct, kind: AggKind::Count, args, ty: LogicalType::Integer, @@ -543,7 +620,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let ty = args[0].return_type(); return Ok(ScalarExpression::AggCall { - distinct: func.distinct, + distinct: is_distinct, kind: AggKind::Sum, args, ty, @@ -556,7 +633,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let ty = args[0].return_type(); return Ok(ScalarExpression::AggCall { - distinct: func.distinct, + distinct: is_distinct, kind: AggKind::Min, args, ty, @@ -569,7 +646,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let ty = args[0].return_type(); return Ok(ScalarExpression::AggCall { - distinct: func.distinct, + distinct: is_distinct, kind: AggKind::Max, args, ty, @@ -581,7 +658,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T } return Ok(ScalarExpression::AggCall { - distinct: func.distinct, + distinct: is_distinct, kind: AggKind::Avg, args, ty: LogicalType::Double, @@ -678,7 +755,10 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T })); } - Err(DatabaseError::FunctionNotFound(summary.name.to_string())) + Err(attach_span_if_absent( + DatabaseError::function_not_found(summary.name.to_string()), + func, + )) } fn return_type( diff --git a/src/binder/insert.rs b/src/binder/insert.rs index cdd2ed8d..46249419 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{lower_case_name, Binder}; +use crate::binder::{attach_span_if_absent, lower_case_name, Binder}; use crate::errors::DatabaseError; use crate::expression::simplify::ConstantCalculator; use crate::expression::visitor_mut::VisitorMut; @@ -96,6 +96,12 @@ impl> Binder<'_, '_, T, A> } // Check if the value length is too long value.check_len(ty)?; + if value.is_null() && !schema_ref[i].nullable() { + return Err(attach_span_if_absent( + DatabaseError::not_null_column(schema_ref[i].name().to_string()), + expr, + )); + } row.push(value); } @@ -103,6 +109,12 @@ impl> Binder<'_, '_, T, A> let default_value = schema_ref[i] .default_value()? .ok_or(DatabaseError::DefaultNotExist)?; + if default_value.is_null() && !schema_ref[i].nullable() { + return Err(attach_span_if_absent( + DatabaseError::not_null_column(schema_ref[i].name().to_string()), + expr, + )); + } row.push(default_value); } _ => return Err(DatabaseError::UnsupportedStmt(expr.to_string())), diff --git a/src/binder/mod.rs b/src/binder/mod.rs index e7afb0fe..9052af22 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -34,7 +34,11 @@ mod show_view; mod truncate; mod update; -use sqlparser::ast::{Ident, ObjectName, ObjectType, SetExpr, Statement}; +use sqlparser::ast::{ + DescribeAlias, FromTable, Ident, ObjectName, ObjectNamePart, ObjectType, SetExpr, Spanned, + Statement, TableObject, +}; +use sqlparser::tokenizer::Span; use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -42,7 +46,7 @@ use std::sync::Arc; use crate::catalog::view::View; use crate::catalog::{ColumnRef, TableCatalog, TableName}; use crate::db::{ScalaFunctions, TableFunctions}; -use crate::errors::DatabaseError; +use crate::errors::{DatabaseError, SqlErrorSpan}; use crate::expression::ScalarExpression; use crate::planner::operator::join::JoinType; use crate::planner::{LogicalPlan, SchemaOutput}; @@ -61,23 +65,67 @@ pub enum CommandType { DDL, } +fn annotate_bind_error(stmt: &Statement, err: DatabaseError) -> DatabaseError { + attach_span_if_absent(err, stmt) +} + +pub(crate) fn attach_span_from_sqlparser_span_if_absent( + err: DatabaseError, + span: Span, +) -> DatabaseError { + if err.sql_error_span().is_some() { + return err; + } + + match sqlparser_span_to_sql_error_span(span) { + Some(span) => err.with_span(span), + None => err, + } +} + +pub(crate) fn attach_span_if_absent( + err: DatabaseError, + node: &T, +) -> DatabaseError { + attach_span_from_sqlparser_span_if_absent(err, node.span()) +} + +pub(crate) fn sqlparser_span_to_sql_error_span(span: Span) -> Option { + if span == Span::empty() { + return None; + } + + let start = span.start.column as usize; + let mut end = span.end.column as usize; + if end <= start { + end = start.saturating_add(1); + } + + Some(SqlErrorSpan { + start, + end, + line: span.start.line as usize, + highlight: None, + }) +} + pub fn command_type(stmt: &Statement) -> Result { match stmt { - Statement::CreateTable { .. } - | Statement::CreateIndex { .. } - | Statement::CreateView { .. } - | Statement::AlterTable { .. } + Statement::CreateTable(_) + | Statement::CreateIndex(_) + | Statement::CreateView(_) + | Statement::AlterTable(_) | Statement::Drop { .. } => Ok(CommandType::DDL), Statement::Query(_) | Statement::Explain { .. } | Statement::ExplainTable { .. } | Statement::ShowTables { .. } - | Statement::ShowVariable { .. } => Ok(CommandType::DQL), - Statement::Analyze { .. } - | Statement::Truncate { .. } - | Statement::Update { .. } - | Statement::Delete { .. } - | Statement::Insert { .. } + | Statement::ShowViews { .. } => Ok(CommandType::DQL), + Statement::Analyze(_) + | Statement::Truncate(_) + | Statement::Update(_) + | Statement::Delete(_) + | Statement::Insert(_) | Statement::Copy { .. } => Ok(CommandType::DML), stmt => Err(DatabaseError::UnsupportedStmt(stmt.to_string())), } @@ -298,7 +346,7 @@ impl<'a, T: Transaction> BinderContext<'a, T> { }) { Ok(source.1) } else { - Err(DatabaseError::InvalidTable(table_name.into())) + Err(DatabaseError::invalid_table(table_name)) } } @@ -375,17 +423,23 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' false } - pub fn bind(&mut self, stmt: &Statement) -> Result { + fn bind_inner(&mut self, stmt: &Statement) -> Result { let plan = match stmt { Statement::Query(query) => self.bind_query(query)?, - Statement::AlterTable { name, operation } => self.bind_alter_table(name, operation)?, - Statement::CreateTable { - name, - columns, - constraints, - if_not_exists, - .. - } => self.bind_create_table(name, columns, constraints, *if_not_exists)?, + Statement::AlterTable(alter) => { + if alter.operations.len() != 1 { + return Err(DatabaseError::UnsupportedStmt( + "only a single ALTER TABLE operation is supported".to_string(), + )); + } + self.bind_alter_table(&alter.name, &alter.operations[0])? + } + Statement::CreateTable(create) => self.bind_create_table( + &create.name, + &create.columns, + &create.constraints, + create.if_not_exists, + )?, Statement::Drop { object_type, names, @@ -408,16 +462,29 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' } } } - Statement::Insert { - table_name, - columns, - source, - overwrite, - .. - } => { + Statement::Insert(insert) => { + let table_name = match &insert.table { + TableObject::TableName(table_name) => table_name, + TableObject::TableFunction(_) => { + return Err(DatabaseError::UnsupportedStmt( + "insert into table function is not supported".to_string(), + )) + } + }; + let source = insert.source.as_ref().ok_or_else(|| { + DatabaseError::UnsupportedStmt( + "insert without source is not supported".to_string(), + ) + })?; // TODO: support body on Insert if let SetExpr::Values(values) = source.body.as_ref() { - self.bind_insert(table_name, columns, &values.rows, *overwrite, false)? + self.bind_insert( + table_name, + &insert.columns, + &values.rows, + insert.overwrite, + false, + )? } else { return Err(DatabaseError::UnsupportedStmt(format!( "insert body: {:#?}", @@ -425,36 +492,44 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' ))); } } - Statement::Update { - table, - selection, - assignments, - .. - } => { + Statement::Update(update) => { + let table = &update.table; if !table.joins.is_empty() { unimplemented!() } else { - self.bind_update(table, selection, assignments)? + self.bind_update(table, &update.selection, &update.assignments)? } } - Statement::Delete { - from, selection, .. - } => { + Statement::Delete(delete) => { + let from = match &delete.from { + FromTable::WithFromKeyword(from) | FromTable::WithoutKeyword(from) => from, + }; let table = &from[0]; if !table.joins.is_empty() { unimplemented!() } else { - self.bind_delete(table, selection)? + self.bind_delete(table, &delete.selection)? + } + } + Statement::Analyze(analyze) => { + let table_name = analyze.table_name.as_ref().ok_or_else(|| { + DatabaseError::UnsupportedStmt( + "ANALYZE without table is not supported".to_string(), + ) + })?; + self.bind_analyze(table_name)? + } + Statement::Truncate(truncate) => { + if truncate.table_names.len() != 1 { + return Err(DatabaseError::UnsupportedStmt( + "only truncate a single table is supported".to_string(), + )); } + self.bind_truncate(&truncate.table_names[0].name)? } - Statement::Analyze { table_name, .. } => self.bind_analyze(table_name)?, - Statement::Truncate { table_name, .. } => self.bind_truncate(table_name)?, Statement::ShowTables { .. } => self.bind_show_tables()?, - Statement::ShowVariable { variable } => match &variable[0].value.to_lowercase()[..] { - "views" => self.bind_show_views()?, - _ => return Err(DatabaseError::UnsupportedStmt(stmt.to_string())), - }, + Statement::ShowViews { .. } => self.bind_show_views()?, Statement::Copy { source, to, @@ -463,37 +538,41 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' .. } => self.bind_copy(source.clone(), *to, target.clone(), options)?, Statement::Explain { statement, .. } => { - let plan = self.bind(statement)?; + let plan = self.bind_inner(statement)?; self.bind_explain(plan)? } Statement::ExplainTable { - describe_alias: true, - table_name, - } => self.bind_describe(table_name)?, - Statement::CreateIndex { + describe_alias: DescribeAlias::Describe | DescribeAlias::Desc, table_name, - name, - columns, - if_not_exists, - unique, - .. - } => self.bind_create_index(table_name, name, columns, *if_not_exists, *unique)?, - Statement::CreateView { - or_replace, - name, - columns, - query, .. - } => self.bind_create_view(or_replace, name, columns, query)?, + } => self.bind_describe(table_name)?, + Statement::CreateIndex(create) => self.bind_create_index( + &create.table_name, + create.name.as_ref(), + &create.columns, + create.if_not_exists, + create.unique, + )?, + Statement::CreateView(create) => self.bind_create_view( + &create.or_replace, + &create.name, + &create.columns, + &create.query, + )?, _ => return Err(DatabaseError::UnsupportedStmt(stmt.to_string())), }; Ok(plan) } + pub fn bind(&mut self, stmt: &Statement) -> Result { + self.bind_inner(stmt) + .map_err(|err| annotate_bind_error(stmt, err)) + } + pub fn bind_set_expr(&mut self, set_expr: &SetExpr) -> Result { match set_expr { - SetExpr::Select(select) => self.bind_select(select, &[]), + SetExpr::Select(select) => self.bind_select(select, None), SetExpr::Query(query) => self.bind_query(query), SetExpr::SetOperation { op, @@ -524,12 +603,21 @@ fn lower_ident(ident: &Ident) -> String { ident.value.to_lowercase() } +fn lower_name_part(part: &ObjectNamePart) -> Result { + part.as_ident() + .map(lower_ident) + .ok_or_else(|| attach_span_if_absent(DatabaseError::invalid_table(part.to_string()), part)) +} + /// Convert an object name into lower case fn lower_case_name(name: &ObjectName) -> Result { if name.0.len() == 1 { - return Ok(lower_ident(&name.0[0])); + return lower_name_part(&name.0[0]); } - Err(DatabaseError::InvalidTable(name.to_string())) + Err(attach_span_if_absent( + DatabaseError::invalid_table(name.to_string()), + name, + )) } pub(crate) fn is_valid_identifier(s: &str) -> bool { diff --git a/src/binder/select.rs b/src/binder/select.rs index a1b18b39..baa87804 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -28,7 +28,8 @@ use std::collections::HashSet; use std::sync::Arc; use super::{ - lower_case_name, lower_ident, Binder, BinderContext, QueryBindStep, Source, SubQueryType, + attach_span_if_absent, lower_case_name, lower_ident, Binder, BinderContext, QueryBindStep, + Source, SubQueryType, }; use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnSummary, TableName}; @@ -52,9 +53,10 @@ use crate::types::value::Utf8Type; use crate::types::{ColumnId, LogicalType}; use itertools::Itertools; use sqlparser::ast::{ - CharLengthUnits, Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, Offset, - OrderByExpr, Query, Select, SelectInto, SelectItem, SetExpr, SetOperator, SetQuantifier, - TableAlias, TableFactor, TableWithJoins, + CharLengthUnits, Distinct, Expr, GroupByExpr, Join, JoinConstraint, JoinOperator, LimitClause, + OrderByExpr, OrderByKind, Query, Select, SelectInto, SelectItem, + SelectItemQualifiedWildcardKind, SetExpr, SetOperator, SetQuantifier, TableAlias, + TableAliasColumnDef, TableFactor, TableWithJoins, }; impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, 'b, T, A> { @@ -65,8 +67,20 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' // TODO support with clause. } + let order_by_exprs = if let Some(order_by) = &query.order_by { + match &order_by.kind { + OrderByKind::Expressions(exprs) => Some(exprs.as_slice()), + OrderByKind::All(_) => { + return Err(DatabaseError::UnsupportedStmt( + "ORDER BY ALL is not supported".to_string(), + )) + } + } + } else { + None + }; let mut plan = match query.body.borrow() { - SetExpr::Select(select) => self.bind_select(select, &query.order_by), + SetExpr::Select(select) => self.bind_select(select, order_by_exprs), SetExpr::Query(query) => self.bind_query(query), SetExpr::SetOperation { op, @@ -82,11 +96,8 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } }?; - let limit = &query.limit; - let offset = &query.offset; - - if limit.is_some() || offset.is_some() { - plan = self.bind_limit(plan, limit, offset)?; + if let Some(limit_clause) = query.limit_clause.clone() { + plan = self.bind_limit(plan, limit_clause)?; } self.context.step(origin_step); @@ -96,7 +107,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' pub(crate) fn bind_select( &mut self, select: &Select, - orderby: &[OrderByExpr], + orderby: Option<&[OrderByExpr]>, ) -> Result { let mut plan = if select.from.is_empty() { LogicalPlan::new(Operator::Dummy, Childrens::None) @@ -123,14 +134,29 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' self.extract_select_join(&mut select_list); self.extract_select_aggregate(&mut select_list)?; - if !select.group_by.is_empty() { - self.extract_group_by_aggregate(&mut select_list, &select.group_by)?; + match &select.group_by { + GroupByExpr::Expressions(group_by_exprs, modifiers) => { + if !modifiers.is_empty() { + return Err(DatabaseError::UnsupportedStmt( + "GROUP BY modifiers are not supported".to_string(), + )); + } + if !group_by_exprs.is_empty() { + self.extract_group_by_aggregate(&mut select_list, group_by_exprs)?; + } + } + GroupByExpr::All(_) => { + return Err(DatabaseError::UnsupportedStmt( + "GROUP BY ALL is not supported".to_string(), + )) + } } let mut having_orderby = (None, None); - if select.having.is_some() || !orderby.is_empty() { - having_orderby = self.extract_having_orderby_aggregate(&select.having, orderby)?; + if select.having.is_some() || orderby.is_some() { + having_orderby = + self.extract_having_orderby_aggregate(&select.having, orderby.unwrap_or(&[]))?; } if !self.context.agg_calls.is_empty() || !self.context.group_by_exprs.is_empty() { @@ -286,6 +312,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let is_all = match set_quantifier { SetQuantifier::All => true, SetQuantifier::Distinct | SetQuantifier::None => false, + SetQuantifier::ByName | SetQuantifier::AllByName | SetQuantifier::DistinctByName => { + return Err(DatabaseError::UnsupportedStmt( + "set quantifier BY NAME is not supported".to_string(), + )) + } }; let mut left_plan = self.bind_set_expr(left)?; let mut right_plan = self.bind_set_expr(right)?; @@ -418,6 +449,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' if let Some(TableAlias { name, columns: alias_column, + .. }) = alias { if tables.len() > 1 { @@ -442,6 +474,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' if let Some(TableAlias { name, columns: alias_column, + .. }) = alias { table_alias = Some(name.value.to_lowercase().into()); @@ -471,7 +504,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' pub(crate) fn bind_alias( &mut self, mut plan: LogicalPlan, - alias_column: &[Ident], + alias_column: &[TableAliasColumnDef], table_alias: TableName, table_name: TableName, ) -> Result { @@ -488,7 +521,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } else { alias_column .iter() - .map(lower_ident) + .map(|column| lower_ident(&column.name)) .zip(input_schema.iter().cloned()) .collect_vec() }; @@ -530,7 +563,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let mut table_alias: Option = None; let mut alias_idents = None; - if let Some(TableAlias { name, columns }) = alias { + if let Some(TableAlias { name, columns, .. }) = alias { table_alias = Some(name.value.to_lowercase().into()); alias_idents = Some(columns); } @@ -599,7 +632,16 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } } SelectItem::QualifiedWildcard(table_name, _) => { - let table_name: Arc = lower_case_name(table_name)?.into(); + let table_name: Arc = match table_name { + SelectItemQualifiedWildcardKind::ObjectName(name) => { + lower_case_name(name)?.into() + } + SelectItemQualifiedWildcardKind::Expr(expr) => { + return Err(DatabaseError::UnsupportedStmt(format!( + "qualified wildcard expr: {expr}" + ))) + } + }; let schema_buf = self.table_schema_buf.entry(table_name.clone()).or_default(); Self::bind_table_column_refs( @@ -676,15 +718,34 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let Join { relation, join_operator, + .. } = join; let (join_type, joint_condition) = match join_operator { - JoinOperator::Inner(constraint) => (JoinType::Inner, Some(constraint)), - JoinOperator::LeftOuter(constraint) => (JoinType::LeftOuter, Some(constraint)), - JoinOperator::RightOuter(constraint) => (JoinType::RightOuter, Some(constraint)), + JoinOperator::Join(constraint) + | JoinOperator::Inner(constraint) + | JoinOperator::StraightJoin(constraint) => (JoinType::Inner, Some(constraint)), + JoinOperator::Left(constraint) | JoinOperator::LeftOuter(constraint) => { + (JoinType::LeftOuter, Some(constraint)) + } + JoinOperator::Right(constraint) | JoinOperator::RightOuter(constraint) => { + (JoinType::RightOuter, Some(constraint)) + } JoinOperator::FullOuter(constraint) => (JoinType::Full, Some(constraint)), - JoinOperator::CrossJoin => (JoinType::Cross, None), - _ => unimplemented!(), + JoinOperator::Semi(constraint) | JoinOperator::LeftSemi(constraint) => { + (JoinType::LeftSemi, Some(constraint)) + } + JoinOperator::Anti(constraint) | JoinOperator::LeftAnti(constraint) => { + (JoinType::LeftAnti, Some(constraint)) + } + JoinOperator::CrossJoin(constraint) => (JoinType::Cross, Some(constraint)), + JoinOperator::RightSemi(_) + | JoinOperator::RightAnti(_) + | JoinOperator::CrossApply + | JoinOperator::OuterApply + | JoinOperator::AsOf { .. } => { + return Err(DatabaseError::UnsupportedStmt(format!("{join_operator:?}"))) + } }; let BinderContext { table_cache, @@ -870,49 +931,60 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' ) } + fn bind_non_negative_limit_value(&mut self, expr: &Expr) -> Result { + let bound_expr = self.bind_expr(expr)?; + match bound_expr { + ScalarExpression::Constant(dv) => match &dv { + DataValue::Int32(v) if *v >= 0 => Ok(*v as usize), + DataValue::Int64(v) if *v >= 0 => Ok(*v as usize), + _ => Err(DatabaseError::InvalidType), + }, + _ => Err(attach_span_if_absent( + DatabaseError::invalid_column("invalid limit expression.".to_owned()), + expr, + )), + } + } + fn bind_limit( &mut self, children: LogicalPlan, - limit_expr: &Option, - offset_expr: &Option, + limit: LimitClause, ) -> Result { self.context.step(QueryBindStep::Limit); - let mut limit = None; - let mut offset = None; - if let Some(expr) = limit_expr { - let expr = self.bind_expr(expr)?; - match expr { - ScalarExpression::Constant(dv) => match &dv { - DataValue::Int32(v) if *v >= 0 => limit = Some(*v as usize), - DataValue::Int64(v) if *v >= 0 => limit = Some(*v as usize), - _ => return Err(DatabaseError::InvalidType), - }, - _ => { - return Err(DatabaseError::InvalidColumn( - "invalid limit expression.".to_owned(), - )) + let mut limit_value = None; + let mut offset_value = None; + match limit { + LimitClause::LimitOffset { + limit: limit_expr, + offset: offset_expr, + limit_by, + } => { + if !limit_by.is_empty() { + return Err(DatabaseError::UnsupportedStmt( + "LIMIT BY is not supported".to_string(), + )); } - } - } - if let Some(expr) = offset_expr { - let expr = self.bind_expr(&expr.value)?; - match expr { - ScalarExpression::Constant(dv) => match &dv { - DataValue::Int32(v) if *v >= 0 => offset = Some(*v as usize), - DataValue::Int64(v) if *v >= 0 => offset = Some(*v as usize), - _ => return Err(DatabaseError::InvalidType), - }, - _ => { - return Err(DatabaseError::InvalidColumn( - "invalid limit expression.".to_owned(), - )) + if let Some(limit_ast) = limit_expr.as_ref() { + limit_value = Some(self.bind_non_negative_limit_value(limit_ast)?); + } + + if let Some(offset_ast) = offset_expr.as_ref() { + offset_value = Some(self.bind_non_negative_limit_value(&offset_ast.value)?); } } + LimitClause::OffsetCommaLimit { + offset: offset_expr, + limit: limit_expr, + } => { + limit_value = Some(self.bind_non_negative_limit_value(&limit_expr)?); + offset_value = Some(self.bind_non_negative_limit_value(&offset_expr)?); + } } - Ok(LimitOperator::build(offset, limit, children)) + Ok(LimitOperator::build(offset_value, limit_value, children)) } pub fn extract_select_join(&mut self, select_items: &mut [ScalarExpression]) { @@ -1006,12 +1078,15 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = Vec::new(); for ident in idents { - let name = lower_ident(ident); + let name = lower_case_name(ident)?; let (Some(left_column), Some(right_column)) = ( find_column(left_schema, &name), find_column(right_schema, &name), ) else { - return Err(DatabaseError::InvalidColumn("not found column".to_string())); + return Err(attach_span_if_absent( + DatabaseError::invalid_column("not found column".to_string()), + ident, + )); }; self.context.add_using(join_type, left_column, right_column); on_keys.push(( diff --git a/src/binder/update.rs b/src/binder/update.rs index d29b23a7..35245bb6 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{lower_case_name, Binder}; +use crate::binder::{ + attach_span_from_sqlparser_span_if_absent, attach_span_if_absent, lower_case_name, Binder, +}; use crate::errors::DatabaseError; use crate::expression::ScalarExpression; use crate::planner::operator::update::UpdateOperator; @@ -20,11 +22,25 @@ use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; -use sqlparser::ast::{Assignment, Expr, TableFactor, TableWithJoins}; +use sqlparser::ast::{ + Assignment, AssignmentTarget, Expr, Ident, ObjectName, TableFactor, TableWithJoins, +}; use std::slice; use std::sync::Arc; impl> Binder<'_, '_, T, A> { + fn single_ident_from_object_name(name: &ObjectName) -> Result<&Ident, DatabaseError> { + if name.0.len() != 1 { + return Err(attach_span_if_absent( + DatabaseError::invalid_column(name.to_string()), + name, + )); + } + name.0[0].as_ident().ok_or_else(|| { + attach_span_if_absent(DatabaseError::invalid_column(name.to_string()), name) + }) + } + pub(crate) fn bind_update( &mut self, to: &TableWithJoins, @@ -47,10 +63,21 @@ impl> Binder<'_, '_, T, A> if assignments.is_empty() { return Err(DatabaseError::ColumnsEmpty); } - for Assignment { id, value } in assignments { + for Assignment { target, value } in assignments { let expression = self.bind_expr(value)?; + let mut idents = vec![]; + match target { + AssignmentTarget::ColumnName(name) => { + idents.push(Self::single_ident_from_object_name(name)?); + } + AssignmentTarget::Tuple(_) => { + return Err(DatabaseError::UnsupportedStmt( + "UPDATE assignment tuple target is not supported".to_string(), + )) + } + } - for ident in id { + for ident in idents { match self.bind_column_ref_from_identifiers( slice::from_ref(ident), Some(table_name.to_string()), @@ -72,7 +99,12 @@ impl> Binder<'_, '_, T, A> } value_exprs.push((column, expr)); } - _ => return Err(DatabaseError::InvalidColumn(ident.to_string())), + _ => { + return Err(attach_span_from_sqlparser_span_if_absent( + DatabaseError::invalid_column(ident.to_string()), + ident.span, + )) + } } } } diff --git a/src/catalog/column.rs b/src/catalog/column.rs index a6988cba..36df8e02 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -243,7 +243,7 @@ impl ColumnDesc { self.is_unique } - pub(crate) fn set_unique(&mut self, is_unique: bool) { - self.is_unique = is_unique + pub(crate) fn set_unique(&mut self) { + self.is_unique = true } } diff --git a/src/catalog/table.rs b/src/catalog/table.rs index b041b420..be86eca3 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -154,7 +154,7 @@ impl TableCatalog { for column_id in column_ids.iter() { let val_ty = self .get_column_by_id(column_id) - .ok_or_else(|| DatabaseError::ColumnNotFound(column_id.to_string()))? + .ok_or_else(|| DatabaseError::column_not_found(column_id.to_string()))? .datatype() .clone(); val_tys.push(val_ty) @@ -233,7 +233,7 @@ impl TableCatalog { let mut columns = BTreeMap::new(); for (i, column_ref) in column_refs.iter().enumerate() { - let column_id = column_ref.id().ok_or(DatabaseError::InvalidColumn( + let column_id = column_ref.id().ok_or(DatabaseError::invalid_column( "column does not belong to table".to_string(), ))?; diff --git a/src/db.rs b/src/db.rs index d6992386..fef32549 100644 --- a/src/db.rs +++ b/src/db.rs @@ -321,10 +321,18 @@ impl State { } fn prepare>(&self, sql: T) -> Result { - let mut stmts = parse_sql(sql)?; + let mut stmts = self.prepare_all(sql)?; stmts.pop().ok_or(DatabaseError::EmptyStatement) } + fn prepare_all>(&self, sql: T) -> Result, DatabaseError> { + let stmts = parse_sql(sql)?; + if stmts.is_empty() { + return Err(DatabaseError::EmptyStatement); + } + Ok(stmts) + } + fn execute<'a, A: AsRef<[(&'static str, DataValue)]>>( &'a self, transaction: &'a mut S::TransactionType<'_>, @@ -361,9 +369,64 @@ pub struct Database { impl Database { /// Run SQL queries. pub fn run>(&self, sql: T) -> Result, DatabaseError> { - let statement = self.prepare(sql)?; + let sql = sql.as_ref(); + let statements = self + .state + .prepare_all(sql) + .map_err(|err| err.with_sql_context(sql))?; + let has_ddl = statements + .iter() + .try_fold(false, |has_ddl, stmt| { + Ok::<_, DatabaseError>(has_ddl || matches!(command_type(stmt)?, CommandType::DDL)) + }) + .map_err(|err| err.with_sql_context(sql))?; + + if statements.len() > 1 && has_ddl { + return Err(DatabaseError::UnsupportedStmt( + "DDL is not allowed in multi-statement execution".to_string(), + ) + .with_sql_context(sql)); + } - self.execute(&statement, &[]) + let guard = if has_ddl { + MetaDataLock::Write(self.mdl.write_arc()) + } else { + MetaDataLock::Read(self.mdl.read_arc()) + }; + + let transaction = Box::into_raw(Box::new(self.storage.transaction()?)); + let mut statements = statements.into_iter().peekable(); + + while let Some(statement) = statements.next() { + let (schema, executor) = + match self + .state + .execute(unsafe { &mut (*transaction) }, &statement, &[]) + { + Ok(result) => result, + Err(err) => { + unsafe { drop(Box::from_raw(transaction)) }; + return Err(err.with_sql_context(sql)); + } + }; + + if statements.peek().is_some() { + if let Err(err) = TransactionIter::new(schema, executor).done() { + unsafe { drop(Box::from_raw(transaction)) }; + return Err(err.with_sql_context(sql)); + } + } else { + let inner = Box::into_raw(Box::new(TransactionIter::new(schema, executor))); + return Ok(DatabaseIter { + transaction, + inner, + _guard: guard, + }); + } + } + + unsafe { drop(Box::from_raw(transaction)) }; + Err(DatabaseError::EmptyStatement.with_sql_context(sql)) } pub fn prepare>(&self, sql: T) -> Result { @@ -375,7 +438,7 @@ impl Database { statement: &Statement, params: A, ) -> Result, DatabaseError> { - let _guard = if matches!(command_type(statement)?, CommandType::DDL) { + let guard = if matches!(command_type(statement)?, CommandType::DDL) { MetaDataLock::Write(self.mdl.write_arc()) } else { MetaDataLock::Read(self.mdl.read_arc()) @@ -393,7 +456,11 @@ impl Database { } }; let inner = Box::into_raw(Box::new(TransactionIter::new(schema, executor))); - Ok(DatabaseIter { transaction, inner }) + Ok(DatabaseIter { + transaction, + inner, + _guard: guard, + }) } pub fn new_transaction(&self) -> Result, DatabaseError> { @@ -418,6 +485,7 @@ pub trait ResultIter: Iterator> { pub struct DatabaseIter<'a, S: Storage + 'a> { transaction: *mut S::TransactionType<'a>, inner: *mut TransactionIter<'a>, + _guard: MetaDataLock, } impl Drop for DatabaseIter<'_, S> { @@ -463,9 +531,24 @@ pub struct DBTransaction<'a, S: Storage + 'a> { impl DBTransaction<'_, S> { pub fn run>(&mut self, sql: T) -> Result, DatabaseError> { - let statement = self.state.prepare(sql)?; + let sql = sql.as_ref(); + let mut statements = self + .state + .prepare_all(sql) + .map_err(|err| err.with_sql_context(sql))?; + let last_statement = statements + .pop() + .ok_or_else(|| DatabaseError::EmptyStatement.with_sql_context(sql))?; + + for statement in statements { + self.execute(&statement, &[]) + .map_err(|err| err.with_sql_context(sql))? + .done() + .map_err(|err| err.with_sql_context(sql))?; + } - self.execute(&statement, &[]) + self.execute(&last_statement, &[]) + .map_err(|err| err.with_sql_context(sql)) } pub fn prepare>(&self, sql: T) -> Result { @@ -648,9 +731,9 @@ pub(crate) mod test { // Filter { - let statement = kite_sql.prepare("explain select * from t1 where b > ?1")?; + let statement = kite_sql.prepare("explain select * from t1 where b > $1")?; - let mut iter = kite_sql.execute(&statement, &[("?1", DataValue::Int32(0))])?; + let mut iter = kite_sql.execute(&statement, &[("$1", DataValue::Int32(0))])?; assert_eq!( iter.next().unwrap()?.values[0].utf8().unwrap(), @@ -662,16 +745,16 @@ pub(crate) mod test { // Aggregate { let statement = kite_sql.prepare( - "explain select a + ?1, max(b + ?2) from t1 where b > ?3 group by a + ?4", + "explain select a + $1, max(b + $2) from t1 where b > $3 group by a + $4", )?; let mut iter = kite_sql.execute( &statement, &[ - ("?1", DataValue::Int32(0)), - ("?2", DataValue::Int32(0)), - ("?3", DataValue::Int32(1)), - ("?4", DataValue::Int32(0)), + ("$1", DataValue::Int32(0)), + ("$2", DataValue::Int32(0)), + ("$3", DataValue::Int32(1)), + ("$4", DataValue::Int32(0)), ], )?; assert_eq!( @@ -683,15 +766,15 @@ pub(crate) mod test { ) } { - let statement = kite_sql.prepare("explain select *, ?1 from (select * from t1 where b > ?2) left join (select * from t1 where a > ?3) on a > ?4")?; + let statement = kite_sql.prepare("explain select *, $1 from (select * from t1 where b > $2) left join (select * from t1 where a > $3) on a > $4")?; let mut iter = kite_sql.execute( &statement, &[ - ("?1", DataValue::Int32(9)), - ("?2", DataValue::Int32(0)), - ("?3", DataValue::Int32(1)), - ("?4", DataValue::Int32(0)), + ("$1", DataValue::Int32(9)), + ("$2", DataValue::Int32(0)), + ("$3", DataValue::Int32(1)), + ("$4", DataValue::Int32(0)), ], )?; assert_eq!( @@ -710,6 +793,116 @@ pub(crate) mod test { Ok(()) } + #[test] + fn test_run_multi_statement() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + + kite_sql + .run("create table t_multi (a int primary key, b int)")? + .done()?; + + let mut iter = kite_sql.run( + "insert into t_multi values(0, 0); insert into t_multi values(1, 1); select * from t_multi order by a", + )?; + assert_eq!( + iter.next().unwrap()?.values, + vec![DataValue::Int32(0), DataValue::Int32(0)] + ); + assert_eq!( + iter.next().unwrap()?.values, + vec![DataValue::Int32(1), DataValue::Int32(1)] + ); + assert!(iter.next().is_none()); + iter.done()?; + + Ok(()) + } + + #[test] + fn test_run_multi_statement_disallow_ddl() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + + let err = match kite_sql.run("create table t_multi_ddl (a int primary key); select 1") { + Ok(_) => panic!("multi-statement execution with DDL should be rejected"), + Err(err) => err, + }; + match err { + DatabaseError::UnsupportedStmt(msg) => { + assert!(msg.contains("multi-statement execution")); + } + other => panic!("unexpected error type: {other:?}"), + } + + Ok(()) + } + + #[test] + fn test_bind_error_with_span() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + + kite_sql + .run("create table t_bind_span(id int primary key)")? + .done()?; + + let err = match kite_sql.run("select id, missing_col from t_bind_span") { + Ok(_) => panic!("expected bind error"), + Err(err) => err, + }; + println!("{}", err); + + match err { + DatabaseError::ColumnNotFound { span, .. } + | DatabaseError::InvalidColumn { span, .. } => { + let span = span.expect("bind error should include span"); + assert_eq!(span.line, 1); + assert!(span.start >= 12); + assert!(span.end > span.start); + assert!(span + .highlight + .as_deref() + .is_some_and(|h| h.contains("missing_col"))); + } + other => panic!("unexpected error type: {other:?}"), + } + + Ok(()) + } + + #[test] + fn test_bind_function_error_with_span() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + + kite_sql + .run("create table t_bind_fn_span(id int primary key)")? + .done()?; + + let err = match kite_sql.run("select missing_fn(id) from t_bind_fn_span") { + Ok(_) => panic!("expected function bind error"), + Err(err) => err, + }; + println!("{}", err); + + match err { + DatabaseError::FunctionNotFound { span, .. } => { + let span = span.expect("function bind error should include span"); + assert_eq!(span.line, 1); + assert!(span.start >= 8); + assert!(span.end > span.start); + assert!(span + .highlight + .as_deref() + .is_some_and(|h| h.contains("missing_fn(id)"))); + } + other => panic!("unexpected error type: {other:?}"), + } + + Ok(()) + } + #[test] fn test_transaction_sql() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); @@ -757,6 +950,41 @@ pub(crate) mod test { Ok(()) } + #[test] + fn test_transaction_run_multi_statement() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + + kite_sql + .run("create table t_multi_tx (a int primary key, b int)")? + .done()?; + + let mut tx = kite_sql.new_transaction()?; + let mut iter = tx.run( + "insert into t_multi_tx values(0, 0); insert into t_multi_tx values(1, 1); select * from t_multi_tx order by a", + )?; + assert_eq!( + iter.next().unwrap()?.values, + vec![DataValue::Int32(0), DataValue::Int32(0)] + ); + assert_eq!( + iter.next().unwrap()?.values, + vec![DataValue::Int32(1), DataValue::Int32(1)] + ); + assert!(iter.next().is_none()); + iter.done()?; + tx.commit()?; + + let mut check_iter = kite_sql.run("select count(*) from t_multi_tx")?; + assert_eq!( + check_iter.next().unwrap()?.values, + vec![DataValue::Int32(2)] + ); + check_iter.done()?; + + Ok(()) + } + #[test] fn test_optimistic_transaction_sql() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); diff --git a/src/errors.rs b/src/errors.rs index ef4da5d3..2d6f2cda 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -21,6 +21,36 @@ use std::num::{ParseFloatError, ParseIntError, TryFromIntError}; use std::str::{ParseBoolError, Utf8Error}; use std::string::FromUtf8Error; +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SqlErrorSpan { + pub start: usize, + pub end: usize, + pub line: usize, + pub highlight: Option, +} + +fn format_sql_error_loc(span: &Option) -> String { + span.as_ref() + .map(|s| { + if let Some(highlight) = &s.highlight { + format!("\n{highlight}") + } else { + format!(" at line {}, range {}..{}", s.line, s.start, s.end) + } + }) + .unwrap_or_default() +} + +fn format_not_null_message(column: &Option, span: &Option) -> String { + match column { + Some(column) => format!( + "column: `{column}` cannot be null{}", + format_sql_error_loc(span) + ), + None => format!("cannot be null{}", format_sql_error_loc(span)), + } +} + #[derive(thiserror::Error, Debug)] pub enum DatabaseError { #[error("agg miss: {0}")] @@ -33,16 +63,29 @@ pub enum DatabaseError { ), #[error("cache size overflow")] CacheSizeOverFlow, - #[error("cast fail: {from} -> {to}")] - CastFail { from: LogicalType, to: LogicalType }, + #[error( + "cast fail: {from} -> {to}{loc}", + loc = format_sql_error_loc(span) + )] + CastFail { + from: LogicalType, + to: LogicalType, + span: Option, + }, #[error("channel close")] ChannelClose, #[error("columns empty")] ColumnsEmpty, - #[error("column id: {0} not found")] + #[error("column id: `{0}` not found")] ColumnIdNotFound(String), - #[error("column: {0} not found")] - ColumnNotFound(String), + #[error( + "column: `{name}` not found{loc}", + loc = format_sql_error_loc(span) + )] + ColumnNotFound { + name: String, + span: Option, + }, #[error("csv error: {0}")] Csv( #[from] @@ -53,18 +96,24 @@ pub enum DatabaseError { DefaultNotColumnRef, #[error("default does not exist")] DefaultNotExist, - #[error("column: {0} already exists")] + #[error("column: `{0}` already exists")] DuplicateColumn(String), - #[error("table or view: {0} hash already exists")] + #[error("table or view: `{0}` hash already exists")] DuplicateSourceHash(String), - #[error("index: {0} already exists")] + #[error("index: `{0}` already exists")] DuplicateIndex(String), #[error("duplicate primary key")] DuplicatePrimaryKey, #[error("the column has been declared unique and the value already exists")] DuplicateUniqueValue, - #[error("function: {0} not found")] - FunctionNotFound(String), + #[error( + "function: `{name}` not found{loc}", + loc = format_sql_error_loc(span) + )] + FunctionNotFound { + name: String, + span: Option, + }, #[error("empty plan")] EmptyPlan, #[error("sql statement is empty")] @@ -79,12 +128,24 @@ pub enum DatabaseError { ), #[error("can not compare two types: {0} and {1}")] Incomparable(LogicalType, LogicalType), - #[error("invalid column: {0}")] - InvalidColumn(String), + #[error( + "invalid column: `{name}`{loc}", + loc = format_sql_error_loc(span) + )] + InvalidColumn { + name: String, + span: Option, + }, #[error("invalid index")] InvalidIndex, - #[error("invalid table: {0}")] - InvalidTable(String), + #[error( + "invalid table: `{name}`{loc}", + loc = format_sql_error_loc(span) + )] + InvalidTable { + name: String, + span: Option, + }, #[error("invalid type")] InvalidType, #[error("invalid value: {0}")] @@ -99,12 +160,21 @@ pub enum DatabaseError { MisMatch(&'static str, &'static str), #[error("add column must be nullable or specify a default value")] NeedNullAbleOrDefault, - #[error("parameter: {0} not found")] - ParametersNotFound(String), + #[error( + "parameter: `{name}` not found{loc}", + loc = format_sql_error_loc(span) + )] + ParametersNotFound { + name: String, + span: Option, + }, #[error("no transaction begin")] NoTransactionBegin, - #[error("cannot be null")] - NotNull, + #[error("{msg}", msg = format_not_null_message(column, span))] + NotNull { + column: Option, + span: Option, + }, #[error("over flow")] OverFlow, #[error("parser bool: {0}")] @@ -197,3 +267,185 @@ pub enum DatabaseError { #[error("the view not found")] ViewNotFound, } + +impl DatabaseError { + pub fn invalid_column(name: impl Into) -> Self { + Self::InvalidColumn { + name: name.into(), + span: None, + } + } + + pub fn column_not_found(name: impl Into) -> Self { + Self::ColumnNotFound { + name: name.into(), + span: None, + } + } + + pub fn invalid_table(name: impl Into) -> Self { + Self::InvalidTable { + name: name.into(), + span: None, + } + } + + pub fn function_not_found(name: impl Into) -> Self { + Self::FunctionNotFound { + name: name.into(), + span: None, + } + } + + pub fn parameter_not_found(name: impl Into) -> Self { + Self::ParametersNotFound { + name: name.into(), + span: None, + } + } + + pub fn not_null() -> Self { + Self::NotNull { + column: None, + span: None, + } + } + + pub fn not_null_column(name: impl Into) -> Self { + Self::NotNull { + column: Some(name.into()), + span: None, + } + } + + pub fn with_span(self, span: SqlErrorSpan) -> Self { + match self { + Self::CastFail { from, to, .. } => Self::CastFail { + from, + to, + span: Some(span), + }, + Self::InvalidColumn { name, .. } => Self::InvalidColumn { + name, + span: Some(span), + }, + Self::ColumnNotFound { name, .. } => Self::ColumnNotFound { + name, + span: Some(span), + }, + Self::InvalidTable { name, .. } => Self::InvalidTable { + name, + span: Some(span), + }, + Self::FunctionNotFound { name, .. } => Self::FunctionNotFound { + name, + span: Some(span), + }, + Self::ParametersNotFound { name, .. } => Self::ParametersNotFound { + name, + span: Some(span), + }, + Self::NotNull { column, .. } => Self::NotNull { + column, + span: Some(span), + }, + other => other, + } + } + + pub fn with_sql_context(self, sql: &str) -> Self { + let annotate = |span: Option| -> Option { + span.map(|mut span| { + if span.highlight.is_none() { + span.highlight = build_sql_highlight(sql, &span); + } + span + }) + }; + + match self { + Self::CastFail { from, to, span } => Self::CastFail { + from, + to, + span: annotate(span), + }, + Self::InvalidColumn { name, span } => Self::InvalidColumn { + name, + span: annotate(span), + }, + Self::ColumnNotFound { name, span } => Self::ColumnNotFound { + name, + span: annotate(span), + }, + Self::InvalidTable { name, span } => Self::InvalidTable { + name, + span: annotate(span), + }, + Self::FunctionNotFound { name, span } => Self::FunctionNotFound { + name, + span: annotate(span), + }, + Self::ParametersNotFound { name, span } => Self::ParametersNotFound { + name, + span: annotate(span), + }, + Self::NotNull { column, span } => Self::NotNull { + column, + span: annotate(span), + }, + other => other, + } + } + + pub fn sql_error_span(&self) -> Option<&SqlErrorSpan> { + match self { + DatabaseError::CastFail { span, .. } + | DatabaseError::InvalidColumn { span, .. } + | DatabaseError::ColumnNotFound { span, .. } + | DatabaseError::InvalidTable { span, .. } + | DatabaseError::FunctionNotFound { span, .. } + | DatabaseError::ParametersNotFound { span, .. } + | DatabaseError::NotNull { span, .. } => span.as_ref(), + _ => None, + } + } +} + +fn build_sql_highlight(sql: &str, span: &SqlErrorSpan) -> Option { + if span.line == 0 || span.start == 0 { + return None; + } + + let lines = sql + .lines() + .map(|line| line.trim_end_matches('\r').to_string()) + .collect::>(); + if lines.is_empty() || span.line > lines.len() { + return None; + } + + let width = lines.len().to_string().len(); + let mut out = String::new(); + out.push_str(&format!("--> line {}\n", span.line)); + + for (i, line) in lines.iter().enumerate() { + let line_no = i + 1; + out.push_str(&format!("{line_no:>width$} | {line}\n", width = width)); + + if line_no == span.line { + let char_len = line.chars().count(); + let start = span.start.saturating_sub(1).min(char_len); + let end = span.end.min(char_len).max(start + 1); + let marker_len = end.saturating_sub(start).max(1); + out.push_str(&format!( + "{:>width$} | {}{}\n", + "", + " ".repeat(start), + "^".repeat(marker_len), + width = width + )); + } + } + + Some(out.trim_end().to_string()) +} diff --git a/src/execution/ddl/drop_column.rs b/src/execution/ddl/drop_column.rs index 32d40097..f17dda21 100644 --- a/src/execution/ddl/drop_column.rs +++ b/src/execution/ddl/drop_column.rs @@ -56,7 +56,7 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropColumn { if is_primary { throw!( co, - Err(DatabaseError::InvalidColumn( + Err(DatabaseError::invalid_column( "drop of primary key column is not allowed.".to_owned(), )) ); @@ -104,7 +104,7 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropColumn { co.yield_(Ok(TupleBuilder::build_result("1".to_string()))) .await; } else if !if_exists { - co.yield_(Err(DatabaseError::ColumnNotFound(column_name))) + co.yield_(Err(DatabaseError::column_not_found(column_name))) .await; } }) diff --git a/src/execution/dml/insert.rs b/src/execution/dml/insert.rs index 6f244300..b24ddeaa 100644 --- a/src/execution/dml/insert.rs +++ b/src/execution/dml/insert.rs @@ -96,7 +96,7 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Insert { .map(|(_, col)| col.key(is_mapping_by_name)) .collect_vec(); if primary_keys.is_empty() { - throw!(co, Err(DatabaseError::NotNull)) + throw!(co, Err(DatabaseError::not_null())) } if let Some(table_catalog) = throw!( @@ -152,7 +152,8 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Insert { value.unwrap_or(DataValue::Null) }; if value.is_null() && !col.nullable() { - co.yield_(Err(DatabaseError::NotNull)).await; + co.yield_(Err(DatabaseError::not_null_column(col.name().to_string()))) + .await; return; } values.push(value) diff --git a/src/function/numbers.rs b/src/function/numbers.rs index cf4a67fd..2a21e994 100644 --- a/src/function/numbers.rs +++ b/src/function/numbers.rs @@ -71,7 +71,9 @@ impl TableFunctionImpl for Numbers { if value.logical_type() != LogicalType::Integer { value = value.cast(&LogicalType::Integer)?; } - let num = value.i32().ok_or(DatabaseError::NotNull)?; + let num = value + .i32() + .ok_or_else(|| DatabaseError::not_null_column("numbers() arg"))?; Ok( Box::new((0..num).map(|i| Ok(Tuple::new(None, vec![DataValue::Int32(i)])))) diff --git a/src/optimizer/plan_utils.rs b/src/optimizer/plan_utils.rs index 8e16739e..39666769 100644 --- a/src/optimizer/plan_utils.rs +++ b/src/optimizer/plan_utils.rs @@ -120,7 +120,7 @@ pub fn replace_child_with_only_child(plan: &mut LogicalPlan, child_idx: usize) - pub fn wrap_child_with(plan: &mut LogicalPlan, child_idx: usize, operator: Operator) -> bool { if let Some(slot) = child_mut(plan, child_idx) { let previous = mem::replace(slot, LogicalPlan::new(operator, Childrens::None)); - slot.childrens = Box::new(Childrens::Only(Box::new(previous))); + *slot.childrens = Childrens::Only(Box::new(previous)); true } else { false diff --git a/src/planner/operator/table_scan.rs b/src/planner/operator/table_scan.rs index 72d9fa1f..aefd78f2 100644 --- a/src/planner/operator/table_scan.rs +++ b/src/planner/operator/table_scan.rs @@ -65,7 +65,7 @@ impl TableScanOperator { let mut sort_fields = Vec::with_capacity(index_meta.column_ids.len()); for col_id in &index_meta.column_ids { let column = table_catalog.get_column_by_id(col_id).ok_or_else(|| { - DatabaseError::ColumnNotFound(format!("index column id: {col_id} not found")) + DatabaseError::column_not_found(format!("index column id: {col_id} not found")) })?; sort_fields.push(SortField { expr: ScalarExpression::column_expr(column.clone()), diff --git a/src/serdes/column.rs b/src/serdes/column.rs index a8ccfd1c..5ce8caae 100644 --- a/src/serdes/column.rs +++ b/src/serdes/column.rs @@ -68,7 +68,7 @@ impl ReferenceSerialization for ColumnRef { .ok_or(DatabaseError::TableNotFound)?; let column = table .get_column_by_id(column_id) - .ok_or(DatabaseError::InvalidColumn(format!( + .ok_or(DatabaseError::invalid_column(format!( "column id: {column_id} not found" )))?; Ok(nullable_for_join diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index 193f33b3..402fff23 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -75,7 +75,7 @@ impl TableCodec { return Err(DatabaseError::PrimaryKeyTooManyLayers); } if value.is_null() { - return Err(DatabaseError::NotNull); + return Err(DatabaseError::not_null_column("primary key")); } if let DataValue::Tuple(values, _) = &value { @@ -451,7 +451,7 @@ impl TableCodec { Ok((key_prefix, column_bytes)) } else { - Err(DatabaseError::InvalidColumn( + Err(DatabaseError::invalid_column( "column does not belong to table".to_string(), )) } diff --git a/src/types/index.rs b/src/types/index.rs index b009aaeb..40919e77 100644 --- a/src/types/index.rs +++ b/src/types/index.rs @@ -71,7 +71,7 @@ impl IndexMeta { if let Some(column) = table.get_column_by_id(column_id) { exprs.push(ScalarExpression::column_expr(column.clone())); } else { - return Err(DatabaseError::ColumnNotFound(column_id.to_string())); + return Err(DatabaseError::column_not_found(column_id.to_string())); } } Ok(exprs) diff --git a/src/types/mod.rs b/src/types/mod.rs index 6a17416f..6753f8dc 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -396,9 +396,18 @@ impl TryFrom for LogicalType { | sqlparser::ast::DataType::Character(char_len) => { let mut len = 1; let mut char_unit = None; - if let Some(sqlparser::ast::CharacterLength { length, unit }) = char_len { - len = cmp::max(len, length); - char_unit = unit; + if let Some(char_len) = char_len { + match char_len { + sqlparser::ast::CharacterLength::IntegerLength { length, unit } => { + len = cmp::max(len, length); + char_unit = unit; + } + sqlparser::ast::CharacterLength::Max => { + return Err(DatabaseError::UnsupportedStmt( + "CHAR(MAX) is not supported".to_string(), + )); + } + } } Ok(LogicalType::Char( len as u32, @@ -410,35 +419,62 @@ impl TryFrom for LogicalType { | sqlparser::ast::DataType::Varchar(varchar_len) => { let mut len = None; let mut char_unit = None; - if let Some(sqlparser::ast::CharacterLength { length, unit }) = varchar_len { - len = Some(length as u32); - char_unit = unit; + if let Some(varchar_len) = varchar_len { + match varchar_len { + sqlparser::ast::CharacterLength::IntegerLength { length, unit } => { + len = Some(length as u32); + char_unit = unit; + } + sqlparser::ast::CharacterLength::Max => { + return Err(DatabaseError::UnsupportedStmt( + "VARCHAR(MAX) is not supported".to_string(), + )); + } + } } Ok(LogicalType::Varchar( len, char_unit.unwrap_or(CharLengthUnits::Characters), )) } - sqlparser::ast::DataType::String | sqlparser::ast::DataType::Text => { + sqlparser::ast::DataType::String(_) | sqlparser::ast::DataType::Text => { Ok(LogicalType::Varchar(None, CharLengthUnits::Characters)) } - sqlparser::ast::DataType::Float(_) | sqlparser::ast::DataType::Real => { - Ok(LogicalType::Float) - } - sqlparser::ast::DataType::Double | sqlparser::ast::DataType::DoublePrecision => { - Ok(LogicalType::Double) - } + sqlparser::ast::DataType::Float(_) + | sqlparser::ast::DataType::Float4 + | sqlparser::ast::DataType::Float32 + | sqlparser::ast::DataType::Real => Ok(LogicalType::Float), + sqlparser::ast::DataType::Double(_) + | sqlparser::ast::DataType::DoublePrecision + | sqlparser::ast::DataType::Float8 + | sqlparser::ast::DataType::Float64 => Ok(LogicalType::Double), sqlparser::ast::DataType::TinyInt(_) => Ok(LogicalType::Tinyint), - sqlparser::ast::DataType::UnsignedTinyInt(_) => Ok(LogicalType::UTinyint), - sqlparser::ast::DataType::SmallInt(_) => Ok(LogicalType::Smallint), - sqlparser::ast::DataType::UnsignedSmallInt(_) => Ok(LogicalType::USmallint), - sqlparser::ast::DataType::Int(_) | sqlparser::ast::DataType::Integer(_) => { - Ok(LogicalType::Integer) + sqlparser::ast::DataType::TinyIntUnsigned(_) | sqlparser::ast::DataType::UTinyInt => { + Ok(LogicalType::UTinyint) + } + sqlparser::ast::DataType::SmallInt(_) | sqlparser::ast::DataType::Int2(_) => { + Ok(LogicalType::Smallint) } - sqlparser::ast::DataType::UnsignedInt(_) - | sqlparser::ast::DataType::UnsignedInteger(_) => Ok(LogicalType::UInteger), - sqlparser::ast::DataType::BigInt(_) => Ok(LogicalType::Bigint), - sqlparser::ast::DataType::UnsignedBigInt(_) => Ok(LogicalType::UBigint), + sqlparser::ast::DataType::SmallIntUnsigned(_) + | sqlparser::ast::DataType::Int2Unsigned(_) + | sqlparser::ast::DataType::USmallInt => Ok(LogicalType::USmallint), + sqlparser::ast::DataType::Int(_) + | sqlparser::ast::DataType::Integer(_) + | sqlparser::ast::DataType::Int4(_) + | sqlparser::ast::DataType::Int32 => Ok(LogicalType::Integer), + sqlparser::ast::DataType::IntUnsigned(_) + | sqlparser::ast::DataType::IntegerUnsigned(_) + | sqlparser::ast::DataType::Int4Unsigned(_) + | sqlparser::ast::DataType::Unsigned + | sqlparser::ast::DataType::UnsignedInteger + | sqlparser::ast::DataType::UInt32 => Ok(LogicalType::UInteger), + sqlparser::ast::DataType::BigInt(_) + | sqlparser::ast::DataType::Int8(_) + | sqlparser::ast::DataType::Int64 => Ok(LogicalType::Bigint), + sqlparser::ast::DataType::BigIntUnsigned(_) + | sqlparser::ast::DataType::Int8Unsigned(_) + | sqlparser::ast::DataType::UBigInt + | sqlparser::ast::DataType::UInt64 => Ok(LogicalType::UBigint), sqlparser::ast::DataType::Boolean => Ok(LogicalType::Boolean), sqlparser::ast::DataType::Date => Ok(LogicalType::Date), sqlparser::ast::DataType::Datetime(precision) => { @@ -481,7 +517,9 @@ impl TryFrom for LogicalType { Ok(LogicalType::TimeStamp(precision, zone)) } sqlparser::ast::DataType::Decimal(info) + | sqlparser::ast::DataType::DecimalUnsigned(info) | sqlparser::ast::DataType::Dec(info) + | sqlparser::ast::DataType::DecUnsigned(info) | sqlparser::ast::DataType::Numeric(info) => match info { ExactNumberInfo::None => Ok(Self::Decimal(None, None)), ExactNumberInfo::Precision(p) => Ok(Self::Decimal(Some(p as u8), None)), diff --git a/src/types/value.rs b/src/types/value.rs index d7b060ca..7a631ffa 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -354,6 +354,7 @@ macro_rules! numeric_to_boolean { _ => Err(DatabaseError::CastFail { from: $from_ty, to: LogicalType::Boolean, + span: None, }), } }; @@ -1129,6 +1130,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Float32(value) => match to { @@ -1146,6 +1148,7 @@ impl DataValue { Decimal::from_f32(value.0).ok_or(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, })?; Self::decimal_round_f(option, &mut decimal); @@ -1192,6 +1195,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Float64(value) => match to { @@ -1209,6 +1213,7 @@ impl DataValue { Decimal::from_f64(value.0).ok_or(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, })?; Self::decimal_round_f(option, &mut decimal); @@ -1255,6 +1260,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Int8(value) => match to { @@ -1285,6 +1291,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Int16(value) => match to { @@ -1315,6 +1322,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Int32(value) => match to { @@ -1345,6 +1353,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Int64(value) => match to { @@ -1375,6 +1384,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::UInt8(value) => match to { @@ -1405,6 +1415,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::UInt16(value) => match to { @@ -1435,6 +1446,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::UInt32(value) => match to { @@ -1465,6 +1477,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::UInt64(value) => match to { @@ -1495,6 +1508,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Utf8 { ref value, .. } => match to { @@ -1517,9 +1531,7 @@ impl DataValue { varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) } LogicalType::Date => { - let value = NaiveDate::parse_from_str(value, DATE_FMT) - .map(|date| date.num_days_from_ce()) - .unwrap(); + let value = NaiveDate::parse_from_str(value, DATE_FMT)?.num_days_from_ce(); Ok(DataValue::Date32(value)) } LogicalType::DateTime => { @@ -1596,6 +1608,7 @@ impl DataValue { return Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }); } } @@ -1608,6 +1621,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Date32(value) => match to { @@ -1616,7 +1630,8 @@ impl DataValue { varchar_cast!( Self::format_date(value).ok_or(DatabaseError::CastFail { from: self.logical_type(), - to: to.clone() + to: to.clone(), + span: None, })?, Some(len), Utf8Type::Fixed(*len), @@ -1627,7 +1642,8 @@ impl DataValue { varchar_cast!( Self::format_date(value).ok_or(DatabaseError::CastFail { from: self.logical_type(), - to: to.clone() + to: to.clone(), + span: None, })?, len, Utf8Type::Variable(*len), @@ -1640,11 +1656,13 @@ impl DataValue { .ok_or(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, })? .and_hms_opt(0, 0, 0) .ok_or(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, })? .and_utc() .timestamp(); @@ -1654,6 +1672,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Date64(value) => match to { @@ -1662,7 +1681,8 @@ impl DataValue { varchar_cast!( Self::format_datetime(value).ok_or(DatabaseError::CastFail { from: self.logical_type(), - to: to.clone() + to: to.clone(), + span: None, })?, Some(len), Utf8Type::Fixed(*len), @@ -1673,7 +1693,8 @@ impl DataValue { varchar_cast!( Self::format_datetime(value).ok_or(DatabaseError::CastFail { from: self.logical_type(), - to: to.clone() + to: to.clone(), + span: None, })?, len, Utf8Type::Variable(*len), @@ -1685,6 +1706,7 @@ impl DataValue { .ok_or(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, })? .naive_utc() .date() @@ -1703,6 +1725,7 @@ impl DataValue { .ok_or(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, })?; Ok(DataValue::Time32(Self::pack(value, 0, 0), precision)) @@ -1717,6 +1740,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Time32(value, precision) => match to { @@ -1725,7 +1749,8 @@ impl DataValue { varchar_cast!( Self::format_time(value, precision).ok_or(DatabaseError::CastFail { from: self.logical_type(), - to: to.clone() + to: to.clone(), + span: None, })?, Some(len), Utf8Type::Fixed(*len), @@ -1736,7 +1761,8 @@ impl DataValue { varchar_cast!( Self::format_time(value, precision).ok_or(DatabaseError::CastFail { from: self.logical_type(), - to: to.clone() + to: to.clone(), + span: None, })?, len, Utf8Type::Variable(*len), @@ -1749,6 +1775,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Time64(value, precision, _) => match to { @@ -1758,7 +1785,8 @@ impl DataValue { Self::format_timestamp(value, precision).ok_or( DatabaseError::CastFail { from: self.logical_type(), - to: to.clone() + to: to.clone(), + span: None, } )?, Some(len), @@ -1771,7 +1799,8 @@ impl DataValue { Self::format_timestamp(value, precision).ok_or( DatabaseError::CastFail { from: self.logical_type(), - to: to.clone() + to: to.clone(), + span: None, } )?, len, @@ -1784,6 +1813,7 @@ impl DataValue { .ok_or(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, })? .naive_utc() .date() @@ -1796,6 +1826,7 @@ impl DataValue { .ok_or(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, })? .timestamp(); Ok(DataValue::Date64(value)) @@ -1812,6 +1843,7 @@ impl DataValue { .ok_or(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, })?; Ok(DataValue::Time32(Self::pack(value, nano, p), p)) } @@ -1821,6 +1853,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Decimal(value) => match to { @@ -1829,12 +1862,14 @@ impl DataValue { DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }, )?))), LogicalType::Double => Ok(DataValue::Float64(OrderedFloat(value.to_f64().ok_or( DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }, )?))), LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(value)), @@ -1855,6 +1890,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: self.logical_type(), to: to.clone(), + span: None, }), }, DataValue::Tuple(mut values, is_upper) => match to { @@ -1869,6 +1905,7 @@ impl DataValue { _ => Err(DatabaseError::CastFail { from: LogicalType::Tuple(values.iter().map(DataValue::logical_type).collect()), to: to.clone(), + span: None, }), }, }?; diff --git a/tests/macros-test/Cargo.toml b/tests/macros-test/Cargo.toml index b18b88e2..e2dcfbc1 100644 --- a/tests/macros-test/Cargo.toml +++ b/tests/macros-test/Cargo.toml @@ -7,5 +7,5 @@ edition = "2021" "kite_sql" = { path = "../.." } lazy_static = { version = "1" } serde = { version = "1", features = ["derive", "rc"] } -sqlparser = { version = "0.34", features = ["serde"] } -typetag = { version = "0.2" } \ No newline at end of file +sqlparser = { version = "0.61", features = ["serde"] } +typetag = { version = "0.2" } diff --git a/tests/slt/stream_distinct_explain.slt b/tests/slt/stream_distinct_explain.slt new file mode 100644 index 00000000..4dc32b46 --- /dev/null +++ b/tests/slt/stream_distinct_explain.slt @@ -0,0 +1,29 @@ +statement ok +create table distinct_t(id int primary key, c1 int, c2 int); + +statement ok +copy distinct_t from 'tests/data/distinct_rows.csv' ( DELIMITER '|' ); + +statement ok +create index distinct_t_c1_index on distinct_t (c1); + +statement ok +analyze table distinct_t; + +# stream distinct +query T +explain select distinct c1 from distinct_t where c1 < 10 and c1 > 0; +---- +Projection [distinct_t.c1] [Project => (Sort Option: Follow)] Aggregate [] -> Group By [distinct_t.c1] [StreamDistinct => (Sort Option: Follow)] Filter ((distinct_t.c1 < 10) && (distinct_t.c1 > 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan distinct_t -> [c1] [IndexScan By distinct_t_c1_index => (0, 10) Covered => (Sort Option: OrderBy: (distinct_t.c1 Asc Nulls Last) ignore_prefix_len: 0)] + +statement ok +drop index distinct_t.distinct_t_c1_index; + +# hash distinct +query T +explain select distinct c1 from distinct_t where c1 < 10 and c1 > 0; +---- +Projection [distinct_t.c1] [Project => (Sort Option: Follow)] Aggregate [] -> Group By [distinct_t.c1] [HashAggregate => (Sort Option: None)] Filter ((distinct_t.c1 < 10) && (distinct_t.c1 > 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan distinct_t -> [c1] [SeqScan => (Sort Option: None)] + +statement ok +drop table distinct_t; diff --git a/tests/slt/where_by_index_explain.slt b/tests/slt/where_by_index_explain.slt new file mode 100644 index 00000000..f31f5f6b --- /dev/null +++ b/tests/slt/where_by_index_explain.slt @@ -0,0 +1,221 @@ +statement ok +create table t1(id int primary key, c1 int, c2 int); + +statement ok +copy t1 from 'tests/data/row_20000.csv' ( DELIMITER '|' ); + +statement ok +insert into t1 values(100000000, null, null); + +statement ok +create unique index u_c1_index on t1 (c1); + +statement ok +create index c2_index on t1 (c2); + +statement ok +create index p_index on t1 (c1, c2); + +statement ok +analyze table t1; + +query T +explain select * from t1 limit 10; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2], Limit: 10 [SeqScan => (Sort Option: None)] + +query T +explain select * from t1 where id = 0; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.id = 0), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => 0 => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where id = 0 and id = 1; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id = 0) && (t1.id = 1)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => Dummy => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where id = 0 and id != 0; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id = 0) && (t1.id != 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => 0 => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where id = 0 or id != 0 limit 10; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Limit 10 [Limit => (Sort Option: Follow)] Filter ((t1.id = 0) || (t1.id != 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [SeqScan => (Sort Option: None)] + +query T +explain select * from t1 where id = 0 and id != 0 and id = 3; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (((t1.id = 0) && (t1.id != 0)) && (t1.id = 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => Dummy => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where id = 0 and id != 0 or id = 3; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (((t1.id = 0) && (t1.id != 0)) || (t1.id = 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => 0, 3 => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where id > 0 and id = 3; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id > 0) && (t1.id = 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => 3 => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where id >= 0 and id <= 3; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id >= 0) && (t1.id <= 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => [0, 3] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where id <= 0 and id >= 3; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id <= 0) && (t1.id >= 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => Dummy => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where id >= 3 or id <= 9 limit 10; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Limit 10 [Limit => (Sort Option: Follow)] Filter ((t1.id >= 3) || (t1.id <= 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [SeqScan => (Sort Option: None)] + +query T +explain select * from t1 where id <= 3 or id >= 9 limit 10; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Limit 10 [Limit => (Sort Option: Follow)] Filter ((t1.id <= 3) || (t1.id >= 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => (-inf, 3], [9, +inf) => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where (id >= 0 and id <= 3) or (id >= 9 and id <= 12); +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (((t1.id >= 0) && (t1.id <= 3)) || ((t1.id >= 9) && (t1.id <= 12))), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => [0, 3], [9, 12] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where (id >= 0 or id <= 3) and (id >= 9 or id <= 12) limit 10; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Limit 10 [Limit => (Sort Option: Follow)] Filter (((t1.id >= 0) || (t1.id <= 3)) && ((t1.id >= 9) || (t1.id <= 12))), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [SeqScan => (Sort Option: None)] + +query T +explain select * from t1 where id = 5 or (id > 5 and (id > 6 or id < 8) and id < 12); +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id = 5) || (((t1.id > 5) && ((t1.id > 6) || (t1.id < 8))) && (t1.id < 12))), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => [5, 12) => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where c1 = 7 and c2 = 8; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c1 = 7) && (t1.c2 = 8)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By u_c1_index => 7 => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where c1 = 7 and c2 < 9; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c1 = 7) && (t1.c2 < 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By u_c1_index => 7 => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where (c1 = 7 or c1 = 10) and c2 < 9; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (((t1.c1 = 7) || (t1.c1 = 10)) && (t1.c2 < 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By u_c1_index => 7, 10 => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where c1 is null and c2 is null; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.c1 is null && t1.c2 is null), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By c2_index => null => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where c1 > 0 and c1 < 8; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c1 > 0) && (t1.c1 < 8)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By u_c1_index => (0, 8) => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where c2 > 0 and c2 < 9; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c2 > 0) && (t1.c2 < 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By c2_index => (0, 9) => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where c2 = 5; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.c2 = 5), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By c2_index => 5 => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] + +statement ok +update t1 set c2 = 9 where c1 = 1 + +query T +explain select * from t1 where c2 > 0 and c2 < 10; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c2 > 0) && (t1.c2 < 10)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By c2_index => (0, 10) => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] + +statement ok +delete from t1 where c1 = 4 + +query T +explain select * from t1 where c2 > 0 and c2 < 10; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c2 > 0) && (t1.c2 < 10)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By c2_index => (0, 10) => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] + + +# unique covered +query T +explain select c1 from t1 where c1 < 10; +---- +Projection [t1.c1] [Project => (Sort Option: Follow)] Filter (t1.c1 < 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [c1] [IndexScan By p_index => (-inf, (10)) Covered => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last, t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] + +# unique covered with primary key projection +query T +explain select c1, id from t1 where c1 < 10; +---- +Projection [t1.c1, t1.id] [Project => (Sort Option: Follow)] Filter (t1.c1 < 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1] [IndexScan By p_index => (-inf, (10)) => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last, t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] + +statement ok +drop index t1.u_c1_index; + +# normal covered +query T +explain select c2 from t1 where c2 < 10 and c2 > 0; +---- +Projection [t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c2 < 10) && (t1.c2 > 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [c2] [IndexScan By c2_index => (0, 10) Covered => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] + +statement ok +insert into t1 values(100000002, 100000002, 8); + +# stream distinct +query T +explain select distinct c2 from t1 where c2 < 10 and c2 > 0; +---- +Projection [t1.c2] [Project => (Sort Option: Follow)] Aggregate [] -> Group By [t1.c2] [StreamDistinct => (Sort Option: Follow)] Filter ((t1.c2 < 10) && (t1.c2 > 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [c2] [IndexScan By c2_index => (0, 10) Covered => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] + +statement ok +delete from t1 where id = 100000002; + +statement ok +drop index t1.c2_index; + +# composite covered +query T +explain select c1, c2 from t1 where c1 < 10 and c1 > 0 and c2 >0 and c2 < 10; +---- +Projection [t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((((t1.c1 < 10) && (t1.c1 > 0)) && (t1.c2 > 0)) && (t1.c2 < 10)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [c1, c2] [IndexScan By p_index => ((0), (10)) Covered => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last, t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] + +# composite covered projection reorder +query T +explain select c2, c1 from t1 where c1 < 10 and c1 > 0 and c2 > 0 and c2 < 10; +---- +Projection [t1.c2, t1.c1] [Project => (Sort Option: Follow)] Filter ((((t1.c1 < 10) && (t1.c1 > 0)) && (t1.c2 > 0)) && (t1.c2 < 10)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [c1, c2] [IndexScan By p_index => ((0), (10)) Covered => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last, t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] + + +statement ok +drop table t1; + +statement ok +create table t_cover(id int primary key, c1 int, c2 int, c3 int); + +statement ok +insert into t_cover values + (1, 1, 10, 11), + (2, 2, 20, 21), + (3, 2, 22, 23), + (4, 3, 30, 31); + +statement ok +create index idx_cover on t_cover (c1, c2, c3); + +# composite index trailing columns cover (index columns > output columns) +query T +explain select c2, c3 from t_cover where c1 = 2; +---- +Projection [t_cover.c2, t_cover.c3] [Project => (Sort Option: Follow)] Filter (t_cover.c1 = 2), Is Having: false [Filter => (Sort Option: Follow)] TableScan t_cover -> [c1, c2, c3] [SeqScan => (Sort Option: None)] + +statement ok +drop table t_cover; diff --git a/tpcc/Cargo.toml b/tpcc/Cargo.toml index 5d219f84..363b6ae9 100644 --- a/tpcc/Cargo.toml +++ b/tpcc/Cargo.toml @@ -13,4 +13,4 @@ rand = { version = "0.8" } rust_decimal = { version = "1" } thiserror = { version = "1" } sqlite = { version = "0.34" } -sqlparser = { version = "0.34" } +sqlparser = { version = "0.61" } diff --git a/tpcc/src/delivery.rs b/tpcc/src/delivery.rs index a752168c..994a1226 100644 --- a/tpcc/src/delivery.rs +++ b/tpcc/src/delivery.rs @@ -50,8 +50,8 @@ impl TpccTransaction for Delivery { let tuple = tx.query_one( &statements[0], &[ - ("?1", DataValue::Int8(d_id as i8)), - ("?2", DataValue::Int16(args.w_id as i16)), + ("$1", DataValue::Int8(d_id as i8)), + ("$2", DataValue::Int16(args.w_id as i16)), ], )?; let no_o_id = tuple.values[0].i32().unwrap(); @@ -63,18 +63,18 @@ impl TpccTransaction for Delivery { tx.execute_drain( &statements[1], &[ - ("?1", DataValue::Int32(no_o_id)), - ("?2", DataValue::Int8(d_id as i8)), - ("?3", DataValue::Int16(args.w_id as i16)), + ("$1", DataValue::Int32(no_o_id)), + ("$2", DataValue::Int8(d_id as i8)), + ("$3", DataValue::Int16(args.w_id as i16)), ], )?; // "SELECT o_c_id FROM orders WHERE o_id = ? AND o_d_id = ? AND o_w_id = ?" let tuple = tx.query_one( &statements[2], &[ - ("?1", DataValue::Int32(no_o_id)), - ("?2", DataValue::Int8(d_id as i8)), - ("?3", DataValue::Int16(args.w_id as i16)), + ("$1", DataValue::Int32(no_o_id)), + ("$2", DataValue::Int8(d_id as i8)), + ("$3", DataValue::Int16(args.w_id as i16)), ], )?; let c_id = tuple.values[0].i32().unwrap(); @@ -82,29 +82,29 @@ impl TpccTransaction for Delivery { tx.execute_drain( &statements[3], &[ - ("?1", DataValue::Int8(args.o_carrier_id as i8)), - ("?2", DataValue::Int32(no_o_id)), - ("?3", DataValue::Int8(d_id as i8)), - ("?4", DataValue::Int16(args.w_id as i16)), + ("$1", DataValue::Int8(args.o_carrier_id as i8)), + ("$2", DataValue::Int32(no_o_id)), + ("$3", DataValue::Int8(d_id as i8)), + ("$4", DataValue::Int16(args.w_id as i16)), ], )?; // "UPDATE order_line SET ol_delivery_d = ? WHERE ol_o_id = ? AND ol_d_id = ? AND ol_w_id = ?" tx.execute_drain( &statements[4], &[ - ("?1", DataValue::from(&now)), - ("?2", DataValue::Int32(no_o_id)), - ("?3", DataValue::Int8(d_id as i8)), - ("?4", DataValue::Int16(args.w_id as i16)), + ("$1", DataValue::from(&now)), + ("$2", DataValue::Int32(no_o_id)), + ("$3", DataValue::Int8(d_id as i8)), + ("$4", DataValue::Int16(args.w_id as i16)), ], )?; // "SELECT SUM(ol_amount) FROM order_line WHERE ol_o_id = ? AND ol_d_id = ? AND ol_w_id = ?" let tuple = tx.query_one( &statements[5], &[ - ("?1", DataValue::Int32(no_o_id)), - ("?2", DataValue::Int8(d_id as i8)), - ("?3", DataValue::Int16(args.w_id as i16)), + ("$1", DataValue::Int32(no_o_id)), + ("$2", DataValue::Int8(d_id as i8)), + ("$3", DataValue::Int16(args.w_id as i16)), ], )?; let ol_total = tuple.values[0].decimal().unwrap(); @@ -112,10 +112,10 @@ impl TpccTransaction for Delivery { tx.execute_drain( &statements[6], &[ - ("?1", DataValue::Decimal(ol_total)), - ("?2", DataValue::Int32(c_id)), - ("?3", DataValue::Int8(d_id as i8)), - ("?4", DataValue::Int16(args.w_id as i16)), + ("$1", DataValue::Decimal(ol_total)), + ("$2", DataValue::Int32(c_id)), + ("$3", DataValue::Int8(d_id as i8)), + ("$4", DataValue::Int16(args.w_id as i16)), ], )?; } diff --git a/tpcc/src/main.rs b/tpcc/src/main.rs index 72c389ed..37e0100a 100644 --- a/tpcc/src/main.rs +++ b/tpcc/src/main.rs @@ -60,8 +60,8 @@ const TX_NAMES: [&str; 5] = [ "Delivery", "Stock-Level", ]; -pub(crate) const STOCK_LEVEL_DISTINCT_SQL: &str = "SELECT DISTINCT ol_i_id FROM order_line WHERE ol_w_id = ?1 AND ol_d_id = ?2 AND ol_o_id < ?3 AND ol_o_id >= (?4 - 20)"; -pub(crate) const STOCK_LEVEL_DISTINCT_SQLITE: &str = "SELECT DISTINCT ol_i_id FROM (SELECT ol_i_id FROM order_line WHERE ol_w_id = ?1 AND ol_d_id = ?2 AND ol_o_id < ?3 AND ol_o_id >= (?4 - 20) ORDER BY ol_w_id, ol_d_id, ol_o_id)"; +pub(crate) const STOCK_LEVEL_DISTINCT_SQL: &str = "SELECT DISTINCT ol_i_id FROM order_line WHERE ol_w_id = $1 AND ol_d_id = $2 AND ol_o_id < $3 AND ol_o_id >= ($4 - 20)"; +pub(crate) const STOCK_LEVEL_DISTINCT_SQLITE: &str = "SELECT DISTINCT ol_i_id FROM (SELECT ol_i_id FROM order_line WHERE ol_w_id = $1 AND ol_d_id = $2 AND ol_o_id < $3 AND ol_o_id >= ($4 - 20) ORDER BY ol_w_id, ol_d_id, ol_o_id)"; pub(crate) trait TpccTransaction { type Args; @@ -274,39 +274,39 @@ fn statement_specs() -> Vec> { vec![ vec![ stmt( - "SELECT c.c_discount, c.c_last, c.c_credit, w.w_tax FROM customer AS c JOIN warehouse AS w ON c.c_w_id = w_id AND w.w_id = ?1 AND c.c_w_id = ?2 AND c.c_d_id = ?3 AND c.c_id = ?4", + "SELECT c.c_discount, c.c_last, c.c_credit, w.w_tax FROM customer AS c JOIN warehouse AS w ON c.c_w_id = w_id AND w.w_id = $1 AND c.c_w_id = $2 AND c.c_d_id = $3 AND c.c_id = $4", &[ColumnType::Decimal, ColumnType::Utf8, ColumnType::Utf8, ColumnType::Decimal], ), stmt( - "SELECT c_discount, c_last, c_credit FROM customer WHERE c_w_id = ?1 AND c_d_id = ?2 AND c_id = ?3", + "SELECT c_discount, c_last, c_credit FROM customer WHERE c_w_id = $1 AND c_d_id = $2 AND c_id = $3", &[ColumnType::Decimal, ColumnType::Utf8, ColumnType::Utf8], ), stmt( - "SELECT w_tax FROM warehouse WHERE w_id = ?1", + "SELECT w_tax FROM warehouse WHERE w_id = $1", &[ColumnType::Decimal], ), stmt( - "SELECT d_next_o_id, d_tax FROM district WHERE d_id = ?1 AND d_w_id = ?2", + "SELECT d_next_o_id, d_tax FROM district WHERE d_id = $1 AND d_w_id = $2", &[ColumnType::Int32, ColumnType::Decimal], ), stmt( - "UPDATE district SET d_next_o_id = ?1 + 1 WHERE d_id = ?2 AND d_w_id = ?3", + "UPDATE district SET d_next_o_id = $1 + 1 WHERE d_id = $2 AND d_w_id = $3", &[], ), stmt( - "INSERT INTO orders (o_id, o_d_id, o_w_id, o_c_id, o_entry_d, o_ol_cnt, o_all_local) VALUES(?1, ?2, ?3, ?4, ?5, ?6, ?7)", + "INSERT INTO orders (o_id, o_d_id, o_w_id, o_c_id, o_entry_d, o_ol_cnt, o_all_local) VALUES($1, $2, $3, $4, $5, $6, $7)", &[], ), stmt( - "INSERT INTO new_orders (no_o_id, no_d_id, no_w_id) VALUES (?1,?2,?3)", + "INSERT INTO new_orders (no_o_id, no_d_id, no_w_id) VALUES ($1,$2,$3)", &[], ), stmt( - "SELECT i_price, i_name, i_data FROM item WHERE i_id = ?1", + "SELECT i_price, i_name, i_data FROM item WHERE i_id = $1", &[ColumnType::Decimal, ColumnType::Utf8, ColumnType::Utf8], ), stmt( - "SELECT s_quantity, s_data, s_dist_01, s_dist_02, s_dist_03, s_dist_04, s_dist_05, s_dist_06, s_dist_07, s_dist_08, s_dist_09, s_dist_10 FROM stock WHERE s_i_id = ?1 AND s_w_id = ?2", + "SELECT s_quantity, s_data, s_dist_01, s_dist_02, s_dist_03, s_dist_04, s_dist_05, s_dist_06, s_dist_07, s_dist_08, s_dist_09, s_dist_10 FROM stock WHERE s_i_id = $1 AND s_w_id = $2", &[ ColumnType::Int16, ColumnType::Utf8, @@ -323,21 +323,21 @@ fn statement_specs() -> Vec> { ], ), stmt( - "UPDATE stock SET s_quantity = ?1 WHERE s_i_id = ?2 AND s_w_id = ?3", + "UPDATE stock SET s_quantity = $1 WHERE s_i_id = $2 AND s_w_id = $3", &[], ), stmt( - "INSERT INTO order_line (ol_o_id, ol_d_id, ol_w_id, ol_number, ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_dist_info) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + "INSERT INTO order_line (ol_o_id, ol_d_id, ol_w_id, ol_number, ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_dist_info) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)", &[], ), ], vec![ stmt( - "UPDATE warehouse SET w_ytd = w_ytd + ?1 WHERE w_id = ?2", + "UPDATE warehouse SET w_ytd = w_ytd + $1 WHERE w_id = $2", &[], ), stmt( - "SELECT w_street_1, w_street_2, w_city, w_state, w_zip, w_name FROM warehouse WHERE w_id = ?1", + "SELECT w_street_1, w_street_2, w_city, w_state, w_zip, w_name FROM warehouse WHERE w_id = $1", &[ ColumnType::Utf8, ColumnType::Utf8, @@ -348,11 +348,11 @@ fn statement_specs() -> Vec> { ], ), stmt( - "UPDATE district SET d_ytd = d_ytd + ?1 WHERE d_w_id = ?2 AND d_id = ?3", + "UPDATE district SET d_ytd = d_ytd + $1 WHERE d_w_id = $2 AND d_id = $3", &[], ), stmt( - "SELECT d_street_1, d_street_2, d_city, d_state, d_zip, d_name FROM district WHERE d_w_id = ?1 AND d_id = ?2", + "SELECT d_street_1, d_street_2, d_city, d_state, d_zip, d_name FROM district WHERE d_w_id = $1 AND d_id = $2", &[ ColumnType::Utf8, ColumnType::Utf8, @@ -363,15 +363,15 @@ fn statement_specs() -> Vec> { ], ), stmt( - "SELECT count(c_id) FROM customer WHERE c_w_id = ?1 AND c_d_id = ?2 AND c_last = ?3", + "SELECT count(c_id) FROM customer WHERE c_w_id = $1 AND c_d_id = $2 AND c_last = $3", &[ColumnType::Int32], ), stmt( - "SELECT c_id FROM customer WHERE c_w_id = ?1 AND c_d_id = ?2 AND c_last = ?3 ORDER BY c_first", + "SELECT c_id FROM customer WHERE c_w_id = $1 AND c_d_id = $2 AND c_last = $3 ORDER BY c_first", &[ColumnType::Int32], ), stmt( - "SELECT c_first, c_middle, c_last, c_street_1, c_street_2, c_city, c_state, c_zip, c_phone, c_credit, c_credit_lim, c_discount, c_balance, c_since FROM customer WHERE c_w_id = ?1 AND c_d_id = ?2 AND c_id = ?3", + "SELECT c_first, c_middle, c_last, c_street_1, c_street_2, c_city, c_state, c_zip, c_phone, c_credit, c_credit_lim, c_discount, c_balance, c_since FROM customer WHERE c_w_id = $1 AND c_d_id = $2 AND c_id = $3", &[ ColumnType::Utf8, ColumnType::Utf8, @@ -390,31 +390,31 @@ fn statement_specs() -> Vec> { ], ), stmt( - "SELECT c_data FROM customer WHERE c_w_id = ?1 AND c_d_id = ?2 AND c_id = ?3", + "SELECT c_data FROM customer WHERE c_w_id = $1 AND c_d_id = $2 AND c_id = $3", &[ColumnType::Utf8], ), stmt( - "UPDATE customer SET c_balance = ?1, c_data = ?2 WHERE c_w_id = ?3 AND c_d_id = ?4 AND c_id = ?5", + "UPDATE customer SET c_balance = $1, c_data = $2 WHERE c_w_id = $3 AND c_d_id = $4 AND c_id = $5", &[], ), stmt( - "UPDATE customer SET c_balance = ?1 WHERE c_w_id = ?2 AND c_d_id = ?3 AND c_id = ?4", + "UPDATE customer SET c_balance = $1 WHERE c_w_id = $2 AND c_d_id = $3 AND c_id = $4", &[], ), stmt( - "INSERT INTO history(h_c_d_id, h_c_w_id, h_c_id, h_d_id, h_w_id, h_date, h_amount, h_data) VALUES(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + "INSERT INTO history(h_c_d_id, h_c_w_id, h_c_id, h_d_id, h_w_id, h_date, h_amount, h_data) VALUES($1, $2, $3, $4, $5, $6, $7, $8)", &[], ), ], vec![ - // "SELECT count(c_id) FROM customer WHERE c_w_id = ?1 AND c_d_id = ?2 AND c_last = ?3" + // "SELECT count(c_id) FROM customer WHERE c_w_id = $1 AND c_d_id = $2 AND c_last = $3" stmt( - "SELECT count(c_id) FROM customer WHERE c_w_id = ?1 AND c_d_id = ?2 AND c_last = ?3", + "SELECT count(c_id) FROM customer WHERE c_w_id = $1 AND c_d_id = $2 AND c_last = $3", &[ColumnType::Int32], ), // "SELECT c_balance, c_first, c_middle, c_last FROM customer WHERE ... ORDER BY c_first" stmt( - "SELECT c_balance, c_first, c_middle, c_last FROM customer WHERE c_w_id = ?1 AND c_d_id = ?2 AND c_last = ?3 ORDER BY c_first", + "SELECT c_balance, c_first, c_middle, c_last FROM customer WHERE c_w_id = $1 AND c_d_id = $2 AND c_last = $3 ORDER BY c_first", &[ ColumnType::Decimal, ColumnType::Utf8, @@ -422,9 +422,9 @@ fn statement_specs() -> Vec> { ColumnType::Utf8, ], ), - // "SELECT c_balance, c_first, c_middle, c_last FROM customer WHERE c_w_id = ?1 AND c_d_id = ?2 AND c_id = ?3" + // "SELECT c_balance, c_first, c_middle, c_last FROM customer WHERE c_w_id = $1 AND c_d_id = $2 AND c_id = $3" stmt( - "SELECT c_balance, c_first, c_middle, c_last FROM customer WHERE c_w_id = ?1 AND c_d_id = ?2 AND c_id = ?3", + "SELECT c_balance, c_first, c_middle, c_last FROM customer WHERE c_w_id = $1 AND c_d_id = $2 AND c_id = $3", &[ ColumnType::Decimal, ColumnType::Utf8, @@ -434,12 +434,12 @@ fn statement_specs() -> Vec> { ), // "SELECT o_id, o_entry_d, COALESCE(o_carrier_id,0) FROM orders ..." stmt( - "SELECT o_id, o_entry_d, COALESCE(o_carrier_id,0) FROM orders WHERE o_w_id = ?1 AND o_d_id = ?2 AND o_c_id = ?3 AND o_id = (SELECT MAX(o_id) FROM orders WHERE o_w_id = ?4 AND o_d_id = ?5 AND o_c_id = ?6)", + "SELECT o_id, o_entry_d, COALESCE(o_carrier_id,0) FROM orders WHERE o_w_id = $1 AND o_d_id = $2 AND o_c_id = $3 AND o_id = (SELECT MAX(o_id) FROM orders WHERE o_w_id = $4 AND o_d_id = $5 AND o_c_id = $6)", &[ColumnType::Int32, ColumnType::DateTime, ColumnType::Int32], ), // "SELECT ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_delivery_d FROM order_line ..." stmt( - "SELECT ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_delivery_d FROM order_line WHERE ol_w_id = ?1 AND ol_d_id = ?2 AND ol_o_id = ?3", + "SELECT ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_delivery_d FROM order_line WHERE ol_w_id = $1 AND ol_d_id = $2 AND ol_o_id = $3", &[ ColumnType::Int32, ColumnType::Int16, @@ -450,52 +450,52 @@ fn statement_specs() -> Vec> { ), ], vec![ - // "SELECT COALESCE(MIN(no_o_id),0) FROM new_orders WHERE no_d_id = ?1 AND no_w_id = ?2" + // "SELECT COALESCE(MIN(no_o_id),0) FROM new_orders WHERE no_d_id = $1 AND no_w_id = $2" stmt( - "SELECT COALESCE(MIN(no_o_id),0) FROM new_orders WHERE no_d_id = ?1 AND no_w_id = ?2", + "SELECT COALESCE(MIN(no_o_id),0) FROM new_orders WHERE no_d_id = $1 AND no_w_id = $2", &[ColumnType::Int32], ), - // "DELETE FROM new_orders WHERE no_o_id = ?1 AND no_d_id = ?2 AND no_w_id = ?3" + // "DELETE FROM new_orders WHERE no_o_id = $1 AND no_d_id = $2 AND no_w_id = $3" stmt( - "DELETE FROM new_orders WHERE no_o_id = ?1 AND no_d_id = ?2 AND no_w_id = ?3", + "DELETE FROM new_orders WHERE no_o_id = $1 AND no_d_id = $2 AND no_w_id = $3", &[], ), - // "SELECT o_c_id FROM orders WHERE o_id = ?1 AND o_d_id = ?2 AND o_w_id = ?3" + // "SELECT o_c_id FROM orders WHERE o_id = $1 AND o_d_id = $2 AND o_w_id = $3" stmt( - "SELECT o_c_id FROM orders WHERE o_id = ?1 AND o_d_id = ?2 AND o_w_id = ?3", + "SELECT o_c_id FROM orders WHERE o_id = $1 AND o_d_id = $2 AND o_w_id = $3", &[ColumnType::Int32], ), - // "UPDATE orders SET o_carrier_id = ?1 WHERE o_id = ?2 AND o_d_id = ?3 AND o_w_id = ?4" + // "UPDATE orders SET o_carrier_id = $1 WHERE o_id = $2 AND o_d_id = $3 AND o_w_id = $4" stmt( - "UPDATE orders SET o_carrier_id = ?1 WHERE o_id = ?2 AND o_d_id = ?3 AND o_w_id = ?4", + "UPDATE orders SET o_carrier_id = $1 WHERE o_id = $2 AND o_d_id = $3 AND o_w_id = $4", &[], ), - // "UPDATE order_line SET ol_delivery_d = ?1 WHERE ol_o_id = ?2 AND ol_d_id = ?3 AND ol_w_id = ?4" + // "UPDATE order_line SET ol_delivery_d = $1 WHERE ol_o_id = $2 AND ol_d_id = $3 AND ol_w_id = $4" stmt( - "UPDATE order_line SET ol_delivery_d = ?1 WHERE ol_o_id = ?2 AND ol_d_id = ?3 AND ol_w_id = ?4", + "UPDATE order_line SET ol_delivery_d = $1 WHERE ol_o_id = $2 AND ol_d_id = $3 AND ol_w_id = $4", &[], ), - // "SELECT SUM(ol_amount) FROM order_line WHERE ol_o_id = ?1 AND ol_d_id = ?2 AND ol_w_id = ?3" + // "SELECT SUM(ol_amount) FROM order_line WHERE ol_o_id = $1 AND ol_d_id = $2 AND ol_w_id = $3" stmt( - "SELECT SUM(ol_amount) FROM order_line WHERE ol_o_id = ?1 AND ol_d_id = ?2 AND ol_w_id = ?3", + "SELECT SUM(ol_amount) FROM order_line WHERE ol_o_id = $1 AND ol_d_id = $2 AND ol_w_id = $3", &[ColumnType::Decimal], ), - // "UPDATE customer SET c_balance = c_balance + ?1 , c_delivery_cnt = c_delivery_cnt + 1 WHERE c_id = ?2 ..." + // "UPDATE customer SET c_balance = c_balance + $1 , c_delivery_cnt = c_delivery_cnt + 1 WHERE c_id = $2 ..." stmt( - "UPDATE customer SET c_balance = c_balance + ?1 , c_delivery_cnt = c_delivery_cnt + 1 WHERE c_id = ?2 AND c_d_id = ?3 AND c_w_id = ?4", + "UPDATE customer SET c_balance = c_balance + $1 , c_delivery_cnt = c_delivery_cnt + 1 WHERE c_id = $2 AND c_d_id = $3 AND c_w_id = $4", &[], ), ], vec![ - // "SELECT d_next_o_id FROM district WHERE d_id = ?1 AND d_w_id = ?2" + // "SELECT d_next_o_id FROM district WHERE d_id = $1 AND d_w_id = $2" stmt( - "SELECT d_next_o_id FROM district WHERE d_id = ?1 AND d_w_id = ?2", + "SELECT d_next_o_id FROM district WHERE d_id = $1 AND d_w_id = $2", &[ColumnType::Int32], ), stmt(STOCK_LEVEL_DISTINCT_SQL, &[ColumnType::Int32]), - // "SELECT count(*) FROM stock WHERE s_w_id = ?1 AND s_i_id = ?2 AND s_quantity < ?3" + // "SELECT count(*) FROM stock WHERE s_w_id = $1 AND s_i_id = $2 AND s_quantity < $3" stmt( - "SELECT count(*) FROM stock WHERE s_w_id = ?1 AND s_i_id = ?2 AND s_quantity < ?3", + "SELECT count(*) FROM stock WHERE s_w_id = $1 AND s_i_id = $2 AND s_quantity < $3", &[ColumnType::Int32], ), ], diff --git a/tpcc/src/new_ord.rs b/tpcc/src/new_ord.rs index f0493f4b..d70b9623 100644 --- a/tpcc/src/new_ord.rs +++ b/tpcc/src/new_ord.rs @@ -84,10 +84,10 @@ impl TpccTransaction for NewOrd { let tuple = tx.query_one( &statements[0], &[ - ("?1", DataValue::Int16(args.w_id as i16)), - ("?2", DataValue::Int16(args.w_id as i16)), - ("?3", DataValue::Int8(args.d_id as i8)), - ("?4", DataValue::Int64(args.c_id as i64)), + ("$1", DataValue::Int16(args.w_id as i16)), + ("$2", DataValue::Int16(args.w_id as i16)), + ("$3", DataValue::Int8(args.d_id as i8)), + ("$4", DataValue::Int64(args.c_id as i64)), ], )?; let c_discount = tuple.values[0].decimal().unwrap(); @@ -101,9 +101,9 @@ impl TpccTransaction for NewOrd { let tuple = tx.query_one( &statements[1], &[ - ("?1", DataValue::Int16(args.w_id as i16)), - ("?2", DataValue::Int8(args.d_id as i8)), - ("?3", DataValue::Int32(args.c_id as i32)), + ("$1", DataValue::Int16(args.w_id as i16)), + ("$2", DataValue::Int8(args.d_id as i8)), + ("$3", DataValue::Int32(args.c_id as i32)), ], )?; let c_discount = tuple.values[0].decimal().unwrap(); @@ -112,7 +112,7 @@ impl TpccTransaction for NewOrd { // "SELECT w_tax FROM warehouse WHERE w_id = ?" let tuple = tx.query_one( &statements[2], - &[("?1", DataValue::Int16(args.w_id as i16))], + &[("$1", DataValue::Int16(args.w_id as i16))], )?; let w_tax = tuple.values[0].decimal().unwrap(); @@ -122,8 +122,8 @@ impl TpccTransaction for NewOrd { let tuple = tx.query_one( &statements[3], &[ - ("?1", DataValue::Int8(args.d_id as i8)), - ("?2", DataValue::Int16(args.w_id as i16)), + ("$1", DataValue::Int8(args.d_id as i8)), + ("$2", DataValue::Int16(args.w_id as i16)), ], )?; let d_next_o_id = tuple.values[0].i32().unwrap(); @@ -132,9 +132,9 @@ impl TpccTransaction for NewOrd { tx.execute_drain( &statements[4], &[ - ("?1", DataValue::Int32(d_next_o_id)), - ("?2", DataValue::Int8(args.d_id as i8)), - ("?3", DataValue::Int16(args.w_id as i16)), + ("$1", DataValue::Int32(d_next_o_id)), + ("$2", DataValue::Int8(args.d_id as i8)), + ("$3", DataValue::Int16(args.w_id as i16)), ], )?; let o_id = d_next_o_id; @@ -142,22 +142,22 @@ impl TpccTransaction for NewOrd { tx.execute_drain( &statements[5], &[ - ("?1", DataValue::Int32(o_id)), - ("?2", DataValue::Int8(args.d_id as i8)), - ("?3", DataValue::Int16(args.w_id as i16)), - ("?4", DataValue::Int32(args.c_id as i32)), - ("?5", DataValue::from(&now)), - ("?6", DataValue::Int8(args.o_ol_cnt as i8)), - ("?7", DataValue::Int8(args.o_all_local as i8)), + ("$1", DataValue::Int32(o_id)), + ("$2", DataValue::Int8(args.d_id as i8)), + ("$3", DataValue::Int16(args.w_id as i16)), + ("$4", DataValue::Int32(args.c_id as i32)), + ("$5", DataValue::from(&now)), + ("$6", DataValue::Int8(args.o_ol_cnt as i8)), + ("$7", DataValue::Int8(args.o_all_local as i8)), ], )?; // "INSERT INTO new_orders (no_o_id, no_d_id, no_w_id) VALUES (?,?,?)" tx.execute_drain( &statements[6], &[ - ("?1", DataValue::Int32(o_id)), - ("?2", DataValue::Int8(args.d_id as i8)), - ("?3", DataValue::Int16(args.w_id as i16)), + ("$1", DataValue::Int32(o_id)), + ("$2", DataValue::Int8(args.d_id as i8)), + ("$3", DataValue::Int16(args.w_id as i16)), ], )?; let mut ol_num_seq = vec![0; MAX_NUM_ITEMS]; @@ -188,7 +188,7 @@ impl TpccTransaction for NewOrd { let ol_i_id = args.item_id[ol_num_seq[ol_number - 1]]; let ol_quantity = args.qty[ol_num_seq[ol_number - 1]]; // "SELECT i_price, i_name, i_data FROM item WHERE i_id = ?" - let params = [("?1", DataValue::Int32(ol_i_id as i32))]; + let params = [("$1", DataValue::Int32(ol_i_id as i32))]; let tuple = tx.query_one(&statements[7], ¶ms)?; let i_price = tuple.values[0].decimal().unwrap(); let i_name = tuple.values[1].utf8().unwrap(); @@ -199,8 +199,8 @@ impl TpccTransaction for NewOrd { // "SELECT s_quantity, s_data, s_dist_01, s_dist_02, s_dist_03, s_dist_04, s_dist_05, s_dist_06, s_dist_07, s_dist_08, s_dist_09, s_dist_10 FROM stock WHERE s_i_id = ? AND s_w_id = ? FOR UPDATE" let params = [ - ("?1", DataValue::Int32(ol_i_id as i32)), - ("?2", DataValue::Int16(ol_supply_w_id as i16)), + ("$1", DataValue::Int32(ol_i_id as i32)), + ("$2", DataValue::Int16(ol_supply_w_id as i16)), ]; let tuple = tx.query_one(&statements[8], ¶ms)?; let mut s_quantity = tuple.values[0].i16().unwrap(); @@ -235,9 +235,9 @@ impl TpccTransaction for NewOrd { }; // "UPDATE stock SET s_quantity = ? WHERE s_i_id = ? AND s_w_id = ?" let params = [ - ("?1", DataValue::Int16(s_quantity)), - ("?2", DataValue::Int32(ol_i_id as i32)), - ("?3", DataValue::Int16(ol_supply_w_id as i16)), + ("$1", DataValue::Int16(s_quantity)), + ("$2", DataValue::Int32(ol_i_id as i32)), + ("$3", DataValue::Int16(ol_supply_w_id as i16)), ]; tx.execute_drain(&statements[9], ¶ms)?; @@ -253,15 +253,15 @@ impl TpccTransaction for NewOrd { amt[ol_num_seq[ol_number - 1]] = ol_amount; // "INSERT INTO order_line (ol_o_id, ol_d_id, ol_w_id, ol_number, ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_dist_info) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" let params = [ - ("?1", DataValue::Int32(o_id)), - ("?2", DataValue::Int8(args.d_id as i8)), - ("?3", DataValue::Int16(args.w_id as i16)), - ("?4", DataValue::Int8(ol_number as i8)), - ("?5", DataValue::Int32(ol_i_id as i32)), - ("?6", DataValue::Int16(ol_supply_w_id as i16)), - ("?7", DataValue::Int8(ol_quantity as i8)), - ("?8", DataValue::Decimal(ol_amount.round_dp(2))), - ("?9", DataValue::from(ol_dist_info)), + ("$1", DataValue::Int32(o_id)), + ("$2", DataValue::Int8(args.d_id as i8)), + ("$3", DataValue::Int16(args.w_id as i16)), + ("$4", DataValue::Int8(ol_number as i8)), + ("$5", DataValue::Int32(ol_i_id as i32)), + ("$6", DataValue::Int16(ol_supply_w_id as i16)), + ("$7", DataValue::Int8(ol_quantity as i8)), + ("$8", DataValue::Decimal(ol_amount.round_dp(2))), + ("$9", DataValue::from(ol_dist_info)), ]; tx.execute_drain(&statements[10], ¶ms)?; } diff --git a/tpcc/src/order_stat.rs b/tpcc/src/order_stat.rs index 70902ab8..3c7490fa 100644 --- a/tpcc/src/order_stat.rs +++ b/tpcc/src/order_stat.rs @@ -63,17 +63,17 @@ impl TpccTransaction for OrderStat { let tuple = tx.query_one( &statements[0], &[ - ("?1", DataValue::Int16(args.w_id as i16)), - ("?2", DataValue::Int8(args.d_id as i8)), - ("?3", DataValue::from(args.c_last.clone())), + ("$1", DataValue::Int16(args.w_id as i16)), + ("$2", DataValue::Int8(args.d_id as i8)), + ("$3", DataValue::from(args.c_last.clone())), ], )?; let mut name_cnt = tuple.values[0].i32().unwrap() as usize; // "SELECT c_balance, c_first, c_middle, c_last FROM customer WHERE c_w_id = ? AND c_d_id = ? AND c_last = ? ORDER BY c_first" let params = [ - ("?1", DataValue::Int16(args.w_id as i16)), - ("?2", DataValue::Int8(args.d_id as i8)), - ("?3", DataValue::from(args.c_last.clone())), + ("$1", DataValue::Int16(args.w_id as i16)), + ("$2", DataValue::Int8(args.d_id as i8)), + ("$3", DataValue::from(args.c_last.clone())), ]; let mut tuple_iter = tx.execute(&statements[1], ¶ms)?; @@ -99,9 +99,9 @@ impl TpccTransaction for OrderStat { let tuple = tx.query_one( &statements[2], &[ - ("?1", DataValue::Int16(args.w_id as i16)), - ("?2", DataValue::Int8(args.d_id as i8)), - ("?3", DataValue::Int32(args.c_id as i32)), + ("$1", DataValue::Int16(args.w_id as i16)), + ("$2", DataValue::Int8(args.d_id as i8)), + ("$3", DataValue::Int32(args.c_id as i32)), ], )?; let c_balance = tuple.values[0].decimal().unwrap(); @@ -112,20 +112,20 @@ impl TpccTransaction for OrderStat { }; // "SELECT o_id, o_entry_d, COALESCE(o_carrier_id,0) FROM orders WHERE o_w_id = ? AND o_d_id = ? AND o_c_id = ? AND o_id = (SELECT MAX(o_id) FROM orders WHERE o_w_id = ? AND o_d_id = ? AND o_c_id = ?)" let params = [ - ("?1", DataValue::Int16(args.w_id as i16)), - ("?2", DataValue::Int8(args.d_id as i8)), - ("?3", DataValue::Int32(args.c_id as i32)), - ("?4", DataValue::Int16(args.w_id as i16)), - ("?5", DataValue::Int8(args.d_id as i8)), - ("?6", DataValue::Int32(args.c_id as i32)), + ("$1", DataValue::Int16(args.w_id as i16)), + ("$2", DataValue::Int8(args.d_id as i8)), + ("$3", DataValue::Int32(args.c_id as i32)), + ("$4", DataValue::Int16(args.w_id as i16)), + ("$5", DataValue::Int8(args.d_id as i8)), + ("$6", DataValue::Int32(args.c_id as i32)), ]; let tuple = tx.query_one(&statements[3], ¶ms)?; let o_id = tuple.values[0].i32().unwrap(); // "SELECT ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_delivery_d FROM order_line WHERE ol_w_id = ? AND ol_d_id = ? AND ol_o_id = ?" let params = [ - ("?1", DataValue::Int16(args.w_id as i16)), - ("?2", DataValue::Int8(args.d_id as i8)), - ("?3", DataValue::Int32(o_id)), + ("$1", DataValue::Int16(args.w_id as i16)), + ("$2", DataValue::Int8(args.d_id as i8)), + ("$3", DataValue::Int32(o_id)), ]; let _tuple = tx.query_one(&statements[4], ¶ms)?; // let ol_i_id = tuple.values[0].i32(); diff --git a/tpcc/src/payment.rs b/tpcc/src/payment.rs index b6ad0ba6..9c64e411 100644 --- a/tpcc/src/payment.rs +++ b/tpcc/src/payment.rs @@ -75,14 +75,14 @@ impl TpccTransaction for Payment { tx.execute_drain( &statements[0], &[ - ("?1", DataValue::Decimal(args.h_amount)), - ("?2", DataValue::Int16(args.w_id as i16)), + ("$1", DataValue::Decimal(args.h_amount)), + ("$2", DataValue::Int16(args.w_id as i16)), ], )?; // "SELECT w_street_1, w_street_2, w_city, w_state, w_zip, w_name FROM warehouse WHERE w_id = ?" let tuple = tx.query_one( &statements[1], - &[("?1", DataValue::Int16(args.w_id as i16))], + &[("$1", DataValue::Int16(args.w_id as i16))], )?; let w_street_1 = tuple.values[0].utf8().unwrap(); let w_street_2 = tuple.values[1].utf8().unwrap(); @@ -95,9 +95,9 @@ impl TpccTransaction for Payment { tx.execute_drain( &statements[2], &[ - ("?1", DataValue::Decimal(args.h_amount)), - ("?2", DataValue::Int16(args.w_id as i16)), - ("?3", DataValue::Int8(args.d_id as i8)), + ("$1", DataValue::Decimal(args.h_amount)), + ("$2", DataValue::Int16(args.w_id as i16)), + ("$3", DataValue::Int8(args.d_id as i8)), ], )?; @@ -105,8 +105,8 @@ impl TpccTransaction for Payment { let tuple = tx.query_one( &statements[3], &[ - ("?1", DataValue::Int16(args.w_id as i16)), - ("?2", DataValue::Int8(args.d_id as i8)), + ("$1", DataValue::Int16(args.w_id as i16)), + ("$2", DataValue::Int8(args.d_id as i8)), ], )?; let d_street_1 = tuple.values[0].utf8().unwrap(); @@ -122,17 +122,17 @@ impl TpccTransaction for Payment { let tuple = tx.query_one( &statements[4], &[ - ("?1", DataValue::Int16(args.c_w_id as i16)), - ("?2", DataValue::Int8(args.c_d_id as i8)), - ("?3", DataValue::from(args.c_last.clone())), + ("$1", DataValue::Int16(args.c_w_id as i16)), + ("$2", DataValue::Int8(args.c_d_id as i8)), + ("$3", DataValue::from(args.c_last.clone())), ], )?; let mut name_cnt = tuple.values[0].i32().unwrap(); // "SELECT c_id FROM customer WHERE c_w_id = ? AND c_d_id = ? AND c_last = ? ORDER BY c_first" let params = [ - ("?1", DataValue::Int16(args.c_w_id as i16)), - ("?2", DataValue::Int8(args.c_d_id as i8)), - ("?3", DataValue::from(args.c_last.clone())), + ("$1", DataValue::Int16(args.c_w_id as i16)), + ("$2", DataValue::Int8(args.c_d_id as i8)), + ("$3", DataValue::from(args.c_last.clone())), ]; let mut tuple_iter = tx.execute(&statements[5], ¶ms)?; if name_cnt % 2 == 1 { @@ -147,9 +147,9 @@ impl TpccTransaction for Payment { let tuple = tx.query_one( &statements[6], &[ - ("?1", DataValue::Int16(args.c_w_id as i16)), - ("?2", DataValue::Int8(args.c_d_id as i8)), - ("?3", DataValue::Int32(c_id)), + ("$1", DataValue::Int16(args.c_w_id as i16)), + ("$2", DataValue::Int8(args.c_d_id as i8)), + ("$3", DataValue::Int32(c_id)), ], )?; let c_first = tuple.values[0].utf8().unwrap(); @@ -174,9 +174,9 @@ impl TpccTransaction for Payment { let tuple = tx.query_one( &statements[7], &[ - ("?1", DataValue::Int16(args.c_w_id as i16)), - ("?2", DataValue::Int8(args.c_d_id as i8)), - ("?3", DataValue::Int32(c_id)), + ("$1", DataValue::Int16(args.c_w_id as i16)), + ("$2", DataValue::Int8(args.c_d_id as i8)), + ("$3", DataValue::Int32(c_id)), ], )?; let c_data = tuple.values[0].utf8().unwrap(); @@ -188,11 +188,11 @@ impl TpccTransaction for Payment { tx.execute_drain( &statements[8], &[ - ("?1", DataValue::Decimal(c_balance)), - ("?2", DataValue::from(c_data.to_string())), - ("?3", DataValue::Int16(args.c_w_id as i16)), - ("?4", DataValue::Int8(args.c_d_id as i8)), - ("?5", DataValue::Int32(c_id)), + ("$1", DataValue::Decimal(c_balance)), + ("$2", DataValue::from(c_data.to_string())), + ("$3", DataValue::Int16(args.c_w_id as i16)), + ("$4", DataValue::Int8(args.c_d_id as i8)), + ("$5", DataValue::Int32(c_id)), ], )?; } else { @@ -200,10 +200,10 @@ impl TpccTransaction for Payment { tx.execute_drain( &statements[9], &[ - ("?1", DataValue::Decimal(c_balance)), - ("?2", DataValue::Int16(args.c_w_id as i16)), - ("?3", DataValue::Int8(args.c_d_id as i8)), - ("?4", DataValue::Int32(c_id)), + ("$1", DataValue::Decimal(c_balance)), + ("$2", DataValue::Int16(args.c_w_id as i16)), + ("$3", DataValue::Int8(args.c_d_id as i8)), + ("$4", DataValue::Int32(c_id)), ], )?; } @@ -212,10 +212,10 @@ impl TpccTransaction for Payment { tx.execute_drain( &statements[9], &[ - ("?1", DataValue::Decimal(c_balance)), - ("?2", DataValue::Int16(args.c_w_id as i16)), - ("?3", DataValue::Int8(args.c_d_id as i8)), - ("?4", DataValue::Int32(c_id)), + ("$1", DataValue::Decimal(c_balance)), + ("$2", DataValue::Int16(args.c_w_id as i16)), + ("$3", DataValue::Int8(args.c_d_id as i8)), + ("$4", DataValue::Int32(c_id)), ], )?; } @@ -224,14 +224,14 @@ impl TpccTransaction for Payment { tx.execute_drain( &statements[10], &[ - ("?1", DataValue::Int8(args.c_d_id as i8)), - ("?2", DataValue::Int16(args.c_w_id as i16)), - ("?3", DataValue::Int32(c_id)), - ("?4", DataValue::Int8(args.d_id as i8)), - ("?5", DataValue::Int16(args.w_id as i16)), - ("?6", DataValue::from(&now.naive_utc())), - ("?7", DataValue::Decimal(args.h_amount)), - ("?8", DataValue::from(h_data)), + ("$1", DataValue::Int8(args.c_d_id as i8)), + ("$2", DataValue::Int16(args.c_w_id as i16)), + ("$3", DataValue::Int32(c_id)), + ("$4", DataValue::Int8(args.d_id as i8)), + ("$5", DataValue::Int16(args.w_id as i16)), + ("$6", DataValue::from(&now.naive_utc())), + ("$7", DataValue::Decimal(args.h_amount)), + ("$8", DataValue::from(h_data)), ], )?; diff --git a/tpcc/src/slev.rs b/tpcc/src/slev.rs index bf7108e0..24c477af 100644 --- a/tpcc/src/slev.rs +++ b/tpcc/src/slev.rs @@ -47,8 +47,8 @@ impl TpccTransaction for Slev { let tuple = tx.query_one( &statements[0], &[ - ("?1", DataValue::Int8(args.d_id as i8)), - ("?2", DataValue::Int16(args.w_id as i16)), + ("$1", DataValue::Int8(args.d_id as i8)), + ("$2", DataValue::Int16(args.w_id as i16)), ], )?; let d_next_o_id = tuple.values[0].i32().unwrap(); @@ -56,10 +56,10 @@ impl TpccTransaction for Slev { let tuple = tx.query_one( &statements[1], &[ - ("?1", DataValue::Int16(args.w_id as i16)), - ("?2", DataValue::Int8(args.d_id as i8)), - ("?3", DataValue::Int32(d_next_o_id)), - ("?4", DataValue::Int32(d_next_o_id)), + ("$1", DataValue::Int16(args.w_id as i16)), + ("$2", DataValue::Int8(args.d_id as i8)), + ("$3", DataValue::Int32(d_next_o_id)), + ("$4", DataValue::Int32(d_next_o_id)), ], )?; let ol_i_id = tuple.values[0].i32().unwrap(); @@ -67,9 +67,9 @@ impl TpccTransaction for Slev { let _tuple = tx.query_one( &statements[2], &[ - ("?1", DataValue::Int16(args.w_id as i16)), - ("?2", DataValue::Int8(ol_i_id as i8)), - ("?3", DataValue::Int16(args.level as i16)), + ("$1", DataValue::Int16(args.w_id as i16)), + ("$2", DataValue::Int8(ol_i_id as i8)), + ("$3", DataValue::Int16(args.level as i16)), ], )?; // let i_count = tuple.values[0].i32().unwrap();