Camvid 全名Cambridge-driving labeled Video Database

是劍橋大學擁有的資料庫,將街景以32類標籤標註分類,


我們使用這份資料庫,來做U-net model訓練, 來看訓練的成果如何

U-net model主要有幾個架構

Encoder (保留語意)

decoder (還原空間)

skip connection(補充細節)


因這類任務,會看到每個pixel最符合那種標籤,因此每個pixel都會有分類數量的維度

U-net model可以同時擷取語意,並透過decoder+skip connection保留細節,

這也是為什麼U-net model適合這類segmentation的分類任務


這邊使用resnet34的body作為encoder的骨架

encoder = create_body(resnet34(weights=weights), cut=-2)

# Create U-Net decoder head

在create_haead中,將原本resnet34的heder替換成U-Net decoder

def create_head(
num_features: int,
num_classes: int,
lin_ftrs: Optional[List[int]] = None,
ps: float = 0.5,
pool: bool = True,
concat_pool: bool = True,
first_bn: bool = True,
bn_final: bool = False,
lin_first: bool = False,
y_range: Optional[Tuple[float, float]] = None
) -> UnetBlock:
"""Create U-Net decoder head.

Args:
num_features: Number of input features from encoder
num_classes: Number of output classes
lin_ftrs: Linear layer features (unused for U-Net)
ps: Dropout probability (unused)
pool: Whether to use pooling (unused)
concat_pool: Whether to use concatenated pooling (unused)
first_bn: Whether to use batch norm first (unused)
bn_final: Whether to use final batch norm (unused)
lin_first: Whether to use linear layer first (unused)
y_range: Output range (unused)

Returns:
U-Net decoder block
"""
return UnetBlock(num_features, num_classes)

head = create_head(
ENCODER_OUTPUT_FEATURES,
num_classes,
lin_ftrs=[512, 512],
lin_first=True
)


底下可以清楚的看到decoder的脈絡,

程式先將encoder的output跟decoder的input 併列在一起,透過conv block降維,再做upsampling,讓輸出能夠和輸入的寬高維持在同樣(或接近)的大小


在Final classification layer這層,我們要將output的維度轉為class number,目的是讓每個pixel都有class number的維度,才能在預測中挑出預測率最高的項目。


class UnetBlock(nn.Module):

def __init__(self, num_features: int, num_classes: int, stride: int = 1):
"""Initialize U-Net decoder block.

Args:
num_features: Number of input features from encoder
num_classes: Number of output segmentation classes
stride: Stride for convolution blocks
"""
super().__init__()

# Decoder blocks with progressively fewer features
self.block = self._create_conv_block(num_features * 2, num_features, stride)
self.block2 = self._create_conv_block(num_features, num_features // 2, stride)
self.block4 = self._create_conv_block(num_features // 2, num_features // 4, stride)
self.block8 = self._create_conv_block(num_features // 4, num_features // 8, stride)

# Upsampling layers
self.upsample_2x = UpsampleConv(num_features, num_features // 2, scale=2, k=2)
self.upsample_4x = UpsampleConv(num_features // 2, num_features // 4, scale=2, k=2)
self.upsample_8x = UpsampleConv(num_features // 4, num_features // 8, scale=2, k=2)
self.upsample_8x_static = UpsampleConv(num_features // 8, num_features // 8, scale=2, k=2)
self.upsample_16x_static = UpsampleConv(num_features // 8, num_features // 8, scale=2, k=2)

# Final classification layer
self.final_conv = ConvLayer(num_features // 8, num_classes, ks=2, stride=1)

.....

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through U-Net decoder with skip connections.

Args:
x: Input tensor from encoder

Returns:
Segmentation prediction tensor
"""

# Stage 1: Concatenate deepest encoder features
y = torch.cat((extracted_features[3][1], x), dim=1)
extracted_features.pop(3)
y = self.block(y)
y = self.upsample_2x(y)

# Stage 2: Skip connection with encoder layer 2
y = self.center_crop(y, extracted_features[2][1])
y = torch.cat((extracted_features[2][1], y), dim=1)
extracted_features.pop(2)
y = self.block2(y)
y = self.upsample_4x(y)

...


做U-net block,要特別注意encoder的output跟decoder的input是否有對齊

這邊我們會用center_crop來對feature_map做剪裁,


def center_crop(
self,
feature_map: torch.Tensor,
target: torch.Tensor
) -> torch.Tensor:
"""Center crop feature map to match target dimensions.

Args:
feature_map: Feature map to crop
target: Target tensor with desired dimensions

Returns:
Cropped feature map matching target dimensions
"""
_, _, height, width = feature_map.shape
_, _, target_height, target_width = target.shape

start_h = (height - target_height) // 2
start_w = (width - target_width) // 2

return feature_map[
:,
:,
start_h:(start_h + target_height),
start_w:(start_w + target_width)
]

值得注意的是,由於target image是RGB, 而U-net model的output是segmentation image,

在create dataloader時,需先將target image由RGB轉為segmentation.


Camvid會有一份excel檔,紀錄每一個segmentation對應的RGB,我們可以取得color to index的mapping table,這邊是儲存在dictionary.


def load_class_mapping(csv_path: Path) -> Tuple[List[str], Dict[Tuple[int, int, int], int]]:
"""Load class names and RGB-to-index mapping from CSV.

Args:
csv_path: Path to class_dict.csv file

Returns:
Tuple of (class_names, rgb_to_index_mapping)
"""
df = pd.read_csv(csv_path)
class_names = df['name'].tolist()

color_to_idx = {
tuple(row[['r', 'g', 'b']].values.tolist()): idx
for idx, row in df.iterrows()
}

return class_names, color_to_idx


def get_camvid_mask(
image_path: Path,
color_to_idx: Dict[Tuple[int, int, int], int]
) -> torch.Tensor:
"""Convert RGB mask to indexed tensor.

Args:
image_path: Path to input image (used to find mask)
color_to_idx: Mapping from RGB colors to class indices

Returns:
Indexed mask tensor with class indices
"""
mask_path = get_camvid_mask_path(image_path)
mask_img = PILImage.create(mask_path)

indexed_mask = torch.zeros((mask_img.height, mask_img.width), dtype=torch.int8)
mask_array = np.array(mask_img)

# Vectorized conversion from RGB to class indices
for rgb, idx in color_to_idx.items():
matches = (
(mask_array[:, :, 0] == rgb[0]) &
(mask_array[:, :, 1] == rgb[1]) &
(mask_array[:, :, 2] == rgb[2])
)
indexed_mask[matches] = idx

return torch.tensor(indexed_mask)


def create_dataloaders(
data_path: Path,
class_codes: List[str],
color_to_idx: Dict[Tuple[int, int, int], int],
batch_size: int = DEFAULT_BATCH_SIZE
):
"""Create FastAI dataloaders for CamVid dataset.

Args:
data_path: Root path to dataset
class_codes: List of class names
color_to_idx: RGB to class index mapping
batch_size: Batch size for training

Returns:
FastAI DataLoaders object
"""
return DataBlock(
blocks=(ImageBlock, MaskBlock(codes=class_codes)),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=VALID_SPLIT_PCT, seed=RANDOM_SEED),
get_y=lambda fn: get_camvid_mask(fn, color_to_idx),
item_tfms=Resize(DEFAULT_IMAGE_SIZE),
batch_tfms=RandomResizedCrop(DEFAULT_IMAGE_SIZE)
).dataloaders(data_path, bs=batch_size)


將model training完,

最後我們就可以看到segmentation target跟predict的差別


segmentation target

segmentation predict

實驗是用fit_one_cycle做3次的epoch, learing rate為3e-3

learn.fit_one_cycle(3,lr_max=3e-3)

輸出結果看的到prediction有初步的雛形,但邊界仍顯比較模糊,

看得出還有提升的空間。

創作者介紹
創作者 ChanLight的部落格 的頭像
ChanLight

ChanLight的部落格

ChanLight 發表在 痞客邦 留言(0) 人氣( 17 )