Camvid 全名Cambridge-driving labeled Video Database
是劍橋大學擁有的資料庫,將街景以32類標籤標註分類,
我們使用這份資料庫,來做U-net model訓練, 來看訓練的成果如何
U-net model主要有幾個架構
Encoder (保留語意)
decoder (還原空間)
skip connection(補充細節)
因這類任務,會看到每個pixel最符合那種標籤,因此每個pixel都會有分類數量的維度
U-net model可以同時擷取語意,並透過decoder+skip connection保留細節,
這也是為什麼U-net model適合這類segmentation的分類任務

這邊使用resnet34的body作為encoder的骨架
encoder = create_body(resnet34(weights=weights), cut=-2)
# Create U-Net decoder head
在create_haead中,將原本resnet34的heder替換成U-Net decoder
def create_head(
num_features: int,
num_classes: int,
lin_ftrs: Optional[List[int]] = None,
ps: float = 0.5,
pool: bool = True,
concat_pool: bool = True,
first_bn: bool = True,
bn_final: bool = False,
lin_first: bool = False,
y_range: Optional[Tuple[float, float]] = None
) -> UnetBlock:
"""Create U-Net decoder head.
Args:
num_features: Number of input features from encoder
num_classes: Number of output classes
lin_ftrs: Linear layer features (unused for U-Net)
ps: Dropout probability (unused)
pool: Whether to use pooling (unused)
concat_pool: Whether to use concatenated pooling (unused)
first_bn: Whether to use batch norm first (unused)
bn_final: Whether to use final batch norm (unused)
lin_first: Whether to use linear layer first (unused)
y_range: Output range (unused)
Returns:
U-Net decoder block
"""
return UnetBlock(num_features, num_classes)
head = create_head(
ENCODER_OUTPUT_FEATURES,
num_classes,
lin_ftrs=[512, 512],
lin_first=True
)
底下可以清楚的看到decoder的脈絡,
程式先將encoder的output跟decoder的input 併列在一起,透過conv block降維,再做upsampling,讓輸出能夠和輸入的寬高維持在同樣(或接近)的大小
在Final classification layer這層,我們要將output的維度轉為class number,目的是讓每個pixel都有class number的維度,才能在預測中挑出預測率最高的項目。
class UnetBlock(nn.Module):
def __init__(self, num_features: int, num_classes: int, stride: int = 1):
"""Initialize U-Net decoder block.
Args:
num_features: Number of input features from encoder
num_classes: Number of output segmentation classes
stride: Stride for convolution blocks
"""
super().__init__()
# Decoder blocks with progressively fewer features
self.block = self._create_conv_block(num_features * 2, num_features, stride)
self.block2 = self._create_conv_block(num_features, num_features // 2, stride)
self.block4 = self._create_conv_block(num_features // 2, num_features // 4, stride)
self.block8 = self._create_conv_block(num_features // 4, num_features // 8, stride)
# Upsampling layers
self.upsample_2x = UpsampleConv(num_features, num_features // 2, scale=2, k=2)
self.upsample_4x = UpsampleConv(num_features // 2, num_features // 4, scale=2, k=2)
self.upsample_8x = UpsampleConv(num_features // 4, num_features // 8, scale=2, k=2)
self.upsample_8x_static = UpsampleConv(num_features // 8, num_features // 8, scale=2, k=2)
self.upsample_16x_static = UpsampleConv(num_features // 8, num_features // 8, scale=2, k=2)
# Final classification layer
self.final_conv = ConvLayer(num_features // 8, num_classes, ks=2, stride=1)
.....
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through U-Net decoder with skip connections.
Args:
x: Input tensor from encoder
Returns:
Segmentation prediction tensor
"""
# Stage 1: Concatenate deepest encoder features
y = torch.cat((extracted_features[3][1], x), dim=1)
extracted_features.pop(3)
y = self.block(y)
y = self.upsample_2x(y)
# Stage 2: Skip connection with encoder layer 2
y = self.center_crop(y, extracted_features[2][1])
y = torch.cat((extracted_features[2][1], y), dim=1)
extracted_features.pop(2)
y = self.block2(y)
y = self.upsample_4x(y)
...
做U-net block,要特別注意encoder的output跟decoder的input是否有對齊
這邊我們會用center_crop來對feature_map做剪裁,
def center_crop(
self,
feature_map: torch.Tensor,
target: torch.Tensor
) -> torch.Tensor:
"""Center crop feature map to match target dimensions.
Args:
feature_map: Feature map to crop
target: Target tensor with desired dimensions
Returns:
Cropped feature map matching target dimensions
"""
_, _, height, width = feature_map.shape
_, _, target_height, target_width = target.shape
start_h = (height - target_height) // 2
start_w = (width - target_width) // 2
return feature_map[
:,
:,
start_h:(start_h + target_height),
start_w:(start_w + target_width)
]
值得注意的是,由於target image是RGB, 而U-net model的output是segmentation image,
在create dataloader時,需先將target image由RGB轉為segmentation.
Camvid會有一份excel檔,紀錄每一個segmentation對應的RGB,我們可以取得color to index的mapping table,這邊是儲存在dictionary.
def load_class_mapping(csv_path: Path) -> Tuple[List[str], Dict[Tuple[int, int, int], int]]:
"""Load class names and RGB-to-index mapping from CSV.
Args:
csv_path: Path to class_dict.csv file
Returns:
Tuple of (class_names, rgb_to_index_mapping)
"""
df = pd.read_csv(csv_path)
class_names = df['name'].tolist()
color_to_idx = {
tuple(row[['r', 'g', 'b']].values.tolist()): idx
for idx, row in df.iterrows()
}
return class_names, color_to_idx
def get_camvid_mask(
image_path: Path,
color_to_idx: Dict[Tuple[int, int, int], int]
) -> torch.Tensor:
"""Convert RGB mask to indexed tensor.
Args:
image_path: Path to input image (used to find mask)
color_to_idx: Mapping from RGB colors to class indices
Returns:
Indexed mask tensor with class indices
"""
mask_path = get_camvid_mask_path(image_path)
mask_img = PILImage.create(mask_path)
indexed_mask = torch.zeros((mask_img.height, mask_img.width), dtype=torch.int8)
mask_array = np.array(mask_img)
# Vectorized conversion from RGB to class indices
for rgb, idx in color_to_idx.items():
matches = (
(mask_array[:, :, 0] == rgb[0]) &
(mask_array[:, :, 1] == rgb[1]) &
(mask_array[:, :, 2] == rgb[2])
)
indexed_mask[matches] = idx
return torch.tensor(indexed_mask)
def create_dataloaders(
data_path: Path,
class_codes: List[str],
color_to_idx: Dict[Tuple[int, int, int], int],
batch_size: int = DEFAULT_BATCH_SIZE
):
"""Create FastAI dataloaders for CamVid dataset.
Args:
data_path: Root path to dataset
class_codes: List of class names
color_to_idx: RGB to class index mapping
batch_size: Batch size for training
Returns:
FastAI DataLoaders object
"""
return DataBlock(
blocks=(ImageBlock, MaskBlock(codes=class_codes)),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=VALID_SPLIT_PCT, seed=RANDOM_SEED),
get_y=lambda fn: get_camvid_mask(fn, color_to_idx),
item_tfms=Resize(DEFAULT_IMAGE_SIZE),
batch_tfms=RandomResizedCrop(DEFAULT_IMAGE_SIZE)
).dataloaders(data_path, bs=batch_size)
將model training完,
最後我們就可以看到segmentation target跟predict的差別
segmentation target

segmentation predict

實驗是用fit_one_cycle做3次的epoch, learing rate為3e-3
輸出結果看的到prediction有初步的雛形,但邊界仍顯比較模糊,
看得出還有提升的空間。