1use diesel::associations::HasTable;
2use diesel::backend::Backend;
3use diesel::dsl;
4use diesel::migration::{
5 Migration, MigrationConnection, MigrationSource, MigrationVersion, Result,
6};
7use diesel::prelude::*;
8use diesel::query_builder::{DeleteStatement, InsertStatement, IntoUpdateTarget};
9use diesel::query_dsl::methods::ExecuteDsl;
10use diesel::query_dsl::LoadQuery;
11use diesel::serialize::ToSql;
12use diesel::sql_types::Text;
13use std::cell::RefCell;
14use std::collections::HashMap;
15use std::io::Write;
16
17use crate::errors::MigrationError;
18
19diesel::table! {
20 __diesel_schema_migrations (version) {
21 version -> VarChar,
22 run_on -> Timestamp,
23 }
24}
25
26pub trait MigrationHarness<DB: Backend> {
28 fn has_pending_migration<S: MigrationSource<DB>>(&mut self, source: S) -> Result<bool> {
30 self.pending_migrations(source).map(|p| !p.is_empty())
31 }
32
33 fn run_pending_migrations<S: MigrationSource<DB>>(
35 &mut self,
36 source: S,
37 ) -> Result<Vec<MigrationVersion>> {
38 let pending = self.pending_migrations(source)?;
39 self.run_migrations(&pending)
40 }
41
42 #[doc(hidden)]
46 fn run_migrations(
47 &mut self,
48 migrations: &[Box<dyn Migration<DB>>],
49 ) -> Result<Vec<MigrationVersion>> {
50 migrations.iter().map(|m| self.run_migration(m)).collect()
51 }
52
53 fn run_next_migration<S: MigrationSource<DB>>(
55 &mut self,
56 source: S,
57 ) -> Result<MigrationVersion> {
58 let pending_migrations = self.pending_migrations(source)?;
59 let next_migration = pending_migrations
60 .first()
61 .ok_or(MigrationError::NoMigrationRun)?;
62 self.run_migration(next_migration)
63 }
64
65 fn revert_all_migrations<S: MigrationSource<DB>>(
67 &mut self,
68 source: S,
69 ) -> Result<Vec<MigrationVersion>> {
70 let applied_versions = self.applied_migrations()?;
71 let mut migrations = source
72 .migrations()?
73 .into_iter()
74 .map(|m| (m.name().version().as_owned(), m))
75 .collect::<HashMap<_, _>>();
76
77 applied_versions
78 .into_iter()
79 .map(|version| {
80 let migration_to_revert = migrations
81 .remove(&version)
82 .ok_or(MigrationError::UnknownMigrationVersion(version))?;
83 self.revert_migration(&migration_to_revert)
84 })
85 .collect()
86 }
87
88 fn revert_last_migration<S: MigrationSource<DB>>(
93 &mut self,
94 source: S,
95 ) -> Result<MigrationVersion<'static>> {
96 let applied_versions = self.applied_migrations()?;
97 let migrations = source.migrations()?;
98 let last_migration_version = applied_versions
99 .first()
100 .ok_or(MigrationError::NoMigrationRun)?;
101 let migration_to_revert = migrations
102 .iter()
103 .find(|m| m.name().version() == *last_migration_version)
104 .ok_or_else(|| {
105 MigrationError::UnknownMigrationVersion(last_migration_version.as_owned())
106 })?;
107 self.revert_migration(migration_to_revert)
108 }
109
110 fn pending_migrations<S: MigrationSource<DB>>(
115 &mut self,
116 source: S,
117 ) -> Result<Vec<Box<dyn Migration<DB>>>> {
118 let applied_versions = self.applied_migrations()?;
119 let mut migrations = source
120 .migrations()?
121 .into_iter()
122 .map(|m| (m.name().version().as_owned(), m))
123 .collect::<HashMap<_, _>>();
124
125 for applied_version in applied_versions {
126 migrations.remove(&applied_version);
127 }
128
129 let mut migrations = migrations.into_values().collect::<Vec<_>>();
130
131 migrations.sort_unstable_by(|a, b| a.name().version().cmp(&b.name().version()));
132
133 Ok(migrations)
134 }
135
136 fn run_migration(&mut self, migration: &dyn Migration<DB>)
141 -> Result<MigrationVersion<'static>>;
142
143 fn revert_migration(
148 &mut self,
149 migration: &dyn Migration<DB>,
150 ) -> Result<MigrationVersion<'static>>;
151
152 fn applied_migrations(&mut self) -> Result<Vec<MigrationVersion<'static>>>;
154}
155
156impl<'b, C, DB> MigrationHarness<DB> for C
157where
158 DB: Backend,
159 C: Connection<Backend = DB> + MigrationConnection + 'static,
160 dsl::Order<
161 dsl::Select<__diesel_schema_migrations::table, __diesel_schema_migrations::version>,
162 dsl::Desc<__diesel_schema_migrations::version>,
163 >: LoadQuery<'b, C, MigrationVersion<'static>>,
164 for<'a> InsertStatement<
165 __diesel_schema_migrations::table,
166 <dsl::Eq<__diesel_schema_migrations::version, MigrationVersion<'static>> as Insertable<
167 __diesel_schema_migrations::table,
168 >>::Values,
169 >: diesel::query_builder::QueryFragment<DB> + ExecuteDsl<C, DB>,
170 DeleteStatement<
171 <dsl::Find<
172 __diesel_schema_migrations::table,
173 MigrationVersion<'static>,
174 > as HasTable>::Table,
175 <dsl::Find<
176 __diesel_schema_migrations::table,
177 MigrationVersion<'static>,
178 > as IntoUpdateTarget>::WhereClause,
179 >: ExecuteDsl<C>,
180 str: ToSql<Text, DB>,
181{
182 fn run_migration(
183 &mut self,
184 migration: &dyn Migration<DB>,
185 ) -> Result<MigrationVersion<'static>> {
186 let apply_migration = |conn: &mut C| -> Result<()> {
187 migration.run(conn)?;
188 diesel::insert_into(__diesel_schema_migrations::table)
189 .values(__diesel_schema_migrations::version.eq(migration.name().version().as_owned())).execute(conn)?;
190 Ok(())
191 };
192
193 if migration.metadata().run_in_transaction() {
194 self.transaction(apply_migration)?;
195 } else {
196 apply_migration(self)?;
197 }
198 Ok(migration.name().version().as_owned())
199 }
200
201 fn revert_migration(
202 &mut self,
203 migration: &dyn Migration<DB>,
204 ) -> Result<MigrationVersion<'static>> {
205 let revert_migration = |conn: &mut C| -> Result<()> {
206 migration.revert(conn)?;
207 diesel::delete(__diesel_schema_migrations::table.find(migration.name().version().as_owned()))
208 .execute(conn)?;
209 Ok(())
210 };
211
212 if migration.metadata().run_in_transaction() {
213 self.transaction(revert_migration)?;
214 } else {
215 revert_migration(self)?;
216 }
217 Ok(migration.name().version().as_owned())
218 }
219
220 fn applied_migrations(&mut self) -> Result<Vec<MigrationVersion<'static>>> {
221 setup_database(self)?;
222 Ok(__diesel_schema_migrations::table
223 .select(__diesel_schema_migrations::version)
224 .order(__diesel_schema_migrations::version.desc())
225 .load(self)?)
226 }
227}
228
229pub struct HarnessWithOutput<'a, C, W> {
232 connection: &'a mut C,
233 output: RefCell<W>,
234}
235
236impl<'a, C, W> HarnessWithOutput<'a, C, W> {
237 pub fn new<DB>(harness: &'a mut C, output: W) -> Self
239 where
240 C: MigrationHarness<DB>,
241 DB: Backend,
242 W: Write,
243 {
244 Self {
245 connection: harness,
246 output: RefCell::new(output),
247 }
248 }
249}
250
251impl<'a, C> HarnessWithOutput<'a, C, std::io::Stdout> {
252 pub fn write_to_stdout<DB>(harness: &'a mut C) -> Self
254 where
255 C: MigrationHarness<DB>,
256 DB: Backend,
257 {
258 Self {
259 connection: harness,
260 output: RefCell::new(std::io::stdout()),
261 }
262 }
263}
264
265impl<C, W, DB> MigrationHarness<DB> for HarnessWithOutput<'_, C, W>
266where
267 W: Write,
268 C: MigrationHarness<DB>,
269 DB: Backend,
270{
271 fn run_migration(
272 &mut self,
273 migration: &dyn Migration<DB>,
274 ) -> Result<MigrationVersion<'static>> {
275 if migration.name().version() != MigrationVersion::from("00000000000000") {
276 let mut output = self.output.try_borrow_mut()?;
277 writeln!(output, "Running migration {}", migration.name())?;
278 }
279 self.connection.run_migration(migration)
280 }
281
282 fn revert_migration(
283 &mut self,
284 migration: &dyn Migration<DB>,
285 ) -> Result<MigrationVersion<'static>> {
286 if migration.name().version() != MigrationVersion::from("00000000000000") {
287 let mut output = self.output.try_borrow_mut()?;
288 writeln!(output, "Rolling back migration {}", migration.name())?;
289 }
290 self.connection.revert_migration(migration)
291 }
292
293 fn applied_migrations(&mut self) -> Result<Vec<MigrationVersion<'static>>> {
294 self.connection.applied_migrations()
295 }
296}
297
298fn setup_database<Conn: MigrationConnection>(conn: &mut Conn) -> QueryResult<usize> {
299 conn.setup()
300}