1use crate::{
2 MctpMessage, MctpMessageHeaderTrait, MctpMessageTrait, MctpPacketError,
3 deserialize::{parse_message_body, parse_transport_header},
4 endpoint_id::EndpointId,
5 error::{MctpPacketResult, ProtocolError},
6 mctp_message_tag::MctpMessageTag,
7 mctp_sequence_number::MctpSequenceNumber,
8 medium::{MctpMedium, MctpMediumFrame},
9 serialize::SerializePacketState,
10};
11
12#[derive(Debug, PartialEq, Eq)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub struct MctpReplyContext<M: MctpMedium> {
18 pub destination_endpoint_id: EndpointId,
19 pub source_endpoint_id: EndpointId,
20 pub packet_sequence_number: MctpSequenceNumber,
21 pub message_tag: MctpMessageTag,
22 pub medium_context: M::ReplyContext,
23}
24
25pub struct MctpPacketContext<'buf, M: MctpMedium> {
28 assembly_state: AssemblyState,
29 medium: M,
30 packet_assembly_buffer: &'buf mut [u8],
31}
32
33impl<'buf, M: MctpMedium> MctpPacketContext<'buf, M> {
34 pub fn new(medium: M, packet_assembly_buffer: &'buf mut [u8]) -> Self {
35 Self {
36 medium,
37 assembly_state: AssemblyState::Idle,
38 packet_assembly_buffer,
39 }
40 }
41
42 pub fn deserialize_packet(
43 &mut self,
44 packet: &[u8],
45 ) -> MctpPacketResult<Option<MctpMessage<'_, M>>, M> {
46 let (medium_frame, packet) = self.medium.deserialize(packet)?;
47 let (transport_header, packet) = parse_transport_header::<M>(packet)?;
48
49 let mut state = match self.assembly_state {
50 AssemblyState::Idle => {
51 if transport_header.start_of_message == 0 {
52 return Err(MctpPacketError::ProtocolError(
53 ProtocolError::ExpectedStartOfMessage,
54 ));
55 }
56
57 AssemblingState {
58 message_tag: transport_header.message_tag,
59 tag_owner: transport_header.tag_owner,
60 source_endpoint_id: transport_header.source_endpoint_id,
61 packet_sequence_number: transport_header.packet_sequence_number,
62 packet_assembly_buffer_index: 0,
63 }
64 }
65 AssemblyState::Receiving(assembling_state) => {
66 if transport_header.start_of_message != 0 {
67 return Err(MctpPacketError::ProtocolError(
68 ProtocolError::UnexpectedStartOfMessage,
69 ));
70 }
71 if assembling_state.message_tag != transport_header.message_tag {
72 return Err(MctpPacketError::ProtocolError(
73 ProtocolError::MessageTagMismatch(
74 assembling_state.message_tag,
75 transport_header.message_tag,
76 ),
77 ));
78 }
79 if assembling_state.tag_owner != transport_header.tag_owner {
80 return Err(MctpPacketError::ProtocolError(
81 ProtocolError::TagOwnerMismatch(
82 assembling_state.tag_owner,
83 transport_header.tag_owner,
84 ),
85 ));
86 }
87 if assembling_state.source_endpoint_id != transport_header.source_endpoint_id {
88 return Err(MctpPacketError::ProtocolError(
89 ProtocolError::SourceEndpointIdMismatch(
90 assembling_state.source_endpoint_id,
91 transport_header.source_endpoint_id,
92 ),
93 ));
94 }
95 let expected_sequence_number = assembling_state.packet_sequence_number.next();
96 if expected_sequence_number != transport_header.packet_sequence_number {
97 return Err(MctpPacketError::ProtocolError(
98 ProtocolError::UnexpectedPacketSequenceNumber(
99 expected_sequence_number,
100 transport_header.packet_sequence_number,
101 ),
102 ));
103 }
104 assembling_state
105 }
106 };
107
108 let buffer_idx = state.packet_assembly_buffer_index;
109 let packet_size = medium_frame.packet_size();
110 if packet_size < 4 {
111 return Err(MctpPacketError::HeaderParseError(
112 "transport frame indicated packet length < 4",
113 ));
114 }
115 let packet_size = packet_size - 4; if packet.len() < packet_size {
117 return Err(MctpPacketError::HeaderParseError(
118 "packet.len() < packet_size",
119 ));
120 }
121 if buffer_idx + packet_size > self.packet_assembly_buffer.len() {
123 return Err(MctpPacketError::HeaderParseError(
124 "packet assembly buffer overflow - insufficient space",
125 ));
126 }
127 self.packet_assembly_buffer[buffer_idx..buffer_idx + packet_size]
128 .copy_from_slice(&packet[..packet_size]);
129 state.packet_assembly_buffer_index += packet_size;
130
131 let message = if transport_header.end_of_message == 1 {
132 self.assembly_state = AssemblyState::Idle;
133 let (message_body, message_integrity_check) = parse_message_body::<M>(
134 &self.packet_assembly_buffer[..state.packet_assembly_buffer_index],
135 )?;
136 Some(MctpMessage {
137 reply_context: MctpReplyContext {
138 destination_endpoint_id: transport_header.destination_endpoint_id,
139 source_endpoint_id: transport_header.source_endpoint_id,
140 packet_sequence_number: transport_header.packet_sequence_number,
141 message_tag: transport_header.message_tag,
142 medium_context: medium_frame.reply_context(),
143 },
144 message_buffer: message_body,
145 message_integrity_check,
146 })
147 } else {
148 self.assembly_state = AssemblyState::Receiving(state);
149 None
150 };
151
152 Ok(message)
153 }
154
155 pub fn serialize_packet<P: MctpMessageTrait<'buf>>(
156 &'buf mut self,
157 reply_context: MctpReplyContext<M>,
158 message: (P::Header, P),
159 ) -> MctpPacketResult<SerializePacketState<'buf, M>, M> {
160 match self.assembly_state {
161 AssemblyState::Idle => {}
162 _ => {
163 return Err(MctpPacketError::ProtocolError(
164 ProtocolError::SendMessageWhileAssembling,
165 ));
166 }
167 };
168
169 self.packet_assembly_buffer[0] = P::MESSAGE_TYPE;
170 let header_size = message.0.serialize(&mut self.packet_assembly_buffer[1..])?;
171 let body_size = message
172 .1
173 .serialize(&mut self.packet_assembly_buffer[header_size + 1..])?;
174
175 let (message, rest) = self
176 .packet_assembly_buffer
177 .split_at_mut(header_size + body_size + 1);
178
179 Ok(SerializePacketState {
180 medium: &self.medium,
181 reply_context,
182 current_packet_num: 0,
183 serialized_message_header: false,
184 message_buffer: message,
185 assembly_buffer: rest,
186 })
187 }
188}
189
190#[derive(Debug, Copy, Clone, PartialEq, Eq)]
191#[cfg_attr(feature = "defmt", derive(defmt::Format))]
192enum AssemblyState {
193 Idle,
194 Receiving(AssemblingState),
195}
196
197#[derive(Debug, Copy, Clone, PartialEq, Eq)]
198#[cfg_attr(feature = "defmt", derive(defmt::Format))]
199struct AssemblingState {
200 message_tag: MctpMessageTag,
201 tag_owner: u8,
202 source_endpoint_id: EndpointId,
203 packet_sequence_number: MctpSequenceNumber,
204 packet_assembly_buffer_index: usize,
205}