use super::expression_types::NotSelectable;
use crate::backend::{sql_dialect, Backend, SqlDialect};
use crate::expression::subselect::Subselect;
use crate::expression::{
AppearsOnTable, AsExpression, Expression, SelectableExpression, TypedExpressionType,
ValidGrouping,
};
use crate::query_builder::combination_clause::CombinationClause;
use crate::query_builder::{
AstPass, BoxedSelectStatement, QueryFragment, QueryId, SelectQuery, SelectStatement,
};
use crate::result::QueryResult;
use crate::serialize::ToSql;
use crate::sql_types::{self, HasSqlType, SingleValue, SqlType};
use std::marker::PhantomData;
#[derive(Debug, Copy, Clone, QueryId, ValidGrouping)]
#[non_exhaustive]
pub struct In<T, U> {
pub left: T,
pub values: U,
}
#[derive(Debug, Copy, Clone, QueryId, ValidGrouping)]
#[non_exhaustive]
pub struct NotIn<T, U> {
pub left: T,
pub values: U,
}
impl<T, U> In<T, U> {
pub(crate) fn new(left: T, values: U) -> Self {
In { left, values }
}
pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
where
DB: Backend,
T: QueryFragment<DB>,
U: QueryFragment<DB> + InExpression,
{
if self.values.is_empty() {
out.push_sql("1=0");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}
impl<T, U> NotIn<T, U> {
pub(crate) fn new(left: T, values: U) -> Self {
NotIn { left, values }
}
pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
where
DB: Backend,
T: QueryFragment<DB>,
U: QueryFragment<DB> + InExpression,
{
if self.values.is_empty() {
out.push_sql("1=1");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" NOT IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}
impl<T, U> Expression for In<T, U>
where
T: Expression,
U: InExpression<SqlType = T::SqlType>,
T::SqlType: SqlType,
sql_types::is_nullable::IsSqlTypeNullable<T::SqlType>:
sql_types::MaybeNullableType<sql_types::Bool>,
{
type SqlType = sql_types::is_nullable::MaybeNullable<
sql_types::is_nullable::IsSqlTypeNullable<T::SqlType>,
sql_types::Bool,
>;
}
impl<T, U> Expression for NotIn<T, U>
where
T: Expression,
U: InExpression<SqlType = T::SqlType>,
T::SqlType: SqlType,
sql_types::is_nullable::IsSqlTypeNullable<T::SqlType>:
sql_types::MaybeNullableType<sql_types::Bool>,
{
type SqlType = sql_types::is_nullable::MaybeNullable<
sql_types::is_nullable::IsSqlTypeNullable<T::SqlType>,
sql_types::Bool,
>;
}
impl<T, U, DB> QueryFragment<DB> for In<T, U>
where
DB: Backend,
Self: QueryFragment<DB, DB::ArrayComparison>,
{
fn walk_ast<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()> {
<Self as QueryFragment<DB, DB::ArrayComparison>>::walk_ast(self, pass)
}
}
impl<T, U, DB> QueryFragment<DB, sql_dialect::array_comparison::AnsiSqlArrayComparison> for In<T, U>
where
DB: Backend
+ SqlDialect<ArrayComparison = sql_dialect::array_comparison::AnsiSqlArrayComparison>,
T: QueryFragment<DB>,
U: QueryFragment<DB> + InExpression,
{
fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
self.walk_ansi_ast(out)
}
}
impl<T, U, DB> QueryFragment<DB> for NotIn<T, U>
where
DB: Backend,
Self: QueryFragment<DB, DB::ArrayComparison>,
{
fn walk_ast<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()> {
<Self as QueryFragment<DB, DB::ArrayComparison>>::walk_ast(self, pass)
}
}
impl<T, U, DB> QueryFragment<DB, sql_dialect::array_comparison::AnsiSqlArrayComparison>
for NotIn<T, U>
where
DB: Backend
+ SqlDialect<ArrayComparison = sql_dialect::array_comparison::AnsiSqlArrayComparison>,
T: QueryFragment<DB>,
U: QueryFragment<DB> + InExpression,
{
fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
self.walk_ansi_ast(out)
}
}
impl_selectable_expression!(In<T, U>);
impl_selectable_expression!(NotIn<T, U>);
pub trait AsInExpression<T: SqlType> {
type InExpression: InExpression<SqlType = T>;
#[allow(clippy::wrong_self_convention)]
fn as_in_expression(self) -> Self::InExpression;
}
impl<I, T, ST> AsInExpression<ST> for I
where
I: IntoIterator<Item = T>,
T: AsExpression<ST>,
ST: SqlType + TypedExpressionType,
{
type InExpression = Many<ST, T>;
fn as_in_expression(self) -> Self::InExpression {
Many {
values: self.into_iter().collect(),
p: PhantomData,
}
}
}
pub trait InExpression {
type SqlType: SqlType;
fn is_empty(&self) -> bool;
fn is_array(&self) -> bool;
}
impl<ST, F, S, D, W, O, LOf, G, H, LC> AsInExpression<ST>
for SelectStatement<F, S, D, W, O, LOf, G, H, LC>
where
ST: SqlType,
Subselect<Self, ST>: Expression<SqlType = ST>,
Self: SelectQuery<SqlType = ST>,
{
type InExpression = Subselect<Self, ST>;
fn as_in_expression(self) -> Self::InExpression {
Subselect::new(self)
}
}
impl<'a, ST, QS, DB, GB> AsInExpression<ST> for BoxedSelectStatement<'a, ST, QS, DB, GB>
where
ST: SqlType,
Subselect<BoxedSelectStatement<'a, ST, QS, DB, GB>, ST>: Expression<SqlType = ST>,
{
type InExpression = Subselect<Self, ST>;
fn as_in_expression(self) -> Self::InExpression {
Subselect::new(self)
}
}
impl<ST, Combinator, Rule, Source, Rhs> AsInExpression<ST>
for CombinationClause<Combinator, Rule, Source, Rhs>
where
ST: SqlType,
Self: SelectQuery<SqlType = ST>,
Subselect<Self, ST>: Expression<SqlType = ST>,
{
type InExpression = Subselect<Self, ST>;
fn as_in_expression(self) -> Self::InExpression {
Subselect::new(self)
}
}
#[derive(Debug, Clone)]
pub struct Many<ST, I> {
pub values: Vec<I>,
p: PhantomData<ST>,
}
impl<ST, I, GB> ValidGrouping<GB> for Many<ST, I>
where
ST: SingleValue,
I: AsExpression<ST>,
I::Expression: ValidGrouping<GB>,
{
type IsAggregate = <I::Expression as ValidGrouping<GB>>::IsAggregate;
}
impl<ST, I> Expression for Many<ST, I>
where
ST: TypedExpressionType,
{
type SqlType = NotSelectable;
}
impl<ST, I> InExpression for Many<ST, I>
where
ST: SqlType,
{
type SqlType = ST;
fn is_empty(&self) -> bool {
self.values.is_empty()
}
fn is_array(&self) -> bool {
ST::IS_ARRAY
}
}
impl<ST, I, QS> SelectableExpression<QS> for Many<ST, I>
where
Many<ST, I>: AppearsOnTable<QS>,
ST: SingleValue,
I: AsExpression<ST>,
<I as AsExpression<ST>>::Expression: SelectableExpression<QS>,
{
}
impl<ST, I, QS> AppearsOnTable<QS> for Many<ST, I>
where
Many<ST, I>: Expression,
I: AsExpression<ST>,
ST: SingleValue,
<I as AsExpression<ST>>::Expression: SelectableExpression<QS>,
{
}
impl<ST, I, DB> QueryFragment<DB> for Many<ST, I>
where
Self: QueryFragment<DB, DB::ArrayComparison>,
DB: Backend,
{
fn walk_ast<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()> {
<Self as QueryFragment<DB, DB::ArrayComparison>>::walk_ast(self, pass)
}
}
impl<ST, I, DB> QueryFragment<DB, sql_dialect::array_comparison::AnsiSqlArrayComparison>
for Many<ST, I>
where
DB: Backend
+ HasSqlType<ST>
+ SqlDialect<ArrayComparison = sql_dialect::array_comparison::AnsiSqlArrayComparison>,
ST: SingleValue,
I: ToSql<ST, DB>,
{
fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
self.walk_ansi_ast(out)
}
}
impl<ST, I> Many<ST, I> {
pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
where
DB: Backend + HasSqlType<ST>,
ST: SingleValue,
I: ToSql<ST, DB>,
{
out.unsafe_to_cache_prepared();
let mut first = true;
for value in &self.values {
if first {
first = false;
} else {
out.push_sql(", ");
}
out.push_bind_param(value)?;
}
Ok(())
}
}
impl<ST, I> QueryId for Many<ST, I> {
type QueryId = ();
const HAS_STATIC_QUERY_ID: bool = false;
}