1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
use core::{
    fmt::{self, Debug, Formatter},
    mem::{self, MaybeUninit},
    ptr,
};

/// An automation sequence which will either be polled to completion or abort
/// early with a fault.
pub trait AutomationSequence<Input, Output> {
    /// Extra info attached to a fault.
    type FaultInfo;

    fn poll(
        &mut self,
        inputs: &Input,
        outputs: &mut Output,
    ) -> Transition<Self::FaultInfo>;
}

/// The result of a single call to [`AutomationSequence::poll()`].
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum Transition<F> {
    /// The [`AutomationSequence`] completed successfully.
    Complete,
    /// The [`AutomationSequence`] failed with a particular fault code.
    Fault(F),
    /// The [`AutomationSequence`] is still running.
    Incomplete,
}

impl<F> Transition<F> {
    pub fn at_end_state(&self) -> bool {
        match self {
            Transition::Complete | Transition::Fault(..) => true,
            Transition::Incomplete => false,
        }
    }
}

/// A combinator which combines many [`AutomationSequence`]s and will poll them
/// all to completion, stopping when either a fault is raised or there are no
/// more incomplete sequences.
///
/// # Examples
///
/// ```rust
/// use aimc_hal::automation::{AutomationSequence, Transition, All};
///
/// /// A simple automation sequence which will return `Transition::Incomplete`
/// /// until it reaches zero.
/// struct CountDown(usize);
///
/// impl<I, O> AutomationSequence<I, O> for CountDown {
///     type FaultInfo = ();
///
///     fn poll(&mut self, inputs: &I, outputs: &mut O) -> Transition<()> {
///         if self.0 == 0 {
///             Transition::Complete
///         } else {
///             self.0 -= 1;
///             Transition::Incomplete
///         }
///     }
/// }
///
/// // Combine the sequences into one big automation sequence
/// let mut seq = All::new([CountDown(1), CountDown(5), CountDown(2)]);
///
/// // we'll keep track of the number of polls
/// let mut polls = 0;
///
/// // keep polling until all the timers have reached zero
/// while seq.poll(&(), &mut ()) != Transition::Complete {
///     polls += 1;
/// }
///
/// // we should have polled 5 times (`max(1, 5, 2)`)
/// assert_eq!(polls, 5);
/// ```
///
/// [const]: https://github.com/rust-lang/rust/issues/44580
pub struct All<A, const N: usize> {
    sequences: [Option<A>; N],
}

impl<A, const N: usize> All<A, { N }> {
    pub fn new(items: [A; N]) -> Self {
        All {
            sequences: wrap_in_option(items),
        }
    }
}

/// Transform a `[T; N]` into a `[Option<T>; N]`, this is essentially a poor
/// man's `items.into_iter().map(Some).collect()` using static arrays.
fn wrap_in_option<T, const N: usize>(items: [T; N]) -> [Option<T>; N] {
    unsafe {
        let mut sequences =
            MaybeUninit::<[MaybeUninit<Option<T>>; N]>::uninit();

        for i in 0..N {
            // get a pointer to the item we want to copy
            let item = items.as_ptr().add(i);

            // an array of MaybeUninit is always valid
            let sequences = &mut *sequences.as_mut_ptr();

            // copy the item across, transferring ownership to sequences
            sequences
                .as_mut_ptr()
                .add(i)
                .write(MaybeUninit::new(Some(ptr::read(item))));
        }

        // The original variable no longer has ownership
        mem::forget(items);

        mem::transmute_copy(&sequences)
    }
}

impl<I, O, A, const N: usize> AutomationSequence<I, O> for All<A, { N }>
where
    A: AutomationSequence<I, O>,
{
    type FaultInfo = A::FaultInfo;

    fn poll(
        &mut self,
        inputs: &I,
        outputs: &mut O,
    ) -> Transition<Self::FaultInfo> {
        let variants = self.sequences.as_mut();

        for variant in variants.iter_mut() {
            if let Transition::Fault(f) = poll_variant(variant, inputs, outputs)
            {
                return Transition::Fault(f);
            }
        }

        if variants.iter().all(|v| v.is_none()) {
            Transition::Complete
        } else {
            Transition::Incomplete
        }
    }
}

// prefer to manually implement these instead of using #[derive] so it doesn't
// pollute the `All` type signature with implementation details. Should be
// resolved once `#[feature(const_generic_impls_guard)]` is removed.

impl<A, const N: usize> Default for All<A, { N }>
where
    [Option<A>; N]: core::array::LengthAtMost32,
    [Option<A>; N]: Default,
{
    fn default() -> All<A, { N }> {
        All {
            sequences: Default::default(),
        }
    }
}

impl<A, const N: usize> Copy for All<A, { N }>
where
    [Option<A>; N]: core::array::LengthAtMost32,
    [Option<A>; N]: Copy,
{
}

impl<A, const N: usize> Clone for All<A, { N }>
where
    [Option<A>; N]: core::array::LengthAtMost32,
    [Option<A>; N]: Clone,
{
    fn clone(&self) -> All<A, { N }> {
        All {
            sequences: self.sequences.clone(),
        }
    }
}

impl<A, const N: usize> Debug for All<A, { N }>
where
    [Option<A>; N]: core::array::LengthAtMost32,
    [Option<A>; N]: Debug,
{
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        let All { ref sequences } = *self;

        f.debug_struct("All").field("sequences", sequences).finish()
    }
}

impl<A, const N: usize> PartialEq for All<A, { N }>
where
    [Option<A>; N]: core::array::LengthAtMost32,
    [Option<A>; N]: PartialEq,
{
    fn eq(&self, other: &Self) -> bool {
        let All { ref sequences } = *self;

        sequences == &other.sequences
    }
}

fn poll_variant<I, O, A>(
    variant: &mut Option<A>,
    inputs: &I,
    outputs: &mut O,
) -> Transition<A::FaultInfo>
where
    A: AutomationSequence<I, O>,
{
    let trans = match variant {
        Some(ref mut sequence) => sequence.poll(inputs, outputs),
        None => Transition::Complete,
    };

    if trans.at_end_state() {
        let _ = variant.take();
    }

    trans
}

#[cfg(test)]
mod tests {
    use super::*;

    #[derive(Debug, Default)]
    struct Countdown(usize);

    impl AutomationSequence<(), ()> for Countdown {
        type FaultInfo = ();

        fn poll(&mut self, _: &(), _: &mut ()) -> Transition<Self::FaultInfo> {
            unimplemented!()
        }
    }

    #[test]
    fn poll_all() {
        let items = All::new([Countdown(1), Countdown(5)]);

        fn assert_is_automation_sequence<A, I, O>(_: &A)
        where
            A: AutomationSequence<I, O>,
        {
        }

        assert_is_automation_sequence(&Countdown(0));
        assert_is_automation_sequence(&items);
    }
}