mctp_rs/
mctp_packet_context.rs

1use crate::{
2    MctpMessage, MctpMessageHeaderTrait, MctpMessageTrait, MctpPacketError,
3    deserialize::{parse_message_body, parse_transport_header},
4    endpoint_id::EndpointId,
5    error::{MctpPacketResult, ProtocolError},
6    mctp_message_tag::MctpMessageTag,
7    mctp_sequence_number::MctpSequenceNumber,
8    medium::{MctpMedium, MctpMediumFrame},
9    serialize::SerializePacketState,
10};
11
12/// Represents the state needed to construct a repsonse to a request:
13/// the MCTP transport source/destination, the sequence number to use for
14/// the reply, and the medium-specific context that came with the request.
15#[derive(Debug, PartialEq, Eq)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub struct MctpReplyContext<M: MctpMedium> {
18    pub destination_endpoint_id: EndpointId,
19    pub source_endpoint_id: EndpointId,
20    pub packet_sequence_number: MctpSequenceNumber,
21    pub message_tag: MctpMessageTag,
22    pub medium_context: M::ReplyContext,
23}
24
25/// Context for serializing and deserializing an MCTP message, which may be split among multiple
26/// packets.
27pub struct MctpPacketContext<'buf, M: MctpMedium> {
28    assembly_state: AssemblyState,
29    medium: M,
30    packet_assembly_buffer: &'buf mut [u8],
31}
32
33impl<'buf, M: MctpMedium> MctpPacketContext<'buf, M> {
34    pub fn new(medium: M, packet_assembly_buffer: &'buf mut [u8]) -> Self {
35        Self {
36            medium,
37            assembly_state: AssemblyState::Idle,
38            packet_assembly_buffer,
39        }
40    }
41
42    pub fn deserialize_packet(
43        &mut self,
44        packet: &[u8],
45    ) -> MctpPacketResult<Option<MctpMessage<'_, M>>, M> {
46        let (medium_frame, packet) = self.medium.deserialize(packet)?;
47        let (transport_header, packet) = parse_transport_header::<M>(packet)?;
48
49        let mut state = match self.assembly_state {
50            AssemblyState::Idle => {
51                if transport_header.start_of_message == 0 {
52                    return Err(MctpPacketError::ProtocolError(
53                        ProtocolError::ExpectedStartOfMessage,
54                    ));
55                }
56
57                AssemblingState {
58                    message_tag: transport_header.message_tag,
59                    tag_owner: transport_header.tag_owner,
60                    source_endpoint_id: transport_header.source_endpoint_id,
61                    packet_sequence_number: transport_header.packet_sequence_number,
62                    packet_assembly_buffer_index: 0,
63                }
64            }
65            AssemblyState::Receiving(assembling_state) => {
66                if transport_header.start_of_message != 0 {
67                    return Err(MctpPacketError::ProtocolError(
68                        ProtocolError::UnexpectedStartOfMessage,
69                    ));
70                }
71                if assembling_state.message_tag != transport_header.message_tag {
72                    return Err(MctpPacketError::ProtocolError(
73                        ProtocolError::MessageTagMismatch(
74                            assembling_state.message_tag,
75                            transport_header.message_tag,
76                        ),
77                    ));
78                }
79                if assembling_state.tag_owner != transport_header.tag_owner {
80                    return Err(MctpPacketError::ProtocolError(
81                        ProtocolError::TagOwnerMismatch(
82                            assembling_state.tag_owner,
83                            transport_header.tag_owner,
84                        ),
85                    ));
86                }
87                if assembling_state.source_endpoint_id != transport_header.source_endpoint_id {
88                    return Err(MctpPacketError::ProtocolError(
89                        ProtocolError::SourceEndpointIdMismatch(
90                            assembling_state.source_endpoint_id,
91                            transport_header.source_endpoint_id,
92                        ),
93                    ));
94                }
95                let expected_sequence_number = assembling_state.packet_sequence_number.next();
96                if expected_sequence_number != transport_header.packet_sequence_number {
97                    return Err(MctpPacketError::ProtocolError(
98                        ProtocolError::UnexpectedPacketSequenceNumber(
99                            expected_sequence_number,
100                            transport_header.packet_sequence_number,
101                        ),
102                    ));
103                }
104                assembling_state
105            }
106        };
107
108        let buffer_idx = state.packet_assembly_buffer_index;
109        let packet_size = medium_frame.packet_size();
110        if packet_size < 4 {
111            return Err(MctpPacketError::HeaderParseError(
112                "transport frame indicated packet length < 4",
113            ));
114        }
115        let packet_size = packet_size - 4; // to account for the transport header
116        if packet.len() < packet_size {
117            return Err(MctpPacketError::HeaderParseError(
118                "packet.len() < packet_size",
119            ));
120        }
121        // Check bounds to prevent buffer overflow
122        if buffer_idx + packet_size > self.packet_assembly_buffer.len() {
123            return Err(MctpPacketError::HeaderParseError(
124                "packet assembly buffer overflow - insufficient space",
125            ));
126        }
127        self.packet_assembly_buffer[buffer_idx..buffer_idx + packet_size]
128            .copy_from_slice(&packet[..packet_size]);
129        state.packet_assembly_buffer_index += packet_size;
130
131        let message = if transport_header.end_of_message == 1 {
132            self.assembly_state = AssemblyState::Idle;
133            let (message_body, message_integrity_check) = parse_message_body::<M>(
134                &self.packet_assembly_buffer[..state.packet_assembly_buffer_index],
135            )?;
136            Some(MctpMessage {
137                reply_context: MctpReplyContext {
138                    destination_endpoint_id: transport_header.destination_endpoint_id,
139                    source_endpoint_id: transport_header.source_endpoint_id,
140                    packet_sequence_number: transport_header.packet_sequence_number,
141                    message_tag: transport_header.message_tag,
142                    medium_context: medium_frame.reply_context(),
143                },
144                message_buffer: message_body,
145                message_integrity_check,
146            })
147        } else {
148            self.assembly_state = AssemblyState::Receiving(state);
149            None
150        };
151
152        Ok(message)
153    }
154
155    pub fn serialize_packet<P: MctpMessageTrait<'buf>>(
156        &'buf mut self,
157        reply_context: MctpReplyContext<M>,
158        message: (P::Header, P),
159    ) -> MctpPacketResult<SerializePacketState<'buf, M>, M> {
160        match self.assembly_state {
161            AssemblyState::Idle => {}
162            _ => {
163                return Err(MctpPacketError::ProtocolError(
164                    ProtocolError::SendMessageWhileAssembling,
165                ));
166            }
167        };
168
169        self.packet_assembly_buffer[0] = P::MESSAGE_TYPE;
170        let header_size = message.0.serialize(&mut self.packet_assembly_buffer[1..])?;
171        let body_size = message
172            .1
173            .serialize(&mut self.packet_assembly_buffer[header_size + 1..])?;
174
175        let (message, rest) = self
176            .packet_assembly_buffer
177            .split_at_mut(header_size + body_size + 1);
178
179        Ok(SerializePacketState {
180            medium: &self.medium,
181            reply_context,
182            current_packet_num: 0,
183            serialized_message_header: false,
184            message_buffer: message,
185            assembly_buffer: rest,
186        })
187    }
188}
189
190#[derive(Debug, Copy, Clone, PartialEq, Eq)]
191#[cfg_attr(feature = "defmt", derive(defmt::Format))]
192enum AssemblyState {
193    Idle,
194    Receiving(AssemblingState),
195}
196
197#[derive(Debug, Copy, Clone, PartialEq, Eq)]
198#[cfg_attr(feature = "defmt", derive(defmt::Format))]
199struct AssemblingState {
200    message_tag: MctpMessageTag,
201    tag_owner: u8,
202    source_endpoint_id: EndpointId,
203    packet_sequence_number: MctpSequenceNumber,
204    packet_assembly_buffer_index: usize,
205}